summaryrefslogtreecommitdiffstats
path: root/gallery_dl/downloader/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'gallery_dl/downloader/common.py')
-rw-r--r--gallery_dl/downloader/common.py170
1 files changed, 170 insertions, 0 deletions
diff --git a/gallery_dl/downloader/common.py b/gallery_dl/downloader/common.py
new file mode 100644
index 0000000..4803c85
--- /dev/null
+++ b/gallery_dl/downloader/common.py
@@ -0,0 +1,170 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2014-2018 Mike Fährmann
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 as
+# published by the Free Software Foundation.
+
+"""Common classes and constants used by downloader modules."""
+
+import os
+import time
+import logging
+from .. import config, util, exception
+from requests.exceptions import RequestException
+from ssl import SSLError
+
+
+class DownloaderBase():
+ """Base class for downloaders"""
+ scheme = ""
+ retries = 1
+
+ def __init__(self, extractor, output):
+ self.session = extractor.session
+ self.out = output
+ self.log = logging.getLogger("downloader." + self.scheme)
+ self.downloading = False
+ self.part = self.config("part", True)
+ self.partdir = self.config("part-directory")
+
+ if self.partdir:
+ self.partdir = util.expand_path(self.partdir)
+ os.makedirs(self.partdir, exist_ok=True)
+
+ def config(self, key, default=None):
+ """Interpolate config value for 'key'"""
+ return config.interpolate(("downloader", self.scheme, key), default)
+
+ def download(self, url, pathfmt):
+ """Download the resource at 'url' and write it to a file-like object"""
+ try:
+ return self.download_impl(url, pathfmt)
+ except Exception:
+ print()
+ raise
+ finally:
+ # remove file from incomplete downloads
+ if self.downloading and not self.part:
+ try:
+ os.remove(pathfmt.temppath)
+ except (OSError, AttributeError):
+ pass
+
+ def download_impl(self, url, pathfmt):
+ """Actual implementaion of the download process"""
+ adj_ext = None
+ tries = 0
+ msg = ""
+
+ if self.part:
+ pathfmt.part_enable(self.partdir)
+
+ while True:
+ self.reset()
+ if tries:
+ self.log.warning("%s (%d/%d)", msg, tries, self.retries)
+ if tries >= self.retries:
+ return False
+ time.sleep(tries)
+ tries += 1
+
+ # check for .part file
+ filesize = pathfmt.part_size()
+
+ # connect to (remote) source
+ try:
+ offset, size = self.connect(url, filesize)
+ except exception.DownloadRetry as exc:
+ msg = exc
+ continue
+ except exception.DownloadComplete:
+ break
+ except Exception as exc:
+ self.log.warning(exc)
+ return False
+
+ # check response
+ if not offset:
+ mode = "w+b"
+ if filesize:
+ self.log.info("Unable to resume partial download")
+ else:
+ mode = "r+b"
+ self.log.info("Resuming download at byte %d", offset)
+
+ # set missing filename extension
+ if not pathfmt.has_extension:
+ pathfmt.set_extension(self.get_extension())
+ if pathfmt.exists():
+ pathfmt.temppath = ""
+ return True
+
+ self.out.start(pathfmt.path)
+ self.downloading = True
+ with pathfmt.open(mode) as file:
+ if offset:
+ file.seek(offset)
+
+ # download content
+ try:
+ self.receive(file)
+ except (RequestException, SSLError) as exc:
+ msg = exc
+ print()
+ continue
+
+ # check filesize
+ if size and file.tell() < size:
+ msg = "filesize mismatch ({} < {})".format(
+ file.tell(), size)
+ continue
+
+ # check filename extension
+ adj_ext = self._check_extension(file, pathfmt)
+
+ break
+
+ self.downloading = False
+ if adj_ext:
+ pathfmt.set_extension(adj_ext)
+ return True
+
+ def connect(self, url, offset):
+ """Connect to 'url' while respecting 'offset' if possible
+
+ Returns a 2-tuple containing the actual offset and expected filesize.
+ If the returned offset-value is greater than zero, all received data
+ will be appended to the existing .part file.
+ Return '0' as second tuple-field to indicate an unknown filesize.
+ """
+
+ def receive(self, file):
+ """Write data to 'file'"""
+
+ def reset(self):
+ """Reset internal state / cleanup"""
+
+ def get_extension(self):
+ """Return a filename extension appropriate for the current request"""
+
+ @staticmethod
+ def _check_extension(file, pathfmt):
+ """Check filename extension against fileheader"""
+ extension = pathfmt.keywords["extension"]
+ if extension in FILETYPE_CHECK:
+ file.seek(0)
+ header = file.read(8)
+ if len(header) >= 8 and not FILETYPE_CHECK[extension](header):
+ for ext, check in FILETYPE_CHECK.items():
+ if ext != extension and check(header):
+ return ext
+ return None
+
+
+FILETYPE_CHECK = {
+ "jpg": lambda h: h[0:2] == b"\xff\xd8",
+ "png": lambda h: h[0:8] == b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a",
+ "gif": lambda h: h[0:4] == b"GIF8" and h[5] == 97,
+}