aboutsummaryrefslogtreecommitdiffstats
path: root/test/test_downloader.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_downloader.py')
-rw-r--r--test/test_downloader.py54
1 files changed, 20 insertions, 34 deletions
diff --git a/test/test_downloader.py b/test/test_downloader.py
index 9393040..5d73a4c 100644
--- a/test/test_downloader.py
+++ b/test/test_downloader.py
@@ -14,21 +14,30 @@ from unittest.mock import Mock, MagicMock, patch
import re
import base64
+import logging
import os.path
import tempfile
import threading
import http.server
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-from gallery_dl import downloader, extractor, config, util # noqa E402
-from gallery_dl.downloader.common import DownloaderBase # noqa E402
-from gallery_dl.output import NullOutput # noqa E402
+from gallery_dl import downloader, extractor, output, config, util # noqa E402
class MockDownloaderModule(Mock):
__downloader__ = "mock"
+class FakeJob():
+
+ def __init__(self):
+ self.extractor = extractor.find("test:")
+ self.pathfmt = util.PathFormat(self.extractor)
+ self.out = output.NullOutput()
+ self.get_logger = logging.getLogger
+
+
class TestDownloaderModule(unittest.TestCase):
@classmethod
@@ -96,11 +105,10 @@ class TestDownloaderBase(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls.extractor = extractor.find("test:")
- cls.extractor.log.job = None
cls.dir = tempfile.TemporaryDirectory()
cls.fnum = 0
config.set((), "base-directory", cls.dir.name)
+ cls.job = FakeJob()
@classmethod
def tearDownClass(cls):
@@ -113,12 +121,13 @@ class TestDownloaderBase(unittest.TestCase):
cls.fnum += 1
kwdict = {
- "category": "test",
+ "category" : "test",
"subcategory": "test",
- "filename": name,
- "extension": extension,
+ "filename" : name,
+ "extension" : extension,
}
- pathfmt = util.PathFormat(cls.extractor)
+
+ pathfmt = cls.job.pathfmt
pathfmt.set_directory(kwdict)
pathfmt.set_filename(kwdict)
@@ -159,7 +168,7 @@ class TestHTTPDownloader(TestDownloaderBase):
@classmethod
def setUpClass(cls):
TestDownloaderBase.setUpClass()
- cls.downloader = downloader.find("http")(cls.extractor, NullOutput())
+ cls.downloader = downloader.find("http")(cls.job)
port = 8088
cls.address = "http://127.0.0.1:{}".format(port)
@@ -196,7 +205,7 @@ class TestTextDownloader(TestDownloaderBase):
@classmethod
def setUpClass(cls):
TestDownloaderBase.setUpClass()
- cls.downloader = downloader.find("text")(cls.extractor, NullOutput())
+ cls.downloader = downloader.find("text")(cls.job)
def test_text_download(self):
self._run_test("text:foobar", None, "foobar", "txt", "txt")
@@ -208,29 +217,6 @@ class TestTextDownloader(TestDownloaderBase):
self._run_test("text:", None, "", "txt", "txt")
-class FakeDownloader(DownloaderBase):
- scheme = "fake"
-
- def __init__(self, extractor, output):
- DownloaderBase.__init__(self, extractor, output)
-
- def connect(self, url, offset):
- pass
-
- def receive(self, file):
- pass
-
- def reset(self):
- pass
-
- def get_extension(self):
- pass
-
- @staticmethod
- def _check_extension(file, pathfmt):
- pass
-
-
class HttpRequestHandler(http.server.BaseHTTPRequestHandler):
def do_GET(self):