aboutsummaryrefslogtreecommitdiffstats
path: root/gallery_dl/util.py
diff options
context:
space:
mode:
authorLibravatarUnit 193 <unit193@ubuntu.com>2019-07-02 04:33:45 -0400
committerLibravatarUnit 193 <unit193@ubuntu.com>2019-07-02 04:33:45 -0400
commit195c45911e79c33cf0bb986721365fb06df5a153 (patch)
treeac0c9b6ef40bea7aa7ab0c5c3cb500eb510668fa /gallery_dl/util.py
Import Upstream version 1.8.7upstream/1.8.7
Diffstat (limited to 'gallery_dl/util.py')
-rw-r--r--gallery_dl/util.py673
1 files changed, 673 insertions, 0 deletions
diff --git a/gallery_dl/util.py b/gallery_dl/util.py
new file mode 100644
index 0000000..5c0ae41
--- /dev/null
+++ b/gallery_dl/util.py
@@ -0,0 +1,673 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2017-2019 Mike Fährmann
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 as
+# published by the Free Software Foundation.
+
+"""Utility functions and classes"""
+
+import re
+import os
+import sys
+import json
+import shutil
+import string
+import _string
+import sqlite3
+import datetime
+import operator
+import itertools
+import urllib.parse
+from . import text, exception
+
+
+def bencode(num, alphabet="0123456789"):
+ """Encode an integer into a base-N encoded string"""
+ data = ""
+ base = len(alphabet)
+ while num:
+ num, remainder = divmod(num, base)
+ data = alphabet[remainder] + data
+ return data
+
+
+def bdecode(data, alphabet="0123456789"):
+ """Decode a base-N encoded string ( N = len(alphabet) )"""
+ num = 0
+ base = len(alphabet)
+ for c in data:
+ num *= base
+ num += alphabet.index(c)
+ return num
+
+
+def advance(iterable, num):
+ """"Advance the iterable by 'num' steps"""
+ iterator = iter(iterable)
+ next(itertools.islice(iterator, num, num), None)
+ return iterator
+
+
+def raises(obj):
+ """Returns a function that raises 'obj' as exception"""
+ def wrap():
+ raise obj
+ return wrap
+
+
+def combine_dict(a, b):
+ """Recursively combine the contents of 'b' into 'a'"""
+ for key, value in b.items():
+ if key in a and isinstance(value, dict) and isinstance(a[key], dict):
+ combine_dict(a[key], value)
+ else:
+ a[key] = value
+ return a
+
+
+def transform_dict(a, func):
+ """Recursively apply 'func' to all values in 'a'"""
+ for key, value in a.items():
+ if isinstance(value, dict):
+ transform_dict(value, func)
+ else:
+ a[key] = func(value)
+
+
+def number_to_string(value, numbers=(int, float)):
+ """Convert numbers (int, float) to string; Return everything else as is."""
+ return str(value) if value.__class__ in numbers else value
+
+
+def to_string(value):
+ """str() with "better" defaults"""
+ if not value:
+ return ""
+ if value.__class__ is list:
+ try:
+ return ", ".join(value)
+ except Exception:
+ return ", ".join(map(str, value))
+ return str(value)
+
+
+def dump_json(obj, fp=sys.stdout, ensure_ascii=True, indent=4):
+ """Serialize 'obj' as JSON and write it to 'fp'"""
+ json.dump(
+ obj, fp,
+ ensure_ascii=ensure_ascii,
+ indent=indent,
+ default=str,
+ sort_keys=True,
+ )
+ fp.write("\n")
+
+
+def expand_path(path):
+ """Expand environment variables and tildes (~)"""
+ if not path:
+ return path
+ if not isinstance(path, str):
+ path = os.path.join(*path)
+ return os.path.expandvars(os.path.expanduser(path))
+
+
+def code_to_language(code, default=None):
+ """Map an ISO 639-1 language code to its actual name"""
+ return CODES.get((code or "").lower(), default)
+
+
+def language_to_code(lang, default=None):
+ """Map a language name to its ISO 639-1 code"""
+ if lang is None:
+ return default
+ lang = lang.capitalize()
+ for code, language in CODES.items():
+ if language == lang:
+ return code
+ return default
+
+
+CODES = {
+ "ar": "Arabic",
+ "bg": "Bulgarian",
+ "ca": "Catalan",
+ "cs": "Czech",
+ "da": "Danish",
+ "de": "German",
+ "el": "Greek",
+ "en": "English",
+ "es": "Spanish",
+ "fi": "Finnish",
+ "fr": "French",
+ "he": "Hebrew",
+ "hu": "Hungarian",
+ "id": "Indonesian",
+ "it": "Italian",
+ "jp": "Japanese",
+ "ko": "Korean",
+ "ms": "Malay",
+ "nl": "Dutch",
+ "no": "Norwegian",
+ "pl": "Polish",
+ "pt": "Portuguese",
+ "ro": "Romanian",
+ "ru": "Russian",
+ "sv": "Swedish",
+ "th": "Thai",
+ "tr": "Turkish",
+ "vi": "Vietnamese",
+ "zh": "Chinese",
+}
+
+SPECIAL_EXTRACTORS = {"oauth", "recursive", "test"}
+
+
+class UniversalNone():
+ """None-style object that supports more operations than None itself"""
+ __slots__ = ()
+
+ def __getattribute__(self, _):
+ return self
+
+ def __getitem__(self, _):
+ return self
+
+ @staticmethod
+ def __bool__():
+ return False
+
+ @staticmethod
+ def __str__():
+ return "None"
+
+ __repr__ = __str__
+
+
+NONE = UniversalNone()
+
+
+def build_predicate(predicates):
+ if not predicates:
+ return lambda url, kwds: True
+ elif len(predicates) == 1:
+ return predicates[0]
+ else:
+ return ChainPredicate(predicates)
+
+
+class RangePredicate():
+ """Predicate; True if the current index is in the given range"""
+ def __init__(self, rangespec):
+ self.ranges = self.optimize_range(self.parse_range(rangespec))
+ self.index = 0
+
+ if self.ranges:
+ self.lower, self.upper = self.ranges[0][0], self.ranges[-1][1]
+ else:
+ self.lower, self.upper = 0, 0
+
+ def __call__(self, url, kwds):
+ self.index += 1
+
+ if self.index > self.upper:
+ raise exception.StopExtraction()
+
+ for lower, upper in self.ranges:
+ if lower <= self.index <= upper:
+ return True
+ return False
+
+ @staticmethod
+ def parse_range(rangespec):
+ """Parse an integer range string and return the resulting ranges
+
+ Examples:
+ parse_range("-2,4,6-8,10-") -> [(1,2), (4,4), (6,8), (10,INTMAX)]
+ parse_range(" - 3 , 4- 4, 2-6") -> [(1,3), (4,4), (2,6)]
+ """
+ ranges = []
+
+ for group in rangespec.split(","):
+ if not group:
+ continue
+ first, sep, last = group.partition("-")
+ if not sep:
+ beg = end = int(first)
+ else:
+ beg = int(first) if first.strip() else 1
+ end = int(last) if last.strip() else sys.maxsize
+ ranges.append((beg, end) if beg <= end else (end, beg))
+
+ return ranges
+
+ @staticmethod
+ def optimize_range(ranges):
+ """Simplify/Combine a parsed list of ranges
+
+ Examples:
+ optimize_range([(2,4), (4,6), (5,8)]) -> [(2,8)]
+ optimize_range([(1,1), (2,2), (3,6), (8,9))]) -> [(1,6), (8,9)]
+ """
+ if len(ranges) <= 1:
+ return ranges
+
+ ranges.sort()
+ riter = iter(ranges)
+ result = []
+
+ beg, end = next(riter)
+ for lower, upper in riter:
+ if lower > end+1:
+ result.append((beg, end))
+ beg, end = lower, upper
+ elif upper > end:
+ end = upper
+ result.append((beg, end))
+ return result
+
+
+class UniquePredicate():
+ """Predicate; True if given URL has not been encountered before"""
+ def __init__(self):
+ self.urls = set()
+
+ def __call__(self, url, kwds):
+ if url.startswith("text:"):
+ return True
+ if url not in self.urls:
+ self.urls.add(url)
+ return True
+ return False
+
+
+class FilterPredicate():
+ """Predicate; True if evaluating the given expression returns True"""
+ globalsdict = {
+ "parse_int": text.parse_int,
+ "urlsplit": urllib.parse.urlsplit,
+ "datetime": datetime.datetime,
+ "abort": raises(exception.StopExtraction()),
+ "re": re,
+ }
+
+ def __init__(self, filterexpr, target="image"):
+ name = "<{} filter>".format(target)
+ self.codeobj = compile(filterexpr, name, "eval")
+
+ def __call__(self, url, kwds):
+ try:
+ return eval(self.codeobj, self.globalsdict, kwds)
+ except exception.GalleryDLException:
+ raise
+ except Exception as exc:
+ raise exception.FilterError(exc)
+
+
+class ChainPredicate():
+ """Predicate; True if all of its predicates return True"""
+ def __init__(self, predicates):
+ self.predicates = predicates
+
+ def __call__(self, url, kwds):
+ for pred in self.predicates:
+ if not pred(url, kwds):
+ return False
+ return True
+
+
+class ExtendedUrl():
+ """URL with attached config key-value pairs"""
+ def __init__(self, url, gconf, lconf):
+ self.value, self.gconfig, self.lconfig = url, gconf, lconf
+
+ def __str__(self):
+ return self.value
+
+
+class Formatter():
+ """Custom, extended version of string.Formatter
+
+ This string formatter implementation is a mostly performance-optimized
+ variant of the original string.Formatter class. Unnecessary features have
+ been removed (positional arguments, unused argument check) and new
+ formatting options have been added.
+
+ Extra Conversions:
+ - "l": calls str.lower on the target value
+ - "u": calls str.upper
+ - "c": calls str.capitalize
+ - "C": calls string.capwords
+ - "U": calls urllib.parse.unquote
+ - "S": calls util.to_string()
+ - Example: {f!l} -> "example"; {f!u} -> "EXAMPLE"
+
+ Extra Format Specifiers:
+ - "?<before>/<after>/":
+ Adds <before> and <after> to the actual value if it evaluates to True.
+ Otherwise the whole replacement field becomes an empty string.
+ Example: {f:?-+/+-/} -> "-+Example+-" (if "f" contains "Example")
+ -> "" (if "f" is None, 0, "")
+
+ - "L<maxlen>/<replacement>/":
+ Replaces the output with <replacement> if its length (in characters)
+ exceeds <maxlen>. Otherwise everything is left as is.
+ Example: {f:L5/too long/} -> "foo" (if "f" is "foo")
+ -> "too long" (if "f" is "foobar")
+
+ - "J<separator>/":
+ Joins elements of a list (or string) using <separator>
+ Example: {f:J - /} -> "a - b - c" (if "f" is ["a", "b", "c"])
+
+ - "R<old>/<new>/":
+ Replaces all occurrences of <old> with <new>
+ Example: {f:R /_/} -> "f_o_o_b_a_r" (if "f" is "f o o b a r")
+ """
+ CONVERSIONS = {
+ "l": str.lower,
+ "u": str.upper,
+ "c": str.capitalize,
+ "C": string.capwords,
+ "U": urllib.parse.unquote,
+ "S": to_string,
+ "s": str,
+ "r": repr,
+ "a": ascii,
+ }
+
+ def __init__(self, format_string, default=None):
+ self.default = default
+ self.result = []
+ self.fields = []
+
+ for literal_text, field_name, format_spec, conversion in \
+ _string.formatter_parser(format_string):
+ if literal_text:
+ self.result.append(literal_text)
+ if field_name:
+ self.fields.append((
+ len(self.result),
+ self._field_access(field_name, format_spec, conversion)
+ ))
+ self.result.append("")
+
+ def format_map(self, kwargs):
+ """Apply 'kwargs' to the initial format_string and return its result"""
+ for index, func in self.fields:
+ self.result[index] = func(kwargs)
+ return "".join(self.result)
+
+ def _field_access(self, field_name, format_spec, conversion):
+ first, rest = _string.formatter_field_name_split(field_name)
+
+ funcs = []
+ for is_attr, key in rest:
+ if is_attr:
+ func = operator.attrgetter
+ elif ":" in key:
+ func = self._slicegetter
+ else:
+ func = operator.itemgetter
+ funcs.append(func(key))
+
+ if conversion:
+ funcs.append(self.CONVERSIONS[conversion])
+
+ if format_spec:
+ if format_spec[0] == "?":
+ func = self._format_optional
+ elif format_spec[0] == "L":
+ func = self._format_maxlen
+ elif format_spec[0] == "J":
+ func = self._format_join
+ elif format_spec[0] == "R":
+ func = self._format_replace
+ else:
+ func = self._format_default
+ fmt = func(format_spec)
+ else:
+ fmt = str
+
+ if funcs:
+ return self._apply(first, funcs, fmt)
+ return self._apply_simple(first, fmt)
+
+ def _apply_simple(self, key, fmt):
+ def wrap(obj):
+ if key in obj:
+ obj = obj[key]
+ else:
+ obj = self.default
+ return fmt(obj)
+ return wrap
+
+ def _apply(self, key, funcs, fmt):
+ def wrap(obj):
+ try:
+ obj = obj[key]
+ for func in funcs:
+ obj = func(obj)
+ except Exception:
+ obj = self.default
+ return fmt(obj)
+ return wrap
+
+ @staticmethod
+ def _slicegetter(key):
+ start, _, stop = key.partition(":")
+ stop, _, step = stop.partition(":")
+ start = int(start) if start else None
+ stop = int(stop) if stop else None
+ step = int(step) if step else None
+ return operator.itemgetter(slice(start, stop, step))
+
+ @staticmethod
+ def _format_optional(format_spec):
+ def wrap(obj):
+ if not obj:
+ return ""
+ return before + format(obj, format_spec) + after
+ before, after, format_spec = format_spec.split("/", 2)
+ before = before[1:]
+ return wrap
+
+ @staticmethod
+ def _format_maxlen(format_spec):
+ def wrap(obj):
+ obj = format(obj, format_spec)
+ return obj if len(obj) <= maxlen else replacement
+ maxlen, replacement, format_spec = format_spec.split("/", 2)
+ maxlen = text.parse_int(maxlen[1:])
+ return wrap
+
+ @staticmethod
+ def _format_join(format_spec):
+ def wrap(obj):
+ obj = separator.join(obj)
+ return format(obj, format_spec)
+ separator, _, format_spec = format_spec.partition("/")
+ separator = separator[1:]
+ return wrap
+
+ @staticmethod
+ def _format_replace(format_spec):
+ def wrap(obj):
+ obj = obj.replace(old, new)
+ return format(obj, format_spec)
+ old, new, format_spec = format_spec.split("/", 2)
+ old = old[1:]
+ return wrap
+
+ @staticmethod
+ def _format_default(format_spec):
+ def wrap(obj):
+ return format(obj, format_spec)
+ return wrap
+
+
+class PathFormat():
+
+ def __init__(self, extractor):
+ self.filename_fmt = extractor.config(
+ "filename", extractor.filename_fmt)
+ self.directory_fmt = extractor.config(
+ "directory", extractor.directory_fmt)
+ self.kwdefault = extractor.config("keywords-default")
+
+ try:
+ self.formatter = Formatter(self.filename_fmt, self.kwdefault)
+ except Exception as exc:
+ raise exception.FormatError(exc, "filename")
+
+ self.delete = False
+ self.has_extension = False
+ self.keywords = {}
+ self.filename = ""
+ self.directory = self.realdirectory = ""
+ self.path = self.realpath = self.temppath = ""
+
+ self.basedirectory = expand_path(
+ extractor.config("base-directory", (".", "gallery-dl")))
+ if os.altsep:
+ self.basedirectory = self.basedirectory.replace(os.altsep, os.sep)
+
+ def open(self, mode="wb"):
+ """Open file and return a corresponding file object"""
+ return open(self.temppath, mode)
+
+ def exists(self, archive=None):
+ """Return True if the file exists on disk or in 'archive'"""
+ if (archive and archive.check(self.keywords) or
+ self.has_extension and os.path.exists(self.realpath)):
+ if not self.has_extension:
+ # adjust display name
+ self.set_extension("")
+ if self.path[-1] == ".":
+ self.path = self.path[:-1]
+ return True
+ return False
+
+ def set_directory(self, keywords):
+ """Build directory path and create it if necessary"""
+ try:
+ segments = [
+ text.clean_path(
+ Formatter(segment, self.kwdefault)
+ .format_map(keywords).strip())
+ for segment in self.directory_fmt
+ ]
+ except Exception as exc:
+ raise exception.FormatError(exc, "directory")
+
+ self.directory = os.path.join(
+ self.basedirectory,
+ *segments
+ )
+
+ # remove trailing path separator;
+ # occurs if the last argument to os.path.join() is an empty string
+ if self.directory[-1] == os.sep:
+ self.directory = self.directory[:-1]
+
+ self.realdirectory = self.adjust_path(self.directory)
+ os.makedirs(self.realdirectory, exist_ok=True)
+
+ def set_keywords(self, keywords):
+ """Set filename keywords"""
+ self.keywords = keywords
+ self.temppath = ""
+ self.has_extension = bool(keywords.get("extension"))
+ if self.has_extension:
+ self.build_path()
+
+ def set_extension(self, extension, real=True):
+ """Set the 'extension' keyword"""
+ self.has_extension = real
+ self.keywords["extension"] = extension
+ self.build_path()
+
+ def build_path(self):
+ """Use filename-keywords and directory to build a full path"""
+ try:
+ self.filename = text.clean_path(
+ self.formatter.format_map(self.keywords))
+ except Exception as exc:
+ raise exception.FormatError(exc, "filename")
+
+ filename = os.sep + self.filename
+ self.path = self.directory + filename
+ self.realpath = self.realdirectory + filename
+ if not self.temppath:
+ self.temppath = self.realpath
+
+ def part_enable(self, part_directory=None):
+ """Enable .part file usage"""
+ if self.has_extension:
+ self.temppath += ".part"
+ else:
+ self.set_extension("part", False)
+ if part_directory:
+ self.temppath = os.path.join(
+ part_directory,
+ os.path.basename(self.temppath),
+ )
+
+ def part_size(self):
+ """Return size of .part file"""
+ try:
+ return os.stat(self.temppath).st_size
+ except OSError:
+ pass
+ return 0
+
+ def finalize(self):
+ """Move tempfile to its target location"""
+ if self.delete:
+ self.delete = False
+ os.unlink(self.temppath)
+ return
+
+ if self.temppath == self.realpath:
+ return
+
+ try:
+ os.replace(self.temppath, self.realpath)
+ return
+ except OSError:
+ pass
+
+ shutil.copyfile(self.temppath, self.realpath)
+ os.unlink(self.temppath)
+
+ @staticmethod
+ def adjust_path(path):
+ """Enable longer-than-260-character paths on windows"""
+ return "\\\\?\\" + os.path.abspath(path) if os.name == "nt" else path
+
+
+class DownloadArchive():
+
+ def __init__(self, path, extractor):
+ con = sqlite3.connect(path)
+ con.isolation_level = None
+ self.cursor = con.cursor()
+ self.cursor.execute("CREATE TABLE IF NOT EXISTS archive "
+ "(entry PRIMARY KEY) WITHOUT ROWID")
+ self.keygen = (extractor.category + extractor.config(
+ "archive-format", extractor.archive_fmt)
+ ).format_map
+
+ def check(self, kwdict):
+ """Return True if item described by 'kwdict' exists in archive"""
+ key = self.keygen(kwdict)
+ self.cursor.execute(
+ "SELECT 1 FROM archive WHERE entry=? LIMIT 1", (key,))
+ return self.cursor.fetchone()
+
+ def add(self, kwdict):
+ """Add item described by 'kwdict' to archive"""
+ key = self.keygen(kwdict)
+ self.cursor.execute(
+ "INSERT OR IGNORE INTO archive VALUES (?)", (key,))