aboutsummaryrefslogtreecommitdiffstats
path: root/test
diff options
context:
space:
mode:
authorLibravatarUnit 193 <unit193@ubuntu.com>2019-07-02 04:33:45 -0400
committerLibravatarUnit 193 <unit193@ubuntu.com>2019-07-02 04:33:45 -0400
commit195c45911e79c33cf0bb986721365fb06df5a153 (patch)
treeac0c9b6ef40bea7aa7ab0c5c3cb500eb510668fa /test
Import Upstream version 1.8.7upstream/1.8.7
Diffstat (limited to 'test')
-rw-r--r--test/__init__.py0
-rw-r--r--test/test_config.py81
-rw-r--r--test/test_cookies.py130
-rw-r--r--test/test_downloader.py235
-rw-r--r--test/test_extractor.py186
-rw-r--r--test/test_oauth.py104
-rw-r--r--test/test_results.py344
-rw-r--r--test/test_text.py409
-rw-r--r--test/test_util.py395
9 files changed, 1884 insertions, 0 deletions
diff --git a/test/__init__.py b/test/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/__init__.py
diff --git a/test/test_config.py b/test/test_config.py
new file mode 100644
index 0000000..8cdb3da
--- /dev/null
+++ b/test/test_config.py
@@ -0,0 +1,81 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2015-2017 Mike Fährmann
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 as
+# published by the Free Software Foundation.
+
+import unittest
+import gallery_dl.config as config
+import os
+import tempfile
+
+
+class TestConfig(unittest.TestCase):
+
+ def setUp(self):
+ fd, self._configfile = tempfile.mkstemp()
+ with os.fdopen(fd, "w") as file:
+ file.write('{"a": "1", "b": {"a": 2, "c": "text"}}')
+ config.load((self._configfile,))
+
+ def tearDown(self):
+ config.clear()
+ os.remove(self._configfile)
+
+ def test_get(self):
+ self.assertEqual(config.get(["a"]), "1")
+ self.assertEqual(config.get(["b", "c"]), "text")
+ self.assertEqual(config.get(["d"]), None)
+ self.assertEqual(config.get(["e", "f", "g"], 123), 123)
+
+ def test_interpolate(self):
+ self.assertEqual(config.interpolate(["a"]), "1")
+ self.assertEqual(config.interpolate(["b", "a"]), "1")
+ self.assertEqual(config.interpolate(["b", "c"], "2"), "text")
+ self.assertEqual(config.interpolate(["b", "d"], "2"), "2")
+ config.set(["d"], 123)
+ self.assertEqual(config.interpolate(["b", "d"], "2"), 123)
+ self.assertEqual(config.interpolate(["d", "d"], "2"), 123)
+
+ def test_set(self):
+ config.set(["b", "c"], [1, 2, 3])
+ config.set(["e", "f", "g"], value=234)
+ self.assertEqual(config.get(["b", "c"]), [1, 2, 3])
+ self.assertEqual(config.get(["e", "f", "g"]), 234)
+
+ def test_setdefault(self):
+ config.setdefault(["b", "c"], [1, 2, 3])
+ config.setdefault(["e", "f", "g"], value=234)
+ self.assertEqual(config.get(["b", "c"]), "text")
+ self.assertEqual(config.get(["e", "f", "g"]), 234)
+
+ def test_unset(self):
+ config.unset(["a"])
+ config.unset(["b", "c"])
+ config.unset(["c", "d"])
+ self.assertEqual(config.get(["a"]), None)
+ self.assertEqual(config.get(["b", "a"]), 2)
+ self.assertEqual(config.get(["b", "c"]), None)
+
+ def test_apply(self):
+ options = (
+ (["b", "c"], [1, 2, 3]),
+ (["e", "f", "g"], 234),
+ )
+
+ self.assertEqual(config.get(["b", "c"]), "text")
+ self.assertEqual(config.get(["e", "f", "g"]), None)
+
+ with config.apply(options):
+ self.assertEqual(config.get(["b", "c"]), [1, 2, 3])
+ self.assertEqual(config.get(["e", "f", "g"]), 234)
+
+ self.assertEqual(config.get(["b", "c"]), "text")
+ self.assertEqual(config.get(["e", "f", "g"]), None)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/test/test_cookies.py b/test/test_cookies.py
new file mode 100644
index 0000000..a786df6
--- /dev/null
+++ b/test/test_cookies.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2017 Mike Fährmann
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 as
+# published by the Free Software Foundation.
+
+import unittest
+from unittest import mock
+
+import logging
+import tempfile
+import http.cookiejar
+from os.path import join
+
+import gallery_dl.config as config
+import gallery_dl.extractor as extractor
+
+CKEY = ("cookies",)
+
+
+class TestCookiejar(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.path = tempfile.TemporaryDirectory()
+
+ cls.cookiefile = join(cls.path.name, "cookies.txt")
+ with open(cls.cookiefile, "w") as file:
+ file.write("""# HTTP Cookie File
+.example.org\tTRUE\t/\tFALSE\t253402210800\tNAME\tVALUE
+""")
+
+ cls.invalid_cookiefile = join(cls.path.name, "invalid.txt")
+ with open(cls.invalid_cookiefile, "w") as file:
+ file.write("""# asd
+.example.org\tTRUE\t/\tFALSE\t253402210800\tNAME\tVALUE
+""")
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.path.cleanup()
+ config.clear()
+
+ def test_cookiefile(self):
+ config.set(CKEY, self.cookiefile)
+
+ cookies = extractor.find("test:").session.cookies
+ self.assertEqual(len(cookies), 1)
+
+ cookie = next(iter(cookies))
+ self.assertEqual(cookie.domain, ".example.org")
+ self.assertEqual(cookie.path , "/")
+ self.assertEqual(cookie.name , "NAME")
+ self.assertEqual(cookie.value , "VALUE")
+
+ def test_invalid_cookiefile(self):
+ self._test_warning(self.invalid_cookiefile, http.cookiejar.LoadError)
+
+ def test_invalid_filename(self):
+ self._test_warning(join(self.path.name, "nothing"), FileNotFoundError)
+
+ def _test_warning(self, filename, exc):
+ config.set(CKEY, filename)
+ log = logging.getLogger("test")
+ with mock.patch.object(log, "warning") as mock_warning:
+ cookies = extractor.find("test:").session.cookies
+ self.assertEqual(len(cookies), 0)
+ self.assertEqual(mock_warning.call_count, 1)
+ self.assertEqual(mock_warning.call_args[0][0], "cookies: %s")
+ self.assertIsInstance(mock_warning.call_args[0][1], exc)
+
+
+class TestCookiedict(unittest.TestCase):
+
+ def setUp(self):
+ self.cdict = {"NAME1": "VALUE1", "NAME2": "VALUE2"}
+ config.set(CKEY, self.cdict)
+
+ def tearDown(self):
+ config.clear()
+
+ def test_dict(self):
+ cookies = extractor.find("test:").session.cookies
+ self.assertEqual(len(cookies), len(self.cdict))
+ self.assertEqual(sorted(cookies.keys()), sorted(self.cdict.keys()))
+ self.assertEqual(sorted(cookies.values()), sorted(self.cdict.values()))
+
+ def test_domain(self):
+ for category in ["exhentai", "nijie", "sankaku", "seiga"]:
+ extr = _get_extractor(category)
+ cookies = extr.session.cookies
+ for key in self.cdict:
+ self.assertTrue(key in cookies)
+ for c in cookies:
+ self.assertEqual(c.domain, extr.cookiedomain)
+
+
+class TestCookieLogin(unittest.TestCase):
+
+ def tearDown(self):
+ config.clear()
+
+ def test_cookie_login(self):
+ extr_cookies = {
+ "exhentai": ("ipb_member_id", "ipb_pass_hash"),
+ "nijie" : ("nemail", "nlogin"),
+ "sankaku" : ("login", "pass_hash"),
+ "seiga" : ("user_session",),
+ }
+ for category, cookienames in extr_cookies.items():
+ cookies = {name: "value" for name in cookienames}
+ config.set(CKEY, cookies)
+ extr = _get_extractor(category)
+ with mock.patch.object(extr, "_login_impl") as mock_login:
+ extr.login()
+ mock_login.assert_not_called()
+
+
+def _get_extractor(category):
+ for extr in extractor.extractors():
+ if extr.category == category and hasattr(extr, "_login_impl"):
+ url = next(extr._get_tests())[0]
+ return extr.from_url(url)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/test_downloader.py b/test/test_downloader.py
new file mode 100644
index 0000000..3f301b0
--- /dev/null
+++ b/test/test_downloader.py
@@ -0,0 +1,235 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2018 Mike Fährmann
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 as
+# published by the Free Software Foundation.
+
+import re
+import base64
+import os.path
+import tempfile
+import unittest
+import threading
+import http.server
+
+import gallery_dl.downloader as downloader
+import gallery_dl.extractor as extractor
+import gallery_dl.config as config
+from gallery_dl.downloader.common import DownloaderBase
+from gallery_dl.output import NullOutput
+from gallery_dl.util import PathFormat
+
+
+class TestDownloaderBase(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.extractor = extractor.find("test:")
+ cls.dir = tempfile.TemporaryDirectory()
+ cls.fnum = 0
+ config.set(("base-directory",), cls.dir.name)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.dir.cleanup()
+ config.clear()
+
+ @classmethod
+ def _prepare_destination(cls, content=None, part=True, extension=None):
+ name = "file-{}".format(cls.fnum)
+ cls.fnum += 1
+
+ kwdict = {
+ "category": "test",
+ "subcategory": "test",
+ "filename": name,
+ "extension": extension,
+ }
+ pathfmt = PathFormat(cls.extractor)
+ pathfmt.set_directory(kwdict)
+ pathfmt.set_keywords(kwdict)
+
+ if content:
+ mode = "w" + ("b" if isinstance(content, bytes) else "")
+ with pathfmt.open(mode) as file:
+ file.write(content)
+
+ return pathfmt
+
+ def _run_test(self, url, input, output,
+ extension, expected_extension=None):
+ pathfmt = self._prepare_destination(input, extension=extension)
+ success = self.downloader.download(url, pathfmt)
+
+ # test successful download
+ self.assertTrue(success, "downloading '{}' failed".format(url))
+
+ # test content
+ mode = "r" + ("b" if isinstance(output, bytes) else "")
+ with pathfmt.open(mode) as file:
+ content = file.read()
+ self.assertEqual(content, output)
+
+ # test filename extension
+ self.assertEqual(
+ pathfmt.keywords["extension"],
+ expected_extension,
+ )
+ self.assertEqual(
+ os.path.splitext(pathfmt.realpath)[1][1:],
+ expected_extension,
+ )
+
+
+class TestHTTPDownloader(TestDownloaderBase):
+
+ @classmethod
+ def setUpClass(cls):
+ TestDownloaderBase.setUpClass()
+ cls.downloader = downloader.find("http")(cls.extractor, NullOutput())
+
+ port = 8088
+ cls.address = "http://127.0.0.1:{}".format(port)
+ cls._jpg = cls.address + "/image.jpg"
+ cls._png = cls.address + "/image.png"
+ cls._gif = cls.address + "/image.gif"
+
+ server = http.server.HTTPServer(("", port), HttpRequestHandler)
+ threading.Thread(target=server.serve_forever, daemon=True).start()
+
+ def test_http_download(self):
+ self._run_test(self._jpg, None, DATA_JPG, "jpg", "jpg")
+ self._run_test(self._png, None, DATA_PNG, "png", "png")
+ self._run_test(self._gif, None, DATA_GIF, "gif", "gif")
+
+ def test_http_offset(self):
+ self._run_test(self._jpg, DATA_JPG[:123], DATA_JPG, "jpg", "jpg")
+ self._run_test(self._png, DATA_PNG[:12] , DATA_PNG, "png", "png")
+ self._run_test(self._gif, DATA_GIF[:1] , DATA_GIF, "gif", "gif")
+
+ def test_http_extension(self):
+ self._run_test(self._jpg, None, DATA_JPG, None, "jpg")
+ self._run_test(self._png, None, DATA_PNG, None, "png")
+ self._run_test(self._gif, None, DATA_GIF, None, "gif")
+
+ def test_http_adjust_extension(self):
+ self._run_test(self._jpg, None, DATA_JPG, "png", "jpg")
+ self._run_test(self._png, None, DATA_PNG, "gif", "png")
+ self._run_test(self._gif, None, DATA_GIF, "jpg", "gif")
+
+
+class TestTextDownloader(TestDownloaderBase):
+
+ @classmethod
+ def setUpClass(cls):
+ TestDownloaderBase.setUpClass()
+ cls.downloader = downloader.find("text")(cls.extractor, NullOutput())
+
+ def test_text_download(self):
+ self._run_test("text:foobar", None, "foobar", "txt", "txt")
+
+ def test_text_offset(self):
+ self._run_test("text:foobar", "foo", "foobar", "txt", "txt")
+
+ def test_text_extension(self):
+ self._run_test("text:foobar", None, "foobar", None, "txt")
+
+ def test_text_empty(self):
+ self._run_test("text:", None, "", "txt", "txt")
+
+
+class FakeDownloader(DownloaderBase):
+ scheme = "fake"
+
+ def __init__(self, extractor, output):
+ DownloaderBase.__init__(self, extractor, output)
+
+ def connect(self, url, offset):
+ pass
+
+ def receive(self, file):
+ pass
+
+ def reset(self):
+ pass
+
+ def get_extension(self):
+ pass
+
+ @staticmethod
+ def _check_extension(file, pathfmt):
+ pass
+
+
+class HttpRequestHandler(http.server.BaseHTTPRequestHandler):
+
+ def do_GET(self):
+ if self.path == "/image.jpg":
+ content_type = "image/jpeg"
+ output = DATA_JPG
+ elif self.path == "/image.png":
+ content_type = "image/png"
+ output = DATA_PNG
+ elif self.path == "/image.gif":
+ content_type = "image/gif"
+ output = DATA_GIF
+ else:
+ self.send_response(404)
+ self.wfile.write(self.path.encode())
+ return
+
+ headers = {
+ "Content-Type": content_type,
+ "Content-Length": len(output),
+ }
+
+ if "Range" in self.headers:
+ status = 206
+
+ match = re.match(r"bytes=(\d+)-", self.headers["Range"])
+ start = int(match.group(1))
+
+ headers["Content-Range"] = "bytes {}-{}/{}".format(
+ start, len(output)-1, len(output))
+ output = output[start:]
+ else:
+ status = 200
+
+ self.send_response(status)
+ for key, value in headers.items():
+ self.send_header(key, value)
+ self.end_headers()
+ self.wfile.write(output)
+
+
+DATA_JPG = base64.standard_b64decode("""
+/9j/4AAQSkZJRgABAQEASABIAAD/2wBD
+AAEBAQEBAQEBAQEBAQEBAQEBAQEBAQEB
+AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEB
+AQEBAQEBAQEBAQEBAQEBAQH/2wBDAQEB
+AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEB
+AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEB
+AQEBAQEBAQEBAQEBAQH/wAARCAABAAED
+AREAAhEBAxEB/8QAFAABAAAAAAAAAAAA
+AAAAAAAACv/EABQQAQAAAAAAAAAAAAAA
+AAAAAAD/xAAUAQEAAAAAAAAAAAAAAAAA
+AAAA/8QAFBEBAAAAAAAAAAAAAAAAAAAA
+AP/aAAwDAQACEQMRAD8AfwD/2Q==""")
+
+
+DATA_PNG = base64.standard_b64decode("""
+iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB
+CAAAAAA6fptVAAAACklEQVQIHWP4DwAB
+AQEANl9ngAAAAABJRU5ErkJggg==""")
+
+
+DATA_GIF = base64.standard_b64decode("""
+R0lGODdhAQABAIAAAP///////ywAAAAA
+AQABAAACAkQBADs=""")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/test_extractor.py b/test/test_extractor.py
new file mode 100644
index 0000000..fa0709b
--- /dev/null
+++ b/test/test_extractor.py
@@ -0,0 +1,186 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2018-2019 Mike Fährmann
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 as
+# published by the Free Software Foundation.
+
+import sys
+import unittest
+import string
+
+from gallery_dl import extractor
+from gallery_dl.extractor.common import Extractor, Message
+from gallery_dl.extractor.directlink import DirectlinkExtractor as DLExtractor
+
+
+class FakeExtractor(Extractor):
+ category = "fake"
+ subcategory = "test"
+ pattern = "fake:"
+
+ def items(self):
+ yield Message.Version, 1
+ yield Message.Url, "text:foobar", {}
+
+
+class TestExtractor(unittest.TestCase):
+ VALID_URIS = (
+ "https://example.org/file.jpg",
+ "tumblr:foobar",
+ "oauth:flickr",
+ "test:pixiv:",
+ "recursive:https://example.org/document.html",
+ )
+
+ def setUp(self):
+ extractor._cache.clear()
+ extractor._module_iter = iter(extractor.modules)
+
+ def test_find(self):
+ for uri in self.VALID_URIS:
+ result = extractor.find(uri)
+ self.assertIsInstance(result, Extractor, uri)
+
+ for not_found in ("", "/tmp/file.ext"):
+ self.assertIsNone(extractor.find(not_found))
+
+ for invalid in (None, [], {}, 123, b"test:"):
+ with self.assertRaises(TypeError):
+ extractor.find(invalid)
+
+ def test_add(self):
+ uri = "fake:foobar"
+ self.assertIsNone(extractor.find(uri))
+
+ extractor.add(FakeExtractor)
+ self.assertIsInstance(extractor.find(uri), FakeExtractor)
+
+ def test_add_module(self):
+ uri = "fake:foobar"
+ self.assertIsNone(extractor.find(uri))
+
+ classes = extractor.add_module(sys.modules[__name__])
+ self.assertEqual(len(classes), 1)
+ self.assertEqual(classes[0].pattern, FakeExtractor.pattern)
+ self.assertEqual(classes[0], FakeExtractor)
+ self.assertIsInstance(extractor.find(uri), FakeExtractor)
+
+ def test_blacklist(self):
+ link_uri = "https://example.org/file.jpg"
+ test_uri = "test:"
+ fake_uri = "fake:"
+
+ self.assertIsInstance(extractor.find(link_uri), DLExtractor)
+ self.assertIsInstance(extractor.find(test_uri), Extractor)
+ self.assertIsNone(extractor.find(fake_uri))
+
+ with extractor.blacklist(["directlink"]):
+ self.assertIsNone(extractor.find(link_uri))
+ self.assertIsInstance(extractor.find(test_uri), Extractor)
+ self.assertIsNone(extractor.find(fake_uri))
+
+ with extractor.blacklist([], [DLExtractor, FakeExtractor]):
+ self.assertIsNone(extractor.find(link_uri))
+ self.assertIsInstance(extractor.find(test_uri), Extractor)
+ self.assertIsNone(extractor.find(fake_uri))
+
+ with extractor.blacklist(["test"], [DLExtractor]):
+ self.assertIsNone(extractor.find(link_uri))
+ self.assertIsNone(extractor.find(test_uri))
+ self.assertIsNone(extractor.find(fake_uri))
+
+ def test_from_url(self):
+ for uri in self.VALID_URIS:
+ cls = extractor.find(uri).__class__
+ extr = cls.from_url(uri)
+ self.assertIs(type(extr), cls)
+ self.assertIsInstance(extr, Extractor)
+
+ for not_found in ("", "/tmp/file.ext"):
+ self.assertIsNone(FakeExtractor.from_url(not_found))
+
+ for invalid in (None, [], {}, 123, b"test:"):
+ with self.assertRaises(TypeError):
+ FakeExtractor.from_url(invalid)
+
+ def test_unique_pattern_matches(self):
+ test_urls = []
+
+ # collect testcase URLs
+ for extr in extractor.extractors():
+ for testcase in extr._get_tests():
+ test_urls.append((testcase[0], extr))
+
+ # iterate over all testcase URLs
+ for url, extr1 in test_urls:
+ matches = []
+
+ # ... and apply all regex patterns to each one
+ for extr2 in extractor._cache:
+
+ # skip DirectlinkExtractor pattern if it isn't tested
+ if extr1 != DLExtractor and extr2 == DLExtractor:
+ continue
+
+ match = extr2.pattern.match(url)
+ if match:
+ matches.append(match)
+
+ # fail if more or less than 1 match happened
+ if len(matches) > 1:
+ msg = "'{}' gets matched by more than one pattern:".format(url)
+ for match in matches:
+ msg += "\n- "
+ msg += match.re.pattern
+ self.fail(msg)
+
+ if len(matches) < 1:
+ msg = "'{}' isn't matched by any pattern".format(url)
+ self.fail(msg)
+
+ def test_docstrings(self):
+ """ensure docstring uniqueness"""
+ for extr1 in extractor.extractors():
+ for extr2 in extractor.extractors():
+ if extr1 != extr2 and extr1.__doc__ and extr2.__doc__:
+ self.assertNotEqual(
+ extr1.__doc__,
+ extr2.__doc__,
+ "{} <-> {}".format(extr1, extr2),
+ )
+
+ def test_names(self):
+ """Ensure extractor classes are named CategorySubcategoryExtractor"""
+ def capitalize(c):
+ if "-" in c:
+ return string.capwords(c.replace("-", " ")).replace(" ", "")
+ if "." in c:
+ c = c.replace(".", "")
+ return c.capitalize()
+
+ mapping = {
+ "2chan" : "futaba",
+ "3dbooru": "threedeebooru",
+ "4chan" : "fourchan",
+ "4plebs" : "fourplebs",
+ "8chan" : "infinitychan",
+ "oauth" : None,
+ }
+
+ for extr in extractor.extractors():
+ category = mapping.get(extr.category, extr.category)
+ if category:
+ expected = "{}{}Extractor".format(
+ capitalize(category),
+ capitalize(extr.subcategory),
+ )
+ if expected[0].isdigit():
+ expected = "_" + expected
+ self.assertEqual(expected, extr.__name__)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/test/test_oauth.py b/test/test_oauth.py
new file mode 100644
index 0000000..2ce5b43
--- /dev/null
+++ b/test/test_oauth.py
@@ -0,0 +1,104 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2018 Mike Fährmann
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 as
+# published by the Free Software Foundation.
+
+import unittest
+
+from gallery_dl import oauth, text
+
+TESTSERVER = "http://term.ie/oauth/example"
+CONSUMER_KEY = "key"
+CONSUMER_SECRET = "secret"
+REQUEST_TOKEN = "requestkey"
+REQUEST_TOKEN_SECRET = "requestsecret"
+ACCESS_TOKEN = "accesskey"
+ACCESS_TOKEN_SECRET = "accesssecret"
+
+
+class TestOAuthSession(unittest.TestCase):
+
+ def test_concat(self):
+ concat = oauth.concat
+
+ self.assertEqual(concat(), "")
+ self.assertEqual(concat("str"), "str")
+ self.assertEqual(concat("str1", "str2"), "str1&str2")
+
+ self.assertEqual(concat("&", "?/"), "%26&%3F%2F")
+ self.assertEqual(
+ concat("GET", "http://example.org/", "foo=bar&baz=a"),
+ "GET&http%3A%2F%2Fexample.org%2F&foo%3Dbar%26baz%3Da"
+ )
+
+ def test_nonce(self, size=16):
+ nonce_values = set(oauth.nonce(size) for _ in range(size))
+
+ # uniqueness
+ self.assertEqual(len(nonce_values), size)
+
+ # length
+ for nonce in nonce_values:
+ self.assertEqual(len(nonce), size)
+
+ def test_quote(self):
+ quote = oauth.quote
+
+ reserved = ",;:!\"§$%&/(){}[]=?`´+*'äöü"
+ unreserved = ("ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ "abcdefghijklmnopqrstuvwxyz"
+ "0123456789-._~")
+
+ for char in unreserved:
+ self.assertEqual(quote(char), char)
+
+ for char in reserved:
+ quoted = quote(char)
+ quoted_hex = quoted.replace("%", "")
+ self.assertTrue(quoted.startswith("%"))
+ self.assertTrue(len(quoted) >= 3)
+ self.assertEqual(quoted_hex.upper(), quoted_hex)
+
+ def test_request_token(self):
+ response = self._oauth_request(
+ "/request_token.php", {})
+ expected = "oauth_token=requestkey&oauth_token_secret=requestsecret"
+ self.assertEqual(response, expected, msg=response)
+
+ data = text.parse_query(response)
+ self.assertTrue(data["oauth_token"], REQUEST_TOKEN)
+ self.assertTrue(data["oauth_token_secret"], REQUEST_TOKEN_SECRET)
+
+ def test_access_token(self):
+ response = self._oauth_request(
+ "/access_token.php", {}, REQUEST_TOKEN, REQUEST_TOKEN_SECRET)
+ expected = "oauth_token=accesskey&oauth_token_secret=accesssecret"
+ self.assertEqual(response, expected, msg=response)
+
+ data = text.parse_query(response)
+ self.assertTrue(data["oauth_token"], ACCESS_TOKEN)
+ self.assertTrue(data["oauth_token_secret"], ACCESS_TOKEN_SECRET)
+
+ def test_authenticated_call(self):
+ params = {"method": "foo", "a": "äöüß/?&#", "äöüß/?&#": "a"}
+ response = self._oauth_request(
+ "/echo_api.php", params, ACCESS_TOKEN, ACCESS_TOKEN_SECRET)
+
+ self.assertEqual(text.parse_query(response), params)
+
+ def _oauth_request(self, endpoint, params=None,
+ oauth_token=None, oauth_token_secret=None):
+ session = oauth.OAuth1Session(
+ CONSUMER_KEY, CONSUMER_SECRET,
+ oauth_token, oauth_token_secret,
+ )
+ url = TESTSERVER + endpoint
+ return session.get(url, params=params).text
+
+
+if __name__ == "__main__":
+ unittest.main(warnings="ignore")
diff --git a/test/test_results.py b/test/test_results.py
new file mode 100644
index 0000000..8f03f03
--- /dev/null
+++ b/test/test_results.py
@@ -0,0 +1,344 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2015-2019 Mike Fährmann
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 as
+# published by the Free Software Foundation.
+
+import os
+import sys
+import re
+import json
+import hashlib
+import unittest
+from gallery_dl import extractor, util, job, config, exception
+
+
+# these don't work on Travis CI
+TRAVIS_SKIP = {
+ "exhentai", "kissmanga", "mangafox", "dynastyscans", "nijie", "bobx",
+ "archivedmoe", "archiveofsins", "thebarchive", "fireden", "4plebs",
+ "sankaku", "idolcomplex", "mangahere", "readcomiconline", "mangadex",
+ "sankakucomplex",
+}
+
+# temporary issues, etc.
+BROKEN = {
+ "komikcast",
+ "mangapark",
+}
+
+
+class TestExtractorResults(unittest.TestCase):
+
+ def setUp(self):
+ setup_test_config()
+
+ def tearDown(self):
+ config.clear()
+
+ @classmethod
+ def setUpClass(cls):
+ cls._skipped = []
+
+ @classmethod
+ def tearDownClass(cls):
+ if cls._skipped:
+ print("\n\nSkipped tests:")
+ for url, exc in cls._skipped:
+ print('- {} ("{}")'.format(url, exc))
+
+ def _run_test(self, extr, url, result):
+ if result:
+ if "options" in result:
+ for key, value in result["options"]:
+ config.set(key.split("."), value)
+ if "range" in result:
+ config.set(("image-range",), result["range"])
+ config.set(("chapter-range",), result["range"])
+ content = "content" in result
+ else:
+ content = False
+
+ tjob = ResultJob(url, content=content)
+ self.assertEqual(extr, tjob.extractor.__class__)
+
+ if not result:
+ return
+ if "exception" in result:
+ with self.assertRaises(result["exception"]):
+ tjob.run()
+ return
+ try:
+ tjob.run()
+ except exception.StopExtraction:
+ pass
+ except exception.HttpError as exc:
+ exc = str(exc)
+ if re.match(r"5\d\d: ", exc) or \
+ re.search(r"\bRead timed out\b", exc):
+ self._skipped.append((url, exc))
+ self.skipTest(exc)
+ raise
+
+ # test archive-id uniqueness
+ self.assertEqual(len(set(tjob.list_archive)), len(tjob.list_archive))
+
+ # test '_extractor' entries
+ if tjob.queue:
+ for url, kwdict in zip(tjob.list_url, tjob.list_keyword):
+ if "_extractor" in kwdict:
+ extr = kwdict["_extractor"].from_url(url)
+ self.assertIsInstance(extr, kwdict["_extractor"])
+ self.assertEqual(extr.url, url)
+
+ # test extraction results
+ if "url" in result:
+ self.assertEqual(result["url"], tjob.hash_url.hexdigest())
+
+ if "content" in result:
+ self.assertEqual(result["content"], tjob.hash_content.hexdigest())
+
+ if "keyword" in result:
+ keyword = result["keyword"]
+ if isinstance(keyword, dict):
+ for kwdict in tjob.list_keyword:
+ self._test_kwdict(kwdict, keyword)
+ else: # assume SHA1 hash
+ self.assertEqual(keyword, tjob.hash_keyword.hexdigest())
+
+ if "count" in result:
+ count = result["count"]
+ if isinstance(count, str):
+ self.assertRegex(count, r"^ *(==|!=|<|<=|>|>=) *\d+ *$")
+ expr = "{} {}".format(len(tjob.list_url), count)
+ self.assertTrue(eval(expr), msg=expr)
+ else: # assume integer
+ self.assertEqual(len(tjob.list_url), count)
+
+ if "pattern" in result:
+ self.assertGreater(len(tjob.list_url), 0)
+ for url in tjob.list_url:
+ self.assertRegex(url, result["pattern"])
+
+ def _test_kwdict(self, kwdict, tests):
+ for key, test in tests.items():
+ if key.startswith("?"):
+ key = key[1:]
+ if key not in kwdict:
+ continue
+ self.assertIn(key, kwdict)
+ value = kwdict[key]
+
+ if isinstance(test, dict):
+ self._test_kwdict(value, test)
+ elif isinstance(test, type):
+ self.assertIsInstance(value, test, msg=key)
+ elif isinstance(test, str):
+ if test.startswith("re:"):
+ self.assertRegex(value, test[3:], msg=key)
+ elif test.startswith("type:"):
+ self.assertEqual(type(value).__name__, test[5:], msg=key)
+ else:
+ self.assertEqual(value, test, msg=key)
+ else:
+ self.assertEqual(value, test, msg=key)
+
+
+class ResultJob(job.DownloadJob):
+ """Generate test-results for extractor runs"""
+
+ def __init__(self, url, parent=None, content=False):
+ job.DownloadJob.__init__(self, url, parent)
+ self.queue = False
+ self.content = content
+ self.list_url = []
+ self.list_keyword = []
+ self.list_archive = []
+ self.hash_url = hashlib.sha1()
+ self.hash_keyword = hashlib.sha1()
+ self.hash_archive = hashlib.sha1()
+ self.hash_content = hashlib.sha1()
+ if content:
+ self.fileobj = TestPathfmt(self.hash_content)
+ self.get_downloader("http")._check_extension = lambda a, b: None
+
+ self.format_directory = TestFormatter(
+ "".join(self.extractor.directory_fmt))
+ self.format_filename = TestFormatter(self.extractor.filename_fmt)
+
+ def run(self):
+ for msg in self.extractor:
+ self.dispatch(msg)
+
+ def handle_url(self, url, keywords, fallback=None):
+ self.update_url(url)
+ self.update_keyword(keywords)
+ self.update_archive(keywords)
+ self.update_content(url)
+ self.format_filename.format_map(keywords)
+
+ def handle_directory(self, keywords):
+ self.update_keyword(keywords, False)
+ self.format_directory.format_map(keywords)
+
+ def handle_queue(self, url, keywords):
+ self.queue = True
+ self.update_url(url)
+ self.update_keyword(keywords)
+
+ def update_url(self, url):
+ self.list_url.append(url)
+ self.hash_url.update(url.encode())
+
+ def update_keyword(self, kwdict, to_list=True):
+ if to_list:
+ self.list_keyword.append(kwdict)
+ kwdict = self._filter(kwdict)
+ self.hash_keyword.update(
+ json.dumps(kwdict, sort_keys=True, default=str).encode())
+
+ def update_archive(self, kwdict):
+ archive_id = self.extractor.archive_fmt.format_map(kwdict)
+ self.list_archive.append(archive_id)
+ self.hash_archive.update(archive_id.encode())
+
+ def update_content(self, url):
+ if self.content:
+ scheme = url.partition(":")[0]
+ self.get_downloader(scheme).download(url, self.fileobj)
+
+
+class TestPathfmt():
+
+ def __init__(self, hashobj):
+ self.hashobj = hashobj
+ self.path = ""
+ self.size = 0
+ self.has_extension = True
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *args):
+ pass
+
+ def open(self, mode):
+ self.size = 0
+ return self
+
+ def write(self, content):
+ """Update SHA1 hash"""
+ self.size += len(content)
+ self.hashobj.update(content)
+
+ def tell(self):
+ return self.size
+
+ def part_size(self):
+ return 0
+
+
+class TestFormatter(util.Formatter):
+
+ @staticmethod
+ def _noop(_):
+ return ""
+
+ def _apply_simple(self, key, fmt):
+ if key == "extension" or "._format_optional." in repr(fmt):
+ return self._noop
+
+ def wrap(obj):
+ return fmt(obj[key])
+ return wrap
+
+ def _apply(self, key, funcs, fmt):
+ if key == "extension" or "._format_optional." in repr(fmt):
+ return self._noop
+
+ def wrap(obj):
+ obj = obj[key]
+ for func in funcs:
+ obj = func(obj)
+ return fmt(obj)
+ return wrap
+
+
+def setup_test_config():
+ name = "gallerydl"
+ email = "gallerydl@openaliasbox.org"
+
+ config.clear()
+ config.set(("cache", "file"), ":memory:")
+ config.set(("downloader", "part"), False)
+ config.set(("extractor", "timeout"), 60)
+ config.set(("extractor", "username"), name)
+ config.set(("extractor", "password"), name)
+ config.set(("extractor", "nijie", "username"), email)
+ config.set(("extractor", "seiga", "username"), email)
+ config.set(("extractor", "danbooru", "username"), None)
+ config.set(("extractor", "twitter" , "username"), None)
+ config.set(("extractor", "mangoxo" , "password"), "VZ8DL3983u")
+
+ config.set(("extractor", "deviantart", "client-id"), "7777")
+ config.set(("extractor", "deviantart", "client-secret"),
+ "ff14994c744d9208e5caeec7aab4a026")
+
+ config.set(("extractor", "tumblr", "api-key"),
+ "0cXoHfIqVzMQcc3HESZSNsVlulGxEXGDTTZCDrRrjaa0jmuTc6")
+ config.set(("extractor", "tumblr", "api-secret"),
+ "6wxAK2HwrXdedn7VIoZWxGqVhZ8JdYKDLjiQjL46MLqGuEtyVj")
+ config.set(("extractor", "tumblr", "access-token"),
+ "N613fPV6tOZQnyn0ERTuoEZn0mEqG8m2K8M3ClSJdEHZJuqFdG")
+ config.set(("extractor", "tumblr", "access-token-secret"),
+ "sgOA7ZTT4FBXdOGGVV331sSp0jHYp4yMDRslbhaQf7CaS71i4O")
+
+
+def generate_tests():
+ """Dynamically generate extractor unittests"""
+ def _generate_test(extr, tcase):
+ def test(self):
+ url, result = tcase
+ print("\n", url, sep="")
+ self._run_test(extr, url, result)
+ return test
+
+ # enable selective testing for direct calls
+ if __name__ == '__main__' and len(sys.argv) > 1:
+ if sys.argv[1].lower() == "all":
+ fltr = lambda c, bc: True # noqa: E731
+ elif sys.argv[1].lower() == "broken":
+ fltr = lambda c, bc: c in BROKEN # noqa: E731
+ else:
+ argv = sys.argv[1:]
+ fltr = lambda c, bc: c in argv or bc in argv # noqa: E731
+ del sys.argv[1:]
+ else:
+ skip = set(BROKEN)
+ if "CI" in os.environ and "TRAVIS" in os.environ:
+ skip |= set(TRAVIS_SKIP)
+ if skip:
+ print("skipping:", ", ".join(skip))
+ fltr = lambda c, bc: c not in skip # noqa: E731
+
+ # filter available extractor classes
+ extractors = [
+ extr for extr in extractor.extractors()
+ if fltr(extr.category, getattr(extr, "basecategory", None))
+ ]
+
+ # add 'test_...' methods
+ for extr in extractors:
+ name = "test_" + extr.__name__ + "_"
+ for num, tcase in enumerate(extr._get_tests(), 1):
+ test = _generate_test(extr, tcase)
+ test.__name__ = name + str(num)
+ setattr(TestExtractorResults, test.__name__, test)
+
+
+generate_tests()
+if __name__ == '__main__':
+ unittest.main(warnings='ignore')
diff --git a/test/test_text.py b/test/test_text.py
new file mode 100644
index 0000000..405acd3
--- /dev/null
+++ b/test/test_text.py
@@ -0,0 +1,409 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2015-2018 Mike Fährmann
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 as
+# published by the Free Software Foundation.
+
+import unittest
+import datetime
+
+from gallery_dl import text
+
+
+INVALID = ((), [], {}, None, 1, 2.3)
+INVALID_ALT = ((), [], {}, None, "")
+
+
+class TestText(unittest.TestCase):
+
+ def test_clean_xml(self, f=text.clean_xml):
+ # standard usage
+ self.assertEqual(f(""), "")
+ self.assertEqual(f("foo"), "foo")
+ self.assertEqual(f("\tfoo\nbar\r"), "\tfoo\nbar\r")
+ self.assertEqual(f("<foo>\ab\ba\fr\v</foo>"), "<foo>bar</foo>")
+
+ # 'repl' argument
+ repl = "#"
+ self.assertEqual(f("", repl), "")
+ self.assertEqual(f("foo", repl), "foo")
+ self.assertEqual(f("\tfoo\nbar\r", repl), "\tfoo\nbar\r")
+ self.assertEqual(
+ f("<foo>\ab\ba\fr\v</foo>", repl), "<foo>#b#a#r#</foo>")
+
+ # removal of all illegal control characters
+ value = "".join(chr(x) for x in range(32))
+ self.assertEqual(f(value), "\t\n\r")
+
+ # 'invalid' arguments
+ for value in INVALID:
+ self.assertEqual(f(value), "")
+
+ def test_remove_html(self, f=text.remove_html):
+ result = "Hello World."
+
+ # standard usage
+ self.assertEqual(f(""), "")
+ self.assertEqual(f("Hello World."), result)
+ self.assertEqual(f(" Hello World. "), result)
+ self.assertEqual(f("Hello<br/>World."), result)
+ self.assertEqual(
+ f("<div><b class='a'>Hello</b><i>World.</i></div>"), result)
+
+ # empty HTML
+ self.assertEqual(f("<div></div>"), "")
+ self.assertEqual(f(" <div> </div> "), "")
+
+ # malformed HTML
+ self.assertEqual(f("<div</div>"), "")
+ self.assertEqual(f("<div<Hello World.</div>"), "")
+
+ # invalid arguments
+ for value in INVALID:
+ self.assertEqual(f(value), "")
+
+ def test_split_html(self, f=text.split_html):
+ result = ["Hello", "World."]
+ empty = []
+
+ # standard usage
+ self.assertEqual(f(""), empty)
+ self.assertEqual(f("Hello World."), ["Hello World."])
+ self.assertEqual(f(" Hello World. "), ["Hello World."])
+ self.assertEqual(f("Hello<br/>World."), result)
+ self.assertEqual(f(" Hello <br/> World. "), result)
+ self.assertEqual(
+ f("<div><b class='a'>Hello</b><i>World.</i></div>"), result)
+
+ # empty HTML
+ self.assertEqual(f("<div></div>"), empty)
+ self.assertEqual(f(" <div> </div> "), empty)
+
+ # malformed HTML
+ self.assertEqual(f("<div</div>"), empty)
+ self.assertEqual(f("<div<Hello World.</div>"), empty)
+
+ # invalid arguments
+ for value in INVALID:
+ self.assertEqual(f(value), empty)
+
+ def test_filename_from_url(self, f=text.filename_from_url):
+ result = "filename.ext"
+
+ # standard usage
+ self.assertEqual(f(""), "")
+ self.assertEqual(f("filename.ext"), result)
+ self.assertEqual(f("/filename.ext"), result)
+ self.assertEqual(f("example.org/filename.ext"), result)
+ self.assertEqual(f("http://example.org/v2/filename.ext"), result)
+ self.assertEqual(
+ f("http://example.org/v2/filename.ext?param=value#frag"), result)
+
+ # invalid arguments
+ for value in INVALID:
+ self.assertEqual(f(value), "")
+
+ def test_ext_from_url(self, f=text.ext_from_url):
+ result = "ext"
+
+ # standard usage
+ self.assertEqual(f(""), "")
+ self.assertEqual(f("filename.ext"), result)
+ self.assertEqual(f("/filename.ext"), result)
+ self.assertEqual(f("example.org/filename.ext"), result)
+ self.assertEqual(f("http://example.org/v2/filename.ext"), result)
+ self.assertEqual(
+ f("http://example.org/v2/filename.ext?param=value#frag"), result)
+
+ # invalid arguments
+ for value in INVALID:
+ self.assertEqual(f(value), "")
+
+ def test_nameext_from_url(self, f=text.nameext_from_url):
+ empty = {"filename": "", "extension": ""}
+ result = {"filename": "filename", "extension": "ext"}
+
+ # standard usage
+ self.assertEqual(f(""), empty)
+ self.assertEqual(f("filename.ext"), result)
+ self.assertEqual(f("/filename.ext"), result)
+ self.assertEqual(f("example.org/filename.ext"), result)
+ self.assertEqual(f("http://example.org/v2/filename.ext"), result)
+ self.assertEqual(
+ f("http://example.org/v2/filename.ext?param=value#frag"), result)
+
+ # invalid arguments
+ for value in INVALID:
+ self.assertEqual(f(value), empty)
+
+ def test_clean_path_windows(self, f=text.clean_path_windows):
+ self.assertEqual(f(""), "")
+ self.assertEqual(f("foo"), "foo")
+ self.assertEqual(f("foo/bar"), "foo_bar")
+ self.assertEqual(f("foo<>:\"\\/|?*bar"), "foo_________bar")
+
+ # invalid arguments
+ for value in INVALID:
+ self.assertEqual(f(value), "")
+
+ def test_clean_path_posix(self, f=text.clean_path_posix):
+ self.assertEqual(f(""), "")
+ self.assertEqual(f("foo"), "foo")
+ self.assertEqual(f("foo/bar"), "foo_bar")
+ self.assertEqual(f("foo<>:\"\\/|?*bar"), "foo<>:\"\\_|?*bar")
+
+ # invalid arguments
+ for value in INVALID:
+ self.assertEqual(f(value), "")
+
+ def test_extract(self, f=text.extract):
+ txt = "<a><b>"
+ self.assertEqual(f(txt, "<", ">"), ("a" , 3))
+ self.assertEqual(f(txt, "X", ">"), (None, 0))
+ self.assertEqual(f(txt, "<", "X"), (None, 0))
+
+ # 'pos' argument
+ for i in range(1, 4):
+ self.assertEqual(f(txt, "<", ">", i), ("b", 6))
+ for i in range(4, 10):
+ self.assertEqual(f(txt, "<", ">", i), (None, i))
+
+ # invalid arguments
+ for value in INVALID:
+ self.assertEqual(f(value, "<" , ">") , (None, 0))
+ self.assertEqual(f(txt , value, ">") , (None, 0))
+ self.assertEqual(f(txt , "<" , value), (None, 0))
+
+ def test_rextract(self, f=text.rextract):
+ txt = "<a><b>"
+ self.assertEqual(f(txt, "<", ">"), ("b" , 3))
+ self.assertEqual(f(txt, "X", ">"), (None, -1))
+ self.assertEqual(f(txt, "<", "X"), (None, -1))
+
+ # 'pos' argument
+ for i in range(10, 3, -1):
+ self.assertEqual(f(txt, "<", ">", i), ("b", 3))
+ for i in range(3, 0, -1):
+ self.assertEqual(f(txt, "<", ">", i), ("a", 0))
+
+ # invalid arguments
+ for value in INVALID:
+ self.assertEqual(f(value, "<" , ">") , (None, -1))
+ self.assertEqual(f(txt , value, ">") , (None, -1))
+ self.assertEqual(f(txt , "<" , value), (None, -1))
+
+ def test_extract_all(self, f=text.extract_all):
+ txt = "[c][b][a]: xyz! [d][e"
+
+ self.assertEqual(
+ f(txt, ()), ({}, 0))
+ self.assertEqual(
+ f(txt, (("C", "[", "]"), ("B", "[", "]"), ("A", "[", "]"))),
+ ({"A": "a", "B": "b", "C": "c"}, 9),
+ )
+
+ # 'None' as field name
+ self.assertEqual(
+ f(txt, ((None, "[", "]"), (None, "[", "]"), ("A", "[", "]"))),
+ ({"A": "a"}, 9),
+ )
+ self.assertEqual(
+ f(txt, ((None, "[", "]"), (None, "[", "]"), (None, "[", "]"))),
+ ({}, 9),
+ )
+
+ # failed matches
+ self.assertEqual(
+ f(txt, (("C", "[", "]"), ("X", "X", "X"), ("B", "[", "]"))),
+ ({"B": "b", "C": "c", "X": None}, 6),
+ )
+
+ # 'pos' argument
+ self.assertEqual(
+ f(txt, (("B", "[", "]"), ("A", "[", "]")), pos=1),
+ ({"A": "a", "B": "b"}, 9),
+ )
+
+ # 'values' argument
+ self.assertEqual(
+ f(txt, (("C", "[", "]"),), values={"A": "a", "B": "b"}),
+ ({"A": "a", "B": "b", "C": "c"}, 3),
+ )
+
+ vdict = {}
+ rdict, pos = f(txt, (), values=vdict)
+ self.assertIs(vdict, rdict)
+
+ def test_extract_iter(self, f=text.extract_iter):
+ txt = "[c][b][a]: xyz! [d][e"
+
+ def g(*args):
+ return list(f(*args))
+
+ self.assertEqual(
+ g("", "[", "]"), [])
+ self.assertEqual(
+ g("[a]", "[", "]"), ["a"])
+ self.assertEqual(
+ g(txt, "[", "]"), ["c", "b", "a", "d"])
+ self.assertEqual(
+ g(txt, "X", "X"), [])
+ self.assertEqual(
+ g(txt, "[", "]", 6), ["a", "d"])
+
+ def test_extract_from(self, f=text.extract_from):
+ txt = "[c][b][a]: xyz! [d][e"
+
+ e = f(txt)
+ self.assertEqual(e("[", "]"), "c")
+ self.assertEqual(e("[", "]"), "b")
+ self.assertEqual(e("[", "]"), "a")
+ self.assertEqual(e("[", "]"), "d")
+ self.assertEqual(e("[", "]"), "")
+ self.assertEqual(e("[", "]"), "")
+
+ e = f(txt, pos=6, default="END")
+ self.assertEqual(e("[", "]"), "a")
+ self.assertEqual(e("[", "]"), "d")
+ self.assertEqual(e("[", "]"), "END")
+ self.assertEqual(e("[", "]"), "END")
+
+ def test_parse_unicode_escapes(self, f=text.parse_unicode_escapes):
+ self.assertEqual(f(""), "")
+ self.assertEqual(f("foobar"), "foobar")
+ self.assertEqual(f("foo’bar"), "foo’bar")
+ self.assertEqual(f("foo\\u2019bar"), "foo’bar")
+ self.assertEqual(f("foo\\u201bar"), "foo‛ar")
+ self.assertEqual(f("foo\\u201zar"), "foo\\u201zar")
+ self.assertEqual(
+ f("\\u2018foo\\u2019\\u2020bar\\u00ff"),
+ "‘foo’†barÿ",
+ )
+
+ def test_parse_bytes(self, f=text.parse_bytes):
+ self.assertEqual(f("0"), 0)
+ self.assertEqual(f("50"), 50)
+ self.assertEqual(f("50k"), 50 * 1024**1)
+ self.assertEqual(f("50m"), 50 * 1024**2)
+ self.assertEqual(f("50g"), 50 * 1024**3)
+ self.assertEqual(f("50t"), 50 * 1024**4)
+ self.assertEqual(f("50p"), 50 * 1024**5)
+
+ # fractions
+ self.assertEqual(f("123.456"), 123)
+ self.assertEqual(f("123.567"), 124)
+ self.assertEqual(f("0.5M"), round(0.5 * 1024**2))
+
+ # invalid arguments
+ for value in INVALID_ALT:
+ self.assertEqual(f(value), 0)
+ self.assertEqual(f("NaN"), 0)
+ self.assertEqual(f("invalid"), 0)
+ self.assertEqual(f(" 123 kb "), 0)
+
+ def test_parse_int(self, f=text.parse_int):
+ self.assertEqual(f(0), 0)
+ self.assertEqual(f("0"), 0)
+ self.assertEqual(f(123), 123)
+ self.assertEqual(f("123"), 123)
+
+ # invalid arguments
+ for value in INVALID_ALT:
+ self.assertEqual(f(value), 0)
+ self.assertEqual(f("123.456"), 0)
+ self.assertEqual(f("zzz"), 0)
+ self.assertEqual(f([1, 2, 3]), 0)
+ self.assertEqual(f({1: 2, 3: 4}), 0)
+
+ # 'default' argument
+ default = "default"
+ for value in INVALID_ALT:
+ self.assertEqual(f(value, default), default)
+ self.assertEqual(f("zzz", default), default)
+
+ def test_parse_float(self, f=text.parse_float):
+ self.assertEqual(f(0), 0.0)
+ self.assertEqual(f("0"), 0.0)
+ self.assertEqual(f(123), 123.0)
+ self.assertEqual(f("123"), 123.0)
+ self.assertEqual(f(123.456), 123.456)
+ self.assertEqual(f("123.456"), 123.456)
+
+ # invalid arguments
+ for value in INVALID_ALT:
+ self.assertEqual(f(value), 0.0)
+ self.assertEqual(f("zzz"), 0.0)
+ self.assertEqual(f([1, 2, 3]), 0.0)
+ self.assertEqual(f({1: 2, 3: 4}), 0.0)
+
+ # 'default' argument
+ default = "default"
+ for value in INVALID_ALT:
+ self.assertEqual(f(value, default), default)
+ self.assertEqual(f("zzz", default), default)
+
+ def test_parse_query(self, f=text.parse_query):
+ # standard usage
+ self.assertEqual(f(""), {})
+ self.assertEqual(f("foo=1"), {"foo": "1"})
+ self.assertEqual(f("foo=1&bar=2"), {"foo": "1", "bar": "2"})
+
+ # missing value
+ self.assertEqual(f("bar"), {})
+ self.assertEqual(f("foo=1&bar"), {"foo": "1"})
+ self.assertEqual(f("foo=1&bar&baz=3"), {"foo": "1", "baz": "3"})
+
+ # keys with identical names
+ self.assertEqual(f("foo=1&foo=2"), {"foo": "1"})
+ self.assertEqual(
+ f("foo=1&bar=2&foo=3&bar=4"),
+ {"foo": "1", "bar": "2"},
+ )
+
+ # invalid arguments
+ for value in INVALID:
+ self.assertEqual(f(value), {})
+
+ def test_parse_timestamp(self, f=text.parse_timestamp):
+ null = datetime.datetime.utcfromtimestamp(0)
+ value = datetime.datetime.utcfromtimestamp(1555816235)
+
+ self.assertEqual(f(0) , null)
+ self.assertEqual(f("0") , null)
+ self.assertEqual(f(1555816235) , value)
+ self.assertEqual(f("1555816235"), value)
+
+ for value in INVALID_ALT:
+ self.assertEqual(f(value), None)
+ self.assertEqual(f(value, "foo"), "foo")
+
+ def test_parse_datetime(self, f=text.parse_datetime):
+ null = datetime.datetime.utcfromtimestamp(0)
+
+ self.assertEqual(f("1970-01-01T00:00:00+00:00"), null)
+ self.assertEqual(f("1970-01-01T00:00:00+0000") , null)
+ self.assertEqual(f("1970.01.01", "%Y.%m.%d") , null)
+
+ self.assertEqual(
+ f("2019-05-07T21:25:02+09:00"),
+ datetime.datetime(2019, 5, 7, 12, 25, 2),
+ )
+ self.assertEqual(
+ f("2019-05-07T21:25:02+0900"),
+ datetime.datetime(2019, 5, 7, 12, 25, 2),
+ )
+ self.assertEqual(
+ f("2019-05-07 21:25:02"),
+ "2019-05-07 21:25:02",
+ )
+
+ for value in INVALID:
+ self.assertEqual(f(value), None)
+ self.assertEqual(f("1970.01.01"), "1970.01.01")
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/test/test_util.py b/test/test_util.py
new file mode 100644
index 0000000..815b2d8
--- /dev/null
+++ b/test/test_util.py
@@ -0,0 +1,395 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2015-2019 Mike Fährmann
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 as
+# published by the Free Software Foundation.
+
+import unittest
+import sys
+import random
+import string
+
+from gallery_dl import util, text, exception
+
+
+class TestRange(unittest.TestCase):
+
+ def test_parse_range(self, f=util.RangePredicate.parse_range):
+ self.assertEqual(
+ f(""),
+ [])
+ self.assertEqual(
+ f("1-2"),
+ [(1, 2)])
+ self.assertEqual(
+ f("-"),
+ [(1, sys.maxsize)])
+ self.assertEqual(
+ f("-2,4,6-8,10-"),
+ [(1, 2), (4, 4), (6, 8), (10, sys.maxsize)])
+ self.assertEqual(
+ f(" - 3 , 4- 4, 2-6"),
+ [(1, 3), (4, 4), (2, 6)])
+
+ def test_optimize_range(self, f=util.RangePredicate.optimize_range):
+ self.assertEqual(
+ f([]),
+ [])
+ self.assertEqual(
+ f([(2, 4)]),
+ [(2, 4)])
+ self.assertEqual(
+ f([(2, 4), (6, 8), (10, 12)]),
+ [(2, 4), (6, 8), (10, 12)])
+ self.assertEqual(
+ f([(2, 4), (4, 6), (5, 8)]),
+ [(2, 8)])
+ self.assertEqual(
+ f([(1, 1), (2, 2), (3, 6), (8, 9)]),
+ [(1, 6), (8, 9)])
+
+
+class TestPredicate(unittest.TestCase):
+
+ def test_range_predicate(self):
+ dummy = None
+
+ pred = util.RangePredicate(" - 3 , 4- 4, 2-6")
+ for i in range(6):
+ self.assertTrue(pred(dummy, dummy))
+ with self.assertRaises(exception.StopExtraction):
+ bool(pred(dummy, dummy))
+
+ pred = util.RangePredicate("1, 3, 5")
+ self.assertTrue(pred(dummy, dummy))
+ self.assertFalse(pred(dummy, dummy))
+ self.assertTrue(pred(dummy, dummy))
+ self.assertFalse(pred(dummy, dummy))
+ self.assertTrue(pred(dummy, dummy))
+ with self.assertRaises(exception.StopExtraction):
+ bool(pred(dummy, dummy))
+
+ pred = util.RangePredicate("")
+ with self.assertRaises(exception.StopExtraction):
+ bool(pred(dummy, dummy))
+
+ def test_unique_predicate(self):
+ dummy = None
+ pred = util.UniquePredicate()
+
+ # no duplicates
+ self.assertTrue(pred("1", dummy))
+ self.assertTrue(pred("2", dummy))
+ self.assertFalse(pred("1", dummy))
+ self.assertFalse(pred("2", dummy))
+ self.assertTrue(pred("3", dummy))
+ self.assertFalse(pred("3", dummy))
+
+ # duplicates for "text:"
+ self.assertTrue(pred("text:123", dummy))
+ self.assertTrue(pred("text:123", dummy))
+ self.assertTrue(pred("text:123", dummy))
+
+ def test_filter_predicate(self):
+ url = ""
+
+ pred = util.FilterPredicate("a < 3")
+ self.assertTrue(pred(url, {"a": 2}))
+ self.assertFalse(pred(url, {"a": 3}))
+
+ with self.assertRaises(SyntaxError):
+ util.FilterPredicate("(")
+
+ with self.assertRaises(exception.FilterError):
+ util.FilterPredicate("a > 1")(url, {"a": None})
+
+ with self.assertRaises(exception.FilterError):
+ util.FilterPredicate("b > 1")(url, {"a": 2})
+
+ def test_build_predicate(self):
+ pred = util.build_predicate([])
+ self.assertIsInstance(pred, type(lambda: True))
+
+ pred = util.build_predicate([util.UniquePredicate()])
+ self.assertIsInstance(pred, util.UniquePredicate)
+
+ pred = util.build_predicate([util.UniquePredicate(),
+ util.UniquePredicate()])
+ self.assertIsInstance(pred, util.ChainPredicate)
+
+
+class TestISO639_1(unittest.TestCase):
+
+ def test_code_to_language(self):
+ d = "default"
+ self._run_test(util.code_to_language, {
+ ("en",): "English",
+ ("FR",): "French",
+ ("xx",): None,
+ ("" ,): None,
+ (None,): None,
+ ("en", d): "English",
+ ("FR", d): "French",
+ ("xx", d): d,
+ ("" , d): d,
+ (None, d): d,
+ })
+
+ def test_language_to_code(self):
+ d = "default"
+ self._run_test(util.language_to_code, {
+ ("English",): "en",
+ ("fRENch",): "fr",
+ ("xx",): None,
+ ("" ,): None,
+ (None,): None,
+ ("English", d): "en",
+ ("fRENch", d): "fr",
+ ("xx", d): d,
+ ("" , d): d,
+ (None, d): d,
+ })
+
+ def _run_test(self, func, tests):
+ for args, result in tests.items():
+ self.assertEqual(func(*args), result)
+
+
+class TestFormatter(unittest.TestCase):
+
+ kwdict = {
+ "a": "hElLo wOrLd",
+ "b": "äöü",
+ "l": ["a", "b", "c"],
+ "n": None,
+ "u": "%27%3C%20/%20%3E%27",
+ "name": "Name",
+ "title1": "Title",
+ "title2": "",
+ "title3": None,
+ "title4": 0,
+ }
+
+ def test_conversions(self):
+ self._run_test("{a!l}", "hello world")
+ self._run_test("{a!u}", "HELLO WORLD")
+ self._run_test("{a!c}", "Hello world")
+ self._run_test("{a!C}", "Hello World")
+ self._run_test("{a!U}", self.kwdict["a"])
+ self._run_test("{u!U}", "'< / >'")
+ self._run_test("{a!s}", self.kwdict["a"])
+ self._run_test("{a!r}", "'" + self.kwdict["a"] + "'")
+ self._run_test("{a!a}", "'" + self.kwdict["a"] + "'")
+ self._run_test("{b!a}", "'\\xe4\\xf6\\xfc'")
+ self._run_test("{a!S}", self.kwdict["a"])
+ self._run_test("{l!S}", "a, b, c")
+ self._run_test("{n!S}", "")
+ with self.assertRaises(KeyError):
+ self._run_test("{a!q}", "hello world")
+
+ def test_optional(self):
+ self._run_test("{name}{title1}", "NameTitle")
+ self._run_test("{name}{title1:?//}", "NameTitle")
+ self._run_test("{name}{title1:? **/''/}", "Name **Title''")
+
+ self._run_test("{name}{title2}", "Name")
+ self._run_test("{name}{title2:?//}", "Name")
+ self._run_test("{name}{title2:? **/''/}", "Name")
+
+ self._run_test("{name}{title3}", "NameNone")
+ self._run_test("{name}{title3:?//}", "Name")
+ self._run_test("{name}{title3:? **/''/}", "Name")
+
+ self._run_test("{name}{title4}", "Name0")
+ self._run_test("{name}{title4:?//}", "Name")
+ self._run_test("{name}{title4:? **/''/}", "Name")
+
+ def test_missing(self):
+ replacement = "None"
+
+ self._run_test("{missing}", replacement)
+ self._run_test("{missing.attr}", replacement)
+ self._run_test("{missing[key]}", replacement)
+ self._run_test("{missing:?a//}", "")
+
+ self._run_test("{name[missing]}", replacement)
+ self._run_test("{name[missing].attr}", replacement)
+ self._run_test("{name[missing][key]}", replacement)
+ self._run_test("{name[missing]:?a//}", "")
+
+ def test_missing_custom_default(self):
+ replacement = default = "foobar"
+ self._run_test("{missing}" , replacement, default)
+ self._run_test("{missing.attr}", replacement, default)
+ self._run_test("{missing[key]}", replacement, default)
+ self._run_test("{missing:?a//}", "a" + default, default)
+
+ def test_slicing(self):
+ v = self.kwdict["a"]
+ self._run_test("{a[1:10]}" , v[1:10])
+ self._run_test("{a[-10:-1]}", v[-10:-1])
+ self._run_test("{a[5:]}" , v[5:])
+ self._run_test("{a[50:]}", v[50:])
+ self._run_test("{a[:5]}" , v[:5])
+ self._run_test("{a[:50]}", v[:50])
+ self._run_test("{a[:]}" , v)
+ self._run_test("{a[1:10:2]}" , v[1:10:2])
+ self._run_test("{a[-10:-1:2]}", v[-10:-1:2])
+ self._run_test("{a[5::2]}" , v[5::2])
+ self._run_test("{a[50::2]}", v[50::2])
+ self._run_test("{a[:5:2]}" , v[:5:2])
+ self._run_test("{a[:50:2]}", v[:50:2])
+ self._run_test("{a[::]}" , v)
+
+ def test_maxlen(self):
+ v = self.kwdict["a"]
+ self._run_test("{a:L5/foo/}" , "foo")
+ self._run_test("{a:L50/foo/}", v)
+ self._run_test("{a:L50/foo/>50}", " " * 39 + v)
+ self._run_test("{a:L50/foo/>51}", "foo")
+ self._run_test("{a:Lab/foo/}", "foo")
+
+ def test_join(self):
+ self._run_test("{l:J}" , "abc")
+ self._run_test("{l:J,}" , "a,b,c")
+ self._run_test("{l:J,/}" , "a,b,c")
+ self._run_test("{l:J,/>20}" , " a,b,c")
+ self._run_test("{l:J - }" , "a - b - c")
+ self._run_test("{l:J - /}" , "a - b - c")
+ self._run_test("{l:J - />20}", " a - b - c")
+
+ self._run_test("{a:J/}" , self.kwdict["a"])
+ self._run_test("{a:J, /}" , ", ".join(self.kwdict["a"]))
+
+ def test_replace(self):
+ self._run_test("{a:Rh/C/}" , "CElLo wOrLd")
+ self._run_test("{a!l:Rh/C/}", "Cello world")
+ self._run_test("{a!u:Rh/C/}", "HELLO WORLD")
+
+ self._run_test("{a!l:Rl/_/}", "he__o wor_d")
+ self._run_test("{a!l:Rl//}" , "heo word")
+ self._run_test("{name:Rame/othing/}", "Nothing")
+
+ def _run_test(self, format_string, result, default=None):
+ formatter = util.Formatter(format_string, default)
+ output = formatter.format_map(self.kwdict)
+ self.assertEqual(output, result, format_string)
+
+
+class TestOther(unittest.TestCase):
+
+ def test_bencode(self):
+ self.assertEqual(util.bencode(0), "")
+ self.assertEqual(util.bencode(123), "123")
+ self.assertEqual(util.bencode(123, "01"), "1111011")
+ self.assertEqual(util.bencode(123, "BA"), "AAAABAA")
+
+ def test_bdecode(self):
+ self.assertEqual(util.bdecode(""), 0)
+ self.assertEqual(util.bdecode("123"), 123)
+ self.assertEqual(util.bdecode("1111011", "01"), 123)
+ self.assertEqual(util.bdecode("AAAABAA", "BA"), 123)
+
+ def test_bencode_bdecode(self):
+ for _ in range(100):
+ value = random.randint(0, 1000000)
+ for alphabet in ("01", "0123456789", string.ascii_letters):
+ result = util.bdecode(util.bencode(value, alphabet), alphabet)
+ self.assertEqual(result, value)
+
+ def test_advance(self):
+ items = range(5)
+
+ self.assertCountEqual(
+ util.advance(items, 0), items)
+ self.assertCountEqual(
+ util.advance(items, 3), range(3, 5))
+ self.assertCountEqual(
+ util.advance(items, 9), [])
+ self.assertCountEqual(
+ util.advance(util.advance(items, 1), 2), range(3, 5))
+
+ def test_raises(self):
+ func = util.raises(Exception())
+ with self.assertRaises(Exception):
+ func()
+
+ func = util.raises(ValueError(1))
+ with self.assertRaises(ValueError):
+ func()
+ with self.assertRaises(ValueError):
+ func()
+ with self.assertRaises(ValueError):
+ func()
+
+ def test_combine_dict(self):
+ self.assertEqual(
+ util.combine_dict({}, {}),
+ {})
+ self.assertEqual(
+ util.combine_dict({1: 1, 2: 2}, {2: 4, 4: 8}),
+ {1: 1, 2: 4, 4: 8})
+ self.assertEqual(
+ util.combine_dict(
+ {1: {11: 22, 12: 24}, 2: {13: 26, 14: 28}},
+ {1: {11: 33, 13: 39}, 2: "str"}),
+ {1: {11: 33, 12: 24, 13: 39}, 2: "str"})
+ self.assertEqual(
+ util.combine_dict(
+ {1: {2: {3: {4: {"1": "a", "2": "b"}}}}},
+ {1: {2: {3: {4: {"1": "A", "3": "C"}}}}}),
+ {1: {2: {3: {4: {"1": "A", "2": "b", "3": "C"}}}}})
+
+ def test_transform_dict(self):
+ d = {}
+ util.transform_dict(d, str)
+ self.assertEqual(d, {})
+
+ d = {1: 123, 2: "123", 3: True, 4: None}
+ util.transform_dict(d, str)
+ self.assertEqual(
+ d, {1: "123", 2: "123", 3: "True", 4: "None"})
+
+ d = {1: 123, 2: "123", 3: "foo", 4: {11: 321, 12: "321", 13: "bar"}}
+ util.transform_dict(d, text.parse_int)
+ self.assertEqual(
+ d, {1: 123, 2: 123, 3: 0, 4: {11: 321, 12: 321, 13: 0}})
+
+ def test_number_to_string(self, f=util.number_to_string):
+ self.assertEqual(f(1) , "1")
+ self.assertEqual(f(1.0) , "1.0")
+ self.assertEqual(f("1.0") , "1.0")
+ self.assertEqual(f([1]) , [1])
+ self.assertEqual(f({1: 2}), {1: 2})
+ self.assertEqual(f(True) , True)
+ self.assertEqual(f(None) , None)
+
+ def test_to_string(self, f=util.to_string):
+ self.assertEqual(f(1) , "1")
+ self.assertEqual(f(1.0) , "1.0")
+ self.assertEqual(f("1.0"), "1.0")
+
+ self.assertEqual(f("") , "")
+ self.assertEqual(f(None) , "")
+ self.assertEqual(f(0) , "")
+
+ self.assertEqual(f(["a"]), "a")
+ self.assertEqual(f([1]) , "1")
+ self.assertEqual(f(["a", "b", "c"]), "a, b, c")
+ self.assertEqual(f([1, 2, 3]), "1, 2, 3")
+
+ def test_universal_none(self):
+ obj = util.NONE
+
+ self.assertFalse(obj)
+ self.assertEqual(str(obj), str(None))
+ self.assertEqual(repr(obj), repr(None))
+ self.assertIs(obj.attr, obj)
+ self.assertIs(obj["key"], obj)
+
+
+if __name__ == '__main__':
+ unittest.main()