From 195c45911e79c33cf0bb986721365fb06df5a153 Mon Sep 17 00:00:00 2001 From: Unit 193 Date: Tue, 2 Jul 2019 04:33:45 -0400 Subject: Import Upstream version 1.8.7 --- test/__init__.py | 0 test/test_config.py | 81 ++++++++++ test/test_cookies.py | 130 +++++++++++++++ test/test_downloader.py | 235 ++++++++++++++++++++++++++++ test/test_extractor.py | 186 ++++++++++++++++++++++ test/test_oauth.py | 104 ++++++++++++ test/test_results.py | 344 ++++++++++++++++++++++++++++++++++++++++ test/test_text.py | 409 ++++++++++++++++++++++++++++++++++++++++++++++++ test/test_util.py | 395 ++++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 1884 insertions(+) create mode 100644 test/__init__.py create mode 100644 test/test_config.py create mode 100644 test/test_cookies.py create mode 100644 test/test_downloader.py create mode 100644 test/test_extractor.py create mode 100644 test/test_oauth.py create mode 100644 test/test_results.py create mode 100644 test/test_text.py create mode 100644 test/test_util.py (limited to 'test') diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/test_config.py b/test/test_config.py new file mode 100644 index 0000000..8cdb3da --- /dev/null +++ b/test/test_config.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2015-2017 Mike Fährmann +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 as +# published by the Free Software Foundation. + +import unittest +import gallery_dl.config as config +import os +import tempfile + + +class TestConfig(unittest.TestCase): + + def setUp(self): + fd, self._configfile = tempfile.mkstemp() + with os.fdopen(fd, "w") as file: + file.write('{"a": "1", "b": {"a": 2, "c": "text"}}') + config.load((self._configfile,)) + + def tearDown(self): + config.clear() + os.remove(self._configfile) + + def test_get(self): + self.assertEqual(config.get(["a"]), "1") + self.assertEqual(config.get(["b", "c"]), "text") + self.assertEqual(config.get(["d"]), None) + self.assertEqual(config.get(["e", "f", "g"], 123), 123) + + def test_interpolate(self): + self.assertEqual(config.interpolate(["a"]), "1") + self.assertEqual(config.interpolate(["b", "a"]), "1") + self.assertEqual(config.interpolate(["b", "c"], "2"), "text") + self.assertEqual(config.interpolate(["b", "d"], "2"), "2") + config.set(["d"], 123) + self.assertEqual(config.interpolate(["b", "d"], "2"), 123) + self.assertEqual(config.interpolate(["d", "d"], "2"), 123) + + def test_set(self): + config.set(["b", "c"], [1, 2, 3]) + config.set(["e", "f", "g"], value=234) + self.assertEqual(config.get(["b", "c"]), [1, 2, 3]) + self.assertEqual(config.get(["e", "f", "g"]), 234) + + def test_setdefault(self): + config.setdefault(["b", "c"], [1, 2, 3]) + config.setdefault(["e", "f", "g"], value=234) + self.assertEqual(config.get(["b", "c"]), "text") + self.assertEqual(config.get(["e", "f", "g"]), 234) + + def test_unset(self): + config.unset(["a"]) + config.unset(["b", "c"]) + config.unset(["c", "d"]) + self.assertEqual(config.get(["a"]), None) + self.assertEqual(config.get(["b", "a"]), 2) + self.assertEqual(config.get(["b", "c"]), None) + + def test_apply(self): + options = ( + (["b", "c"], [1, 2, 3]), + (["e", "f", "g"], 234), + ) + + self.assertEqual(config.get(["b", "c"]), "text") + self.assertEqual(config.get(["e", "f", "g"]), None) + + with config.apply(options): + self.assertEqual(config.get(["b", "c"]), [1, 2, 3]) + self.assertEqual(config.get(["e", "f", "g"]), 234) + + self.assertEqual(config.get(["b", "c"]), "text") + self.assertEqual(config.get(["e", "f", "g"]), None) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_cookies.py b/test/test_cookies.py new file mode 100644 index 0000000..a786df6 --- /dev/null +++ b/test/test_cookies.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2017 Mike Fährmann +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 as +# published by the Free Software Foundation. + +import unittest +from unittest import mock + +import logging +import tempfile +import http.cookiejar +from os.path import join + +import gallery_dl.config as config +import gallery_dl.extractor as extractor + +CKEY = ("cookies",) + + +class TestCookiejar(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.path = tempfile.TemporaryDirectory() + + cls.cookiefile = join(cls.path.name, "cookies.txt") + with open(cls.cookiefile, "w") as file: + file.write("""# HTTP Cookie File +.example.org\tTRUE\t/\tFALSE\t253402210800\tNAME\tVALUE +""") + + cls.invalid_cookiefile = join(cls.path.name, "invalid.txt") + with open(cls.invalid_cookiefile, "w") as file: + file.write("""# asd +.example.org\tTRUE\t/\tFALSE\t253402210800\tNAME\tVALUE +""") + + @classmethod + def tearDownClass(cls): + cls.path.cleanup() + config.clear() + + def test_cookiefile(self): + config.set(CKEY, self.cookiefile) + + cookies = extractor.find("test:").session.cookies + self.assertEqual(len(cookies), 1) + + cookie = next(iter(cookies)) + self.assertEqual(cookie.domain, ".example.org") + self.assertEqual(cookie.path , "/") + self.assertEqual(cookie.name , "NAME") + self.assertEqual(cookie.value , "VALUE") + + def test_invalid_cookiefile(self): + self._test_warning(self.invalid_cookiefile, http.cookiejar.LoadError) + + def test_invalid_filename(self): + self._test_warning(join(self.path.name, "nothing"), FileNotFoundError) + + def _test_warning(self, filename, exc): + config.set(CKEY, filename) + log = logging.getLogger("test") + with mock.patch.object(log, "warning") as mock_warning: + cookies = extractor.find("test:").session.cookies + self.assertEqual(len(cookies), 0) + self.assertEqual(mock_warning.call_count, 1) + self.assertEqual(mock_warning.call_args[0][0], "cookies: %s") + self.assertIsInstance(mock_warning.call_args[0][1], exc) + + +class TestCookiedict(unittest.TestCase): + + def setUp(self): + self.cdict = {"NAME1": "VALUE1", "NAME2": "VALUE2"} + config.set(CKEY, self.cdict) + + def tearDown(self): + config.clear() + + def test_dict(self): + cookies = extractor.find("test:").session.cookies + self.assertEqual(len(cookies), len(self.cdict)) + self.assertEqual(sorted(cookies.keys()), sorted(self.cdict.keys())) + self.assertEqual(sorted(cookies.values()), sorted(self.cdict.values())) + + def test_domain(self): + for category in ["exhentai", "nijie", "sankaku", "seiga"]: + extr = _get_extractor(category) + cookies = extr.session.cookies + for key in self.cdict: + self.assertTrue(key in cookies) + for c in cookies: + self.assertEqual(c.domain, extr.cookiedomain) + + +class TestCookieLogin(unittest.TestCase): + + def tearDown(self): + config.clear() + + def test_cookie_login(self): + extr_cookies = { + "exhentai": ("ipb_member_id", "ipb_pass_hash"), + "nijie" : ("nemail", "nlogin"), + "sankaku" : ("login", "pass_hash"), + "seiga" : ("user_session",), + } + for category, cookienames in extr_cookies.items(): + cookies = {name: "value" for name in cookienames} + config.set(CKEY, cookies) + extr = _get_extractor(category) + with mock.patch.object(extr, "_login_impl") as mock_login: + extr.login() + mock_login.assert_not_called() + + +def _get_extractor(category): + for extr in extractor.extractors(): + if extr.category == category and hasattr(extr, "_login_impl"): + url = next(extr._get_tests())[0] + return extr.from_url(url) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_downloader.py b/test/test_downloader.py new file mode 100644 index 0000000..3f301b0 --- /dev/null +++ b/test/test_downloader.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2018 Mike Fährmann +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 as +# published by the Free Software Foundation. + +import re +import base64 +import os.path +import tempfile +import unittest +import threading +import http.server + +import gallery_dl.downloader as downloader +import gallery_dl.extractor as extractor +import gallery_dl.config as config +from gallery_dl.downloader.common import DownloaderBase +from gallery_dl.output import NullOutput +from gallery_dl.util import PathFormat + + +class TestDownloaderBase(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.extractor = extractor.find("test:") + cls.dir = tempfile.TemporaryDirectory() + cls.fnum = 0 + config.set(("base-directory",), cls.dir.name) + + @classmethod + def tearDownClass(cls): + cls.dir.cleanup() + config.clear() + + @classmethod + def _prepare_destination(cls, content=None, part=True, extension=None): + name = "file-{}".format(cls.fnum) + cls.fnum += 1 + + kwdict = { + "category": "test", + "subcategory": "test", + "filename": name, + "extension": extension, + } + pathfmt = PathFormat(cls.extractor) + pathfmt.set_directory(kwdict) + pathfmt.set_keywords(kwdict) + + if content: + mode = "w" + ("b" if isinstance(content, bytes) else "") + with pathfmt.open(mode) as file: + file.write(content) + + return pathfmt + + def _run_test(self, url, input, output, + extension, expected_extension=None): + pathfmt = self._prepare_destination(input, extension=extension) + success = self.downloader.download(url, pathfmt) + + # test successful download + self.assertTrue(success, "downloading '{}' failed".format(url)) + + # test content + mode = "r" + ("b" if isinstance(output, bytes) else "") + with pathfmt.open(mode) as file: + content = file.read() + self.assertEqual(content, output) + + # test filename extension + self.assertEqual( + pathfmt.keywords["extension"], + expected_extension, + ) + self.assertEqual( + os.path.splitext(pathfmt.realpath)[1][1:], + expected_extension, + ) + + +class TestHTTPDownloader(TestDownloaderBase): + + @classmethod + def setUpClass(cls): + TestDownloaderBase.setUpClass() + cls.downloader = downloader.find("http")(cls.extractor, NullOutput()) + + port = 8088 + cls.address = "http://127.0.0.1:{}".format(port) + cls._jpg = cls.address + "/image.jpg" + cls._png = cls.address + "/image.png" + cls._gif = cls.address + "/image.gif" + + server = http.server.HTTPServer(("", port), HttpRequestHandler) + threading.Thread(target=server.serve_forever, daemon=True).start() + + def test_http_download(self): + self._run_test(self._jpg, None, DATA_JPG, "jpg", "jpg") + self._run_test(self._png, None, DATA_PNG, "png", "png") + self._run_test(self._gif, None, DATA_GIF, "gif", "gif") + + def test_http_offset(self): + self._run_test(self._jpg, DATA_JPG[:123], DATA_JPG, "jpg", "jpg") + self._run_test(self._png, DATA_PNG[:12] , DATA_PNG, "png", "png") + self._run_test(self._gif, DATA_GIF[:1] , DATA_GIF, "gif", "gif") + + def test_http_extension(self): + self._run_test(self._jpg, None, DATA_JPG, None, "jpg") + self._run_test(self._png, None, DATA_PNG, None, "png") + self._run_test(self._gif, None, DATA_GIF, None, "gif") + + def test_http_adjust_extension(self): + self._run_test(self._jpg, None, DATA_JPG, "png", "jpg") + self._run_test(self._png, None, DATA_PNG, "gif", "png") + self._run_test(self._gif, None, DATA_GIF, "jpg", "gif") + + +class TestTextDownloader(TestDownloaderBase): + + @classmethod + def setUpClass(cls): + TestDownloaderBase.setUpClass() + cls.downloader = downloader.find("text")(cls.extractor, NullOutput()) + + def test_text_download(self): + self._run_test("text:foobar", None, "foobar", "txt", "txt") + + def test_text_offset(self): + self._run_test("text:foobar", "foo", "foobar", "txt", "txt") + + def test_text_extension(self): + self._run_test("text:foobar", None, "foobar", None, "txt") + + def test_text_empty(self): + self._run_test("text:", None, "", "txt", "txt") + + +class FakeDownloader(DownloaderBase): + scheme = "fake" + + def __init__(self, extractor, output): + DownloaderBase.__init__(self, extractor, output) + + def connect(self, url, offset): + pass + + def receive(self, file): + pass + + def reset(self): + pass + + def get_extension(self): + pass + + @staticmethod + def _check_extension(file, pathfmt): + pass + + +class HttpRequestHandler(http.server.BaseHTTPRequestHandler): + + def do_GET(self): + if self.path == "/image.jpg": + content_type = "image/jpeg" + output = DATA_JPG + elif self.path == "/image.png": + content_type = "image/png" + output = DATA_PNG + elif self.path == "/image.gif": + content_type = "image/gif" + output = DATA_GIF + else: + self.send_response(404) + self.wfile.write(self.path.encode()) + return + + headers = { + "Content-Type": content_type, + "Content-Length": len(output), + } + + if "Range" in self.headers: + status = 206 + + match = re.match(r"bytes=(\d+)-", self.headers["Range"]) + start = int(match.group(1)) + + headers["Content-Range"] = "bytes {}-{}/{}".format( + start, len(output)-1, len(output)) + output = output[start:] + else: + status = 200 + + self.send_response(status) + for key, value in headers.items(): + self.send_header(key, value) + self.end_headers() + self.wfile.write(output) + + +DATA_JPG = base64.standard_b64decode(""" +/9j/4AAQSkZJRgABAQEASABIAAD/2wBD +AAEBAQEBAQEBAQEBAQEBAQEBAQEBAQEB +AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEB +AQEBAQEBAQEBAQEBAQEBAQH/2wBDAQEB +AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEB +AQEBAQEBAQEBAQEBAQEBAQEBAQEBAQEB +AQEBAQEBAQEBAQEBAQH/wAARCAABAAED +AREAAhEBAxEB/8QAFAABAAAAAAAAAAAA +AAAAAAAACv/EABQQAQAAAAAAAAAAAAAA +AAAAAAD/xAAUAQEAAAAAAAAAAAAAAAAA +AAAA/8QAFBEBAAAAAAAAAAAAAAAAAAAA +AP/aAAwDAQACEQMRAD8AfwD/2Q==""") + + +DATA_PNG = base64.standard_b64decode(""" +iVBORw0KGgoAAAANSUhEUgAAAAEAAAAB +CAAAAAA6fptVAAAACklEQVQIHWP4DwAB +AQEANl9ngAAAAABJRU5ErkJggg==""") + + +DATA_GIF = base64.standard_b64decode(""" +R0lGODdhAQABAIAAAP///////ywAAAAA +AQABAAACAkQBADs=""") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_extractor.py b/test/test_extractor.py new file mode 100644 index 0000000..fa0709b --- /dev/null +++ b/test/test_extractor.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2018-2019 Mike Fährmann +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 as +# published by the Free Software Foundation. + +import sys +import unittest +import string + +from gallery_dl import extractor +from gallery_dl.extractor.common import Extractor, Message +from gallery_dl.extractor.directlink import DirectlinkExtractor as DLExtractor + + +class FakeExtractor(Extractor): + category = "fake" + subcategory = "test" + pattern = "fake:" + + def items(self): + yield Message.Version, 1 + yield Message.Url, "text:foobar", {} + + +class TestExtractor(unittest.TestCase): + VALID_URIS = ( + "https://example.org/file.jpg", + "tumblr:foobar", + "oauth:flickr", + "test:pixiv:", + "recursive:https://example.org/document.html", + ) + + def setUp(self): + extractor._cache.clear() + extractor._module_iter = iter(extractor.modules) + + def test_find(self): + for uri in self.VALID_URIS: + result = extractor.find(uri) + self.assertIsInstance(result, Extractor, uri) + + for not_found in ("", "/tmp/file.ext"): + self.assertIsNone(extractor.find(not_found)) + + for invalid in (None, [], {}, 123, b"test:"): + with self.assertRaises(TypeError): + extractor.find(invalid) + + def test_add(self): + uri = "fake:foobar" + self.assertIsNone(extractor.find(uri)) + + extractor.add(FakeExtractor) + self.assertIsInstance(extractor.find(uri), FakeExtractor) + + def test_add_module(self): + uri = "fake:foobar" + self.assertIsNone(extractor.find(uri)) + + classes = extractor.add_module(sys.modules[__name__]) + self.assertEqual(len(classes), 1) + self.assertEqual(classes[0].pattern, FakeExtractor.pattern) + self.assertEqual(classes[0], FakeExtractor) + self.assertIsInstance(extractor.find(uri), FakeExtractor) + + def test_blacklist(self): + link_uri = "https://example.org/file.jpg" + test_uri = "test:" + fake_uri = "fake:" + + self.assertIsInstance(extractor.find(link_uri), DLExtractor) + self.assertIsInstance(extractor.find(test_uri), Extractor) + self.assertIsNone(extractor.find(fake_uri)) + + with extractor.blacklist(["directlink"]): + self.assertIsNone(extractor.find(link_uri)) + self.assertIsInstance(extractor.find(test_uri), Extractor) + self.assertIsNone(extractor.find(fake_uri)) + + with extractor.blacklist([], [DLExtractor, FakeExtractor]): + self.assertIsNone(extractor.find(link_uri)) + self.assertIsInstance(extractor.find(test_uri), Extractor) + self.assertIsNone(extractor.find(fake_uri)) + + with extractor.blacklist(["test"], [DLExtractor]): + self.assertIsNone(extractor.find(link_uri)) + self.assertIsNone(extractor.find(test_uri)) + self.assertIsNone(extractor.find(fake_uri)) + + def test_from_url(self): + for uri in self.VALID_URIS: + cls = extractor.find(uri).__class__ + extr = cls.from_url(uri) + self.assertIs(type(extr), cls) + self.assertIsInstance(extr, Extractor) + + for not_found in ("", "/tmp/file.ext"): + self.assertIsNone(FakeExtractor.from_url(not_found)) + + for invalid in (None, [], {}, 123, b"test:"): + with self.assertRaises(TypeError): + FakeExtractor.from_url(invalid) + + def test_unique_pattern_matches(self): + test_urls = [] + + # collect testcase URLs + for extr in extractor.extractors(): + for testcase in extr._get_tests(): + test_urls.append((testcase[0], extr)) + + # iterate over all testcase URLs + for url, extr1 in test_urls: + matches = [] + + # ... and apply all regex patterns to each one + for extr2 in extractor._cache: + + # skip DirectlinkExtractor pattern if it isn't tested + if extr1 != DLExtractor and extr2 == DLExtractor: + continue + + match = extr2.pattern.match(url) + if match: + matches.append(match) + + # fail if more or less than 1 match happened + if len(matches) > 1: + msg = "'{}' gets matched by more than one pattern:".format(url) + for match in matches: + msg += "\n- " + msg += match.re.pattern + self.fail(msg) + + if len(matches) < 1: + msg = "'{}' isn't matched by any pattern".format(url) + self.fail(msg) + + def test_docstrings(self): + """ensure docstring uniqueness""" + for extr1 in extractor.extractors(): + for extr2 in extractor.extractors(): + if extr1 != extr2 and extr1.__doc__ and extr2.__doc__: + self.assertNotEqual( + extr1.__doc__, + extr2.__doc__, + "{} <-> {}".format(extr1, extr2), + ) + + def test_names(self): + """Ensure extractor classes are named CategorySubcategoryExtractor""" + def capitalize(c): + if "-" in c: + return string.capwords(c.replace("-", " ")).replace(" ", "") + if "." in c: + c = c.replace(".", "") + return c.capitalize() + + mapping = { + "2chan" : "futaba", + "3dbooru": "threedeebooru", + "4chan" : "fourchan", + "4plebs" : "fourplebs", + "8chan" : "infinitychan", + "oauth" : None, + } + + for extr in extractor.extractors(): + category = mapping.get(extr.category, extr.category) + if category: + expected = "{}{}Extractor".format( + capitalize(category), + capitalize(extr.subcategory), + ) + if expected[0].isdigit(): + expected = "_" + expected + self.assertEqual(expected, extr.__name__) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_oauth.py b/test/test_oauth.py new file mode 100644 index 0000000..2ce5b43 --- /dev/null +++ b/test/test_oauth.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2018 Mike Fährmann +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 as +# published by the Free Software Foundation. + +import unittest + +from gallery_dl import oauth, text + +TESTSERVER = "http://term.ie/oauth/example" +CONSUMER_KEY = "key" +CONSUMER_SECRET = "secret" +REQUEST_TOKEN = "requestkey" +REQUEST_TOKEN_SECRET = "requestsecret" +ACCESS_TOKEN = "accesskey" +ACCESS_TOKEN_SECRET = "accesssecret" + + +class TestOAuthSession(unittest.TestCase): + + def test_concat(self): + concat = oauth.concat + + self.assertEqual(concat(), "") + self.assertEqual(concat("str"), "str") + self.assertEqual(concat("str1", "str2"), "str1&str2") + + self.assertEqual(concat("&", "?/"), "%26&%3F%2F") + self.assertEqual( + concat("GET", "http://example.org/", "foo=bar&baz=a"), + "GET&http%3A%2F%2Fexample.org%2F&foo%3Dbar%26baz%3Da" + ) + + def test_nonce(self, size=16): + nonce_values = set(oauth.nonce(size) for _ in range(size)) + + # uniqueness + self.assertEqual(len(nonce_values), size) + + # length + for nonce in nonce_values: + self.assertEqual(len(nonce), size) + + def test_quote(self): + quote = oauth.quote + + reserved = ",;:!\"§$%&/(){}[]=?`´+*'äöü" + unreserved = ("ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789-._~") + + for char in unreserved: + self.assertEqual(quote(char), char) + + for char in reserved: + quoted = quote(char) + quoted_hex = quoted.replace("%", "") + self.assertTrue(quoted.startswith("%")) + self.assertTrue(len(quoted) >= 3) + self.assertEqual(quoted_hex.upper(), quoted_hex) + + def test_request_token(self): + response = self._oauth_request( + "/request_token.php", {}) + expected = "oauth_token=requestkey&oauth_token_secret=requestsecret" + self.assertEqual(response, expected, msg=response) + + data = text.parse_query(response) + self.assertTrue(data["oauth_token"], REQUEST_TOKEN) + self.assertTrue(data["oauth_token_secret"], REQUEST_TOKEN_SECRET) + + def test_access_token(self): + response = self._oauth_request( + "/access_token.php", {}, REQUEST_TOKEN, REQUEST_TOKEN_SECRET) + expected = "oauth_token=accesskey&oauth_token_secret=accesssecret" + self.assertEqual(response, expected, msg=response) + + data = text.parse_query(response) + self.assertTrue(data["oauth_token"], ACCESS_TOKEN) + self.assertTrue(data["oauth_token_secret"], ACCESS_TOKEN_SECRET) + + def test_authenticated_call(self): + params = {"method": "foo", "a": "äöüß/?&#", "äöüß/?&#": "a"} + response = self._oauth_request( + "/echo_api.php", params, ACCESS_TOKEN, ACCESS_TOKEN_SECRET) + + self.assertEqual(text.parse_query(response), params) + + def _oauth_request(self, endpoint, params=None, + oauth_token=None, oauth_token_secret=None): + session = oauth.OAuth1Session( + CONSUMER_KEY, CONSUMER_SECRET, + oauth_token, oauth_token_secret, + ) + url = TESTSERVER + endpoint + return session.get(url, params=params).text + + +if __name__ == "__main__": + unittest.main(warnings="ignore") diff --git a/test/test_results.py b/test/test_results.py new file mode 100644 index 0000000..8f03f03 --- /dev/null +++ b/test/test_results.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2015-2019 Mike Fährmann +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 as +# published by the Free Software Foundation. + +import os +import sys +import re +import json +import hashlib +import unittest +from gallery_dl import extractor, util, job, config, exception + + +# these don't work on Travis CI +TRAVIS_SKIP = { + "exhentai", "kissmanga", "mangafox", "dynastyscans", "nijie", "bobx", + "archivedmoe", "archiveofsins", "thebarchive", "fireden", "4plebs", + "sankaku", "idolcomplex", "mangahere", "readcomiconline", "mangadex", + "sankakucomplex", +} + +# temporary issues, etc. +BROKEN = { + "komikcast", + "mangapark", +} + + +class TestExtractorResults(unittest.TestCase): + + def setUp(self): + setup_test_config() + + def tearDown(self): + config.clear() + + @classmethod + def setUpClass(cls): + cls._skipped = [] + + @classmethod + def tearDownClass(cls): + if cls._skipped: + print("\n\nSkipped tests:") + for url, exc in cls._skipped: + print('- {} ("{}")'.format(url, exc)) + + def _run_test(self, extr, url, result): + if result: + if "options" in result: + for key, value in result["options"]: + config.set(key.split("."), value) + if "range" in result: + config.set(("image-range",), result["range"]) + config.set(("chapter-range",), result["range"]) + content = "content" in result + else: + content = False + + tjob = ResultJob(url, content=content) + self.assertEqual(extr, tjob.extractor.__class__) + + if not result: + return + if "exception" in result: + with self.assertRaises(result["exception"]): + tjob.run() + return + try: + tjob.run() + except exception.StopExtraction: + pass + except exception.HttpError as exc: + exc = str(exc) + if re.match(r"5\d\d: ", exc) or \ + re.search(r"\bRead timed out\b", exc): + self._skipped.append((url, exc)) + self.skipTest(exc) + raise + + # test archive-id uniqueness + self.assertEqual(len(set(tjob.list_archive)), len(tjob.list_archive)) + + # test '_extractor' entries + if tjob.queue: + for url, kwdict in zip(tjob.list_url, tjob.list_keyword): + if "_extractor" in kwdict: + extr = kwdict["_extractor"].from_url(url) + self.assertIsInstance(extr, kwdict["_extractor"]) + self.assertEqual(extr.url, url) + + # test extraction results + if "url" in result: + self.assertEqual(result["url"], tjob.hash_url.hexdigest()) + + if "content" in result: + self.assertEqual(result["content"], tjob.hash_content.hexdigest()) + + if "keyword" in result: + keyword = result["keyword"] + if isinstance(keyword, dict): + for kwdict in tjob.list_keyword: + self._test_kwdict(kwdict, keyword) + else: # assume SHA1 hash + self.assertEqual(keyword, tjob.hash_keyword.hexdigest()) + + if "count" in result: + count = result["count"] + if isinstance(count, str): + self.assertRegex(count, r"^ *(==|!=|<|<=|>|>=) *\d+ *$") + expr = "{} {}".format(len(tjob.list_url), count) + self.assertTrue(eval(expr), msg=expr) + else: # assume integer + self.assertEqual(len(tjob.list_url), count) + + if "pattern" in result: + self.assertGreater(len(tjob.list_url), 0) + for url in tjob.list_url: + self.assertRegex(url, result["pattern"]) + + def _test_kwdict(self, kwdict, tests): + for key, test in tests.items(): + if key.startswith("?"): + key = key[1:] + if key not in kwdict: + continue + self.assertIn(key, kwdict) + value = kwdict[key] + + if isinstance(test, dict): + self._test_kwdict(value, test) + elif isinstance(test, type): + self.assertIsInstance(value, test, msg=key) + elif isinstance(test, str): + if test.startswith("re:"): + self.assertRegex(value, test[3:], msg=key) + elif test.startswith("type:"): + self.assertEqual(type(value).__name__, test[5:], msg=key) + else: + self.assertEqual(value, test, msg=key) + else: + self.assertEqual(value, test, msg=key) + + +class ResultJob(job.DownloadJob): + """Generate test-results for extractor runs""" + + def __init__(self, url, parent=None, content=False): + job.DownloadJob.__init__(self, url, parent) + self.queue = False + self.content = content + self.list_url = [] + self.list_keyword = [] + self.list_archive = [] + self.hash_url = hashlib.sha1() + self.hash_keyword = hashlib.sha1() + self.hash_archive = hashlib.sha1() + self.hash_content = hashlib.sha1() + if content: + self.fileobj = TestPathfmt(self.hash_content) + self.get_downloader("http")._check_extension = lambda a, b: None + + self.format_directory = TestFormatter( + "".join(self.extractor.directory_fmt)) + self.format_filename = TestFormatter(self.extractor.filename_fmt) + + def run(self): + for msg in self.extractor: + self.dispatch(msg) + + def handle_url(self, url, keywords, fallback=None): + self.update_url(url) + self.update_keyword(keywords) + self.update_archive(keywords) + self.update_content(url) + self.format_filename.format_map(keywords) + + def handle_directory(self, keywords): + self.update_keyword(keywords, False) + self.format_directory.format_map(keywords) + + def handle_queue(self, url, keywords): + self.queue = True + self.update_url(url) + self.update_keyword(keywords) + + def update_url(self, url): + self.list_url.append(url) + self.hash_url.update(url.encode()) + + def update_keyword(self, kwdict, to_list=True): + if to_list: + self.list_keyword.append(kwdict) + kwdict = self._filter(kwdict) + self.hash_keyword.update( + json.dumps(kwdict, sort_keys=True, default=str).encode()) + + def update_archive(self, kwdict): + archive_id = self.extractor.archive_fmt.format_map(kwdict) + self.list_archive.append(archive_id) + self.hash_archive.update(archive_id.encode()) + + def update_content(self, url): + if self.content: + scheme = url.partition(":")[0] + self.get_downloader(scheme).download(url, self.fileobj) + + +class TestPathfmt(): + + def __init__(self, hashobj): + self.hashobj = hashobj + self.path = "" + self.size = 0 + self.has_extension = True + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + def open(self, mode): + self.size = 0 + return self + + def write(self, content): + """Update SHA1 hash""" + self.size += len(content) + self.hashobj.update(content) + + def tell(self): + return self.size + + def part_size(self): + return 0 + + +class TestFormatter(util.Formatter): + + @staticmethod + def _noop(_): + return "" + + def _apply_simple(self, key, fmt): + if key == "extension" or "._format_optional." in repr(fmt): + return self._noop + + def wrap(obj): + return fmt(obj[key]) + return wrap + + def _apply(self, key, funcs, fmt): + if key == "extension" or "._format_optional." in repr(fmt): + return self._noop + + def wrap(obj): + obj = obj[key] + for func in funcs: + obj = func(obj) + return fmt(obj) + return wrap + + +def setup_test_config(): + name = "gallerydl" + email = "gallerydl@openaliasbox.org" + + config.clear() + config.set(("cache", "file"), ":memory:") + config.set(("downloader", "part"), False) + config.set(("extractor", "timeout"), 60) + config.set(("extractor", "username"), name) + config.set(("extractor", "password"), name) + config.set(("extractor", "nijie", "username"), email) + config.set(("extractor", "seiga", "username"), email) + config.set(("extractor", "danbooru", "username"), None) + config.set(("extractor", "twitter" , "username"), None) + config.set(("extractor", "mangoxo" , "password"), "VZ8DL3983u") + + config.set(("extractor", "deviantart", "client-id"), "7777") + config.set(("extractor", "deviantart", "client-secret"), + "ff14994c744d9208e5caeec7aab4a026") + + config.set(("extractor", "tumblr", "api-key"), + "0cXoHfIqVzMQcc3HESZSNsVlulGxEXGDTTZCDrRrjaa0jmuTc6") + config.set(("extractor", "tumblr", "api-secret"), + "6wxAK2HwrXdedn7VIoZWxGqVhZ8JdYKDLjiQjL46MLqGuEtyVj") + config.set(("extractor", "tumblr", "access-token"), + "N613fPV6tOZQnyn0ERTuoEZn0mEqG8m2K8M3ClSJdEHZJuqFdG") + config.set(("extractor", "tumblr", "access-token-secret"), + "sgOA7ZTT4FBXdOGGVV331sSp0jHYp4yMDRslbhaQf7CaS71i4O") + + +def generate_tests(): + """Dynamically generate extractor unittests""" + def _generate_test(extr, tcase): + def test(self): + url, result = tcase + print("\n", url, sep="") + self._run_test(extr, url, result) + return test + + # enable selective testing for direct calls + if __name__ == '__main__' and len(sys.argv) > 1: + if sys.argv[1].lower() == "all": + fltr = lambda c, bc: True # noqa: E731 + elif sys.argv[1].lower() == "broken": + fltr = lambda c, bc: c in BROKEN # noqa: E731 + else: + argv = sys.argv[1:] + fltr = lambda c, bc: c in argv or bc in argv # noqa: E731 + del sys.argv[1:] + else: + skip = set(BROKEN) + if "CI" in os.environ and "TRAVIS" in os.environ: + skip |= set(TRAVIS_SKIP) + if skip: + print("skipping:", ", ".join(skip)) + fltr = lambda c, bc: c not in skip # noqa: E731 + + # filter available extractor classes + extractors = [ + extr for extr in extractor.extractors() + if fltr(extr.category, getattr(extr, "basecategory", None)) + ] + + # add 'test_...' methods + for extr in extractors: + name = "test_" + extr.__name__ + "_" + for num, tcase in enumerate(extr._get_tests(), 1): + test = _generate_test(extr, tcase) + test.__name__ = name + str(num) + setattr(TestExtractorResults, test.__name__, test) + + +generate_tests() +if __name__ == '__main__': + unittest.main(warnings='ignore') diff --git a/test/test_text.py b/test/test_text.py new file mode 100644 index 0000000..405acd3 --- /dev/null +++ b/test/test_text.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2015-2018 Mike Fährmann +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 as +# published by the Free Software Foundation. + +import unittest +import datetime + +from gallery_dl import text + + +INVALID = ((), [], {}, None, 1, 2.3) +INVALID_ALT = ((), [], {}, None, "") + + +class TestText(unittest.TestCase): + + def test_clean_xml(self, f=text.clean_xml): + # standard usage + self.assertEqual(f(""), "") + self.assertEqual(f("foo"), "foo") + self.assertEqual(f("\tfoo\nbar\r"), "\tfoo\nbar\r") + self.assertEqual(f("\ab\ba\fr\v"), "bar") + + # 'repl' argument + repl = "#" + self.assertEqual(f("", repl), "") + self.assertEqual(f("foo", repl), "foo") + self.assertEqual(f("\tfoo\nbar\r", repl), "\tfoo\nbar\r") + self.assertEqual( + f("\ab\ba\fr\v", repl), "#b#a#r#") + + # removal of all illegal control characters + value = "".join(chr(x) for x in range(32)) + self.assertEqual(f(value), "\t\n\r") + + # 'invalid' arguments + for value in INVALID: + self.assertEqual(f(value), "") + + def test_remove_html(self, f=text.remove_html): + result = "Hello World." + + # standard usage + self.assertEqual(f(""), "") + self.assertEqual(f("Hello World."), result) + self.assertEqual(f(" Hello World. "), result) + self.assertEqual(f("Hello
World."), result) + self.assertEqual( + f("
HelloWorld.
"), result) + + # empty HTML + self.assertEqual(f("
"), "") + self.assertEqual(f("
"), "") + + # malformed HTML + self.assertEqual(f(""), "") + self.assertEqual(f(""), "") + + # invalid arguments + for value in INVALID: + self.assertEqual(f(value), "") + + def test_split_html(self, f=text.split_html): + result = ["Hello", "World."] + empty = [] + + # standard usage + self.assertEqual(f(""), empty) + self.assertEqual(f("Hello World."), ["Hello World."]) + self.assertEqual(f(" Hello World. "), ["Hello World."]) + self.assertEqual(f("Hello
World."), result) + self.assertEqual(f(" Hello
World. "), result) + self.assertEqual( + f("
HelloWorld.
"), result) + + # empty HTML + self.assertEqual(f("
"), empty) + self.assertEqual(f("
"), empty) + + # malformed HTML + self.assertEqual(f(""), empty) + self.assertEqual(f(""), empty) + + # invalid arguments + for value in INVALID: + self.assertEqual(f(value), empty) + + def test_filename_from_url(self, f=text.filename_from_url): + result = "filename.ext" + + # standard usage + self.assertEqual(f(""), "") + self.assertEqual(f("filename.ext"), result) + self.assertEqual(f("/filename.ext"), result) + self.assertEqual(f("example.org/filename.ext"), result) + self.assertEqual(f("http://example.org/v2/filename.ext"), result) + self.assertEqual( + f("http://example.org/v2/filename.ext?param=value#frag"), result) + + # invalid arguments + for value in INVALID: + self.assertEqual(f(value), "") + + def test_ext_from_url(self, f=text.ext_from_url): + result = "ext" + + # standard usage + self.assertEqual(f(""), "") + self.assertEqual(f("filename.ext"), result) + self.assertEqual(f("/filename.ext"), result) + self.assertEqual(f("example.org/filename.ext"), result) + self.assertEqual(f("http://example.org/v2/filename.ext"), result) + self.assertEqual( + f("http://example.org/v2/filename.ext?param=value#frag"), result) + + # invalid arguments + for value in INVALID: + self.assertEqual(f(value), "") + + def test_nameext_from_url(self, f=text.nameext_from_url): + empty = {"filename": "", "extension": ""} + result = {"filename": "filename", "extension": "ext"} + + # standard usage + self.assertEqual(f(""), empty) + self.assertEqual(f("filename.ext"), result) + self.assertEqual(f("/filename.ext"), result) + self.assertEqual(f("example.org/filename.ext"), result) + self.assertEqual(f("http://example.org/v2/filename.ext"), result) + self.assertEqual( + f("http://example.org/v2/filename.ext?param=value#frag"), result) + + # invalid arguments + for value in INVALID: + self.assertEqual(f(value), empty) + + def test_clean_path_windows(self, f=text.clean_path_windows): + self.assertEqual(f(""), "") + self.assertEqual(f("foo"), "foo") + self.assertEqual(f("foo/bar"), "foo_bar") + self.assertEqual(f("foo<>:\"\\/|?*bar"), "foo_________bar") + + # invalid arguments + for value in INVALID: + self.assertEqual(f(value), "") + + def test_clean_path_posix(self, f=text.clean_path_posix): + self.assertEqual(f(""), "") + self.assertEqual(f("foo"), "foo") + self.assertEqual(f("foo/bar"), "foo_bar") + self.assertEqual(f("foo<>:\"\\/|?*bar"), "foo<>:\"\\_|?*bar") + + # invalid arguments + for value in INVALID: + self.assertEqual(f(value), "") + + def test_extract(self, f=text.extract): + txt = "" + self.assertEqual(f(txt, "<", ">"), ("a" , 3)) + self.assertEqual(f(txt, "X", ">"), (None, 0)) + self.assertEqual(f(txt, "<", "X"), (None, 0)) + + # 'pos' argument + for i in range(1, 4): + self.assertEqual(f(txt, "<", ">", i), ("b", 6)) + for i in range(4, 10): + self.assertEqual(f(txt, "<", ">", i), (None, i)) + + # invalid arguments + for value in INVALID: + self.assertEqual(f(value, "<" , ">") , (None, 0)) + self.assertEqual(f(txt , value, ">") , (None, 0)) + self.assertEqual(f(txt , "<" , value), (None, 0)) + + def test_rextract(self, f=text.rextract): + txt = "" + self.assertEqual(f(txt, "<", ">"), ("b" , 3)) + self.assertEqual(f(txt, "X", ">"), (None, -1)) + self.assertEqual(f(txt, "<", "X"), (None, -1)) + + # 'pos' argument + for i in range(10, 3, -1): + self.assertEqual(f(txt, "<", ">", i), ("b", 3)) + for i in range(3, 0, -1): + self.assertEqual(f(txt, "<", ">", i), ("a", 0)) + + # invalid arguments + for value in INVALID: + self.assertEqual(f(value, "<" , ">") , (None, -1)) + self.assertEqual(f(txt , value, ">") , (None, -1)) + self.assertEqual(f(txt , "<" , value), (None, -1)) + + def test_extract_all(self, f=text.extract_all): + txt = "[c][b][a]: xyz! [d][e" + + self.assertEqual( + f(txt, ()), ({}, 0)) + self.assertEqual( + f(txt, (("C", "[", "]"), ("B", "[", "]"), ("A", "[", "]"))), + ({"A": "a", "B": "b", "C": "c"}, 9), + ) + + # 'None' as field name + self.assertEqual( + f(txt, ((None, "[", "]"), (None, "[", "]"), ("A", "[", "]"))), + ({"A": "a"}, 9), + ) + self.assertEqual( + f(txt, ((None, "[", "]"), (None, "[", "]"), (None, "[", "]"))), + ({}, 9), + ) + + # failed matches + self.assertEqual( + f(txt, (("C", "[", "]"), ("X", "X", "X"), ("B", "[", "]"))), + ({"B": "b", "C": "c", "X": None}, 6), + ) + + # 'pos' argument + self.assertEqual( + f(txt, (("B", "[", "]"), ("A", "[", "]")), pos=1), + ({"A": "a", "B": "b"}, 9), + ) + + # 'values' argument + self.assertEqual( + f(txt, (("C", "[", "]"),), values={"A": "a", "B": "b"}), + ({"A": "a", "B": "b", "C": "c"}, 3), + ) + + vdict = {} + rdict, pos = f(txt, (), values=vdict) + self.assertIs(vdict, rdict) + + def test_extract_iter(self, f=text.extract_iter): + txt = "[c][b][a]: xyz! [d][e" + + def g(*args): + return list(f(*args)) + + self.assertEqual( + g("", "[", "]"), []) + self.assertEqual( + g("[a]", "[", "]"), ["a"]) + self.assertEqual( + g(txt, "[", "]"), ["c", "b", "a", "d"]) + self.assertEqual( + g(txt, "X", "X"), []) + self.assertEqual( + g(txt, "[", "]", 6), ["a", "d"]) + + def test_extract_from(self, f=text.extract_from): + txt = "[c][b][a]: xyz! [d][e" + + e = f(txt) + self.assertEqual(e("[", "]"), "c") + self.assertEqual(e("[", "]"), "b") + self.assertEqual(e("[", "]"), "a") + self.assertEqual(e("[", "]"), "d") + self.assertEqual(e("[", "]"), "") + self.assertEqual(e("[", "]"), "") + + e = f(txt, pos=6, default="END") + self.assertEqual(e("[", "]"), "a") + self.assertEqual(e("[", "]"), "d") + self.assertEqual(e("[", "]"), "END") + self.assertEqual(e("[", "]"), "END") + + def test_parse_unicode_escapes(self, f=text.parse_unicode_escapes): + self.assertEqual(f(""), "") + self.assertEqual(f("foobar"), "foobar") + self.assertEqual(f("foo’bar"), "foo’bar") + self.assertEqual(f("foo\\u2019bar"), "foo’bar") + self.assertEqual(f("foo\\u201bar"), "foo‛ar") + self.assertEqual(f("foo\\u201zar"), "foo\\u201zar") + self.assertEqual( + f("\\u2018foo\\u2019\\u2020bar\\u00ff"), + "‘foo’†barÿ", + ) + + def test_parse_bytes(self, f=text.parse_bytes): + self.assertEqual(f("0"), 0) + self.assertEqual(f("50"), 50) + self.assertEqual(f("50k"), 50 * 1024**1) + self.assertEqual(f("50m"), 50 * 1024**2) + self.assertEqual(f("50g"), 50 * 1024**3) + self.assertEqual(f("50t"), 50 * 1024**4) + self.assertEqual(f("50p"), 50 * 1024**5) + + # fractions + self.assertEqual(f("123.456"), 123) + self.assertEqual(f("123.567"), 124) + self.assertEqual(f("0.5M"), round(0.5 * 1024**2)) + + # invalid arguments + for value in INVALID_ALT: + self.assertEqual(f(value), 0) + self.assertEqual(f("NaN"), 0) + self.assertEqual(f("invalid"), 0) + self.assertEqual(f(" 123 kb "), 0) + + def test_parse_int(self, f=text.parse_int): + self.assertEqual(f(0), 0) + self.assertEqual(f("0"), 0) + self.assertEqual(f(123), 123) + self.assertEqual(f("123"), 123) + + # invalid arguments + for value in INVALID_ALT: + self.assertEqual(f(value), 0) + self.assertEqual(f("123.456"), 0) + self.assertEqual(f("zzz"), 0) + self.assertEqual(f([1, 2, 3]), 0) + self.assertEqual(f({1: 2, 3: 4}), 0) + + # 'default' argument + default = "default" + for value in INVALID_ALT: + self.assertEqual(f(value, default), default) + self.assertEqual(f("zzz", default), default) + + def test_parse_float(self, f=text.parse_float): + self.assertEqual(f(0), 0.0) + self.assertEqual(f("0"), 0.0) + self.assertEqual(f(123), 123.0) + self.assertEqual(f("123"), 123.0) + self.assertEqual(f(123.456), 123.456) + self.assertEqual(f("123.456"), 123.456) + + # invalid arguments + for value in INVALID_ALT: + self.assertEqual(f(value), 0.0) + self.assertEqual(f("zzz"), 0.0) + self.assertEqual(f([1, 2, 3]), 0.0) + self.assertEqual(f({1: 2, 3: 4}), 0.0) + + # 'default' argument + default = "default" + for value in INVALID_ALT: + self.assertEqual(f(value, default), default) + self.assertEqual(f("zzz", default), default) + + def test_parse_query(self, f=text.parse_query): + # standard usage + self.assertEqual(f(""), {}) + self.assertEqual(f("foo=1"), {"foo": "1"}) + self.assertEqual(f("foo=1&bar=2"), {"foo": "1", "bar": "2"}) + + # missing value + self.assertEqual(f("bar"), {}) + self.assertEqual(f("foo=1&bar"), {"foo": "1"}) + self.assertEqual(f("foo=1&bar&baz=3"), {"foo": "1", "baz": "3"}) + + # keys with identical names + self.assertEqual(f("foo=1&foo=2"), {"foo": "1"}) + self.assertEqual( + f("foo=1&bar=2&foo=3&bar=4"), + {"foo": "1", "bar": "2"}, + ) + + # invalid arguments + for value in INVALID: + self.assertEqual(f(value), {}) + + def test_parse_timestamp(self, f=text.parse_timestamp): + null = datetime.datetime.utcfromtimestamp(0) + value = datetime.datetime.utcfromtimestamp(1555816235) + + self.assertEqual(f(0) , null) + self.assertEqual(f("0") , null) + self.assertEqual(f(1555816235) , value) + self.assertEqual(f("1555816235"), value) + + for value in INVALID_ALT: + self.assertEqual(f(value), None) + self.assertEqual(f(value, "foo"), "foo") + + def test_parse_datetime(self, f=text.parse_datetime): + null = datetime.datetime.utcfromtimestamp(0) + + self.assertEqual(f("1970-01-01T00:00:00+00:00"), null) + self.assertEqual(f("1970-01-01T00:00:00+0000") , null) + self.assertEqual(f("1970.01.01", "%Y.%m.%d") , null) + + self.assertEqual( + f("2019-05-07T21:25:02+09:00"), + datetime.datetime(2019, 5, 7, 12, 25, 2), + ) + self.assertEqual( + f("2019-05-07T21:25:02+0900"), + datetime.datetime(2019, 5, 7, 12, 25, 2), + ) + self.assertEqual( + f("2019-05-07 21:25:02"), + "2019-05-07 21:25:02", + ) + + for value in INVALID: + self.assertEqual(f(value), None) + self.assertEqual(f("1970.01.01"), "1970.01.01") + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_util.py b/test/test_util.py new file mode 100644 index 0000000..815b2d8 --- /dev/null +++ b/test/test_util.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2015-2019 Mike Fährmann +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 as +# published by the Free Software Foundation. + +import unittest +import sys +import random +import string + +from gallery_dl import util, text, exception + + +class TestRange(unittest.TestCase): + + def test_parse_range(self, f=util.RangePredicate.parse_range): + self.assertEqual( + f(""), + []) + self.assertEqual( + f("1-2"), + [(1, 2)]) + self.assertEqual( + f("-"), + [(1, sys.maxsize)]) + self.assertEqual( + f("-2,4,6-8,10-"), + [(1, 2), (4, 4), (6, 8), (10, sys.maxsize)]) + self.assertEqual( + f(" - 3 , 4- 4, 2-6"), + [(1, 3), (4, 4), (2, 6)]) + + def test_optimize_range(self, f=util.RangePredicate.optimize_range): + self.assertEqual( + f([]), + []) + self.assertEqual( + f([(2, 4)]), + [(2, 4)]) + self.assertEqual( + f([(2, 4), (6, 8), (10, 12)]), + [(2, 4), (6, 8), (10, 12)]) + self.assertEqual( + f([(2, 4), (4, 6), (5, 8)]), + [(2, 8)]) + self.assertEqual( + f([(1, 1), (2, 2), (3, 6), (8, 9)]), + [(1, 6), (8, 9)]) + + +class TestPredicate(unittest.TestCase): + + def test_range_predicate(self): + dummy = None + + pred = util.RangePredicate(" - 3 , 4- 4, 2-6") + for i in range(6): + self.assertTrue(pred(dummy, dummy)) + with self.assertRaises(exception.StopExtraction): + bool(pred(dummy, dummy)) + + pred = util.RangePredicate("1, 3, 5") + self.assertTrue(pred(dummy, dummy)) + self.assertFalse(pred(dummy, dummy)) + self.assertTrue(pred(dummy, dummy)) + self.assertFalse(pred(dummy, dummy)) + self.assertTrue(pred(dummy, dummy)) + with self.assertRaises(exception.StopExtraction): + bool(pred(dummy, dummy)) + + pred = util.RangePredicate("") + with self.assertRaises(exception.StopExtraction): + bool(pred(dummy, dummy)) + + def test_unique_predicate(self): + dummy = None + pred = util.UniquePredicate() + + # no duplicates + self.assertTrue(pred("1", dummy)) + self.assertTrue(pred("2", dummy)) + self.assertFalse(pred("1", dummy)) + self.assertFalse(pred("2", dummy)) + self.assertTrue(pred("3", dummy)) + self.assertFalse(pred("3", dummy)) + + # duplicates for "text:" + self.assertTrue(pred("text:123", dummy)) + self.assertTrue(pred("text:123", dummy)) + self.assertTrue(pred("text:123", dummy)) + + def test_filter_predicate(self): + url = "" + + pred = util.FilterPredicate("a < 3") + self.assertTrue(pred(url, {"a": 2})) + self.assertFalse(pred(url, {"a": 3})) + + with self.assertRaises(SyntaxError): + util.FilterPredicate("(") + + with self.assertRaises(exception.FilterError): + util.FilterPredicate("a > 1")(url, {"a": None}) + + with self.assertRaises(exception.FilterError): + util.FilterPredicate("b > 1")(url, {"a": 2}) + + def test_build_predicate(self): + pred = util.build_predicate([]) + self.assertIsInstance(pred, type(lambda: True)) + + pred = util.build_predicate([util.UniquePredicate()]) + self.assertIsInstance(pred, util.UniquePredicate) + + pred = util.build_predicate([util.UniquePredicate(), + util.UniquePredicate()]) + self.assertIsInstance(pred, util.ChainPredicate) + + +class TestISO639_1(unittest.TestCase): + + def test_code_to_language(self): + d = "default" + self._run_test(util.code_to_language, { + ("en",): "English", + ("FR",): "French", + ("xx",): None, + ("" ,): None, + (None,): None, + ("en", d): "English", + ("FR", d): "French", + ("xx", d): d, + ("" , d): d, + (None, d): d, + }) + + def test_language_to_code(self): + d = "default" + self._run_test(util.language_to_code, { + ("English",): "en", + ("fRENch",): "fr", + ("xx",): None, + ("" ,): None, + (None,): None, + ("English", d): "en", + ("fRENch", d): "fr", + ("xx", d): d, + ("" , d): d, + (None, d): d, + }) + + def _run_test(self, func, tests): + for args, result in tests.items(): + self.assertEqual(func(*args), result) + + +class TestFormatter(unittest.TestCase): + + kwdict = { + "a": "hElLo wOrLd", + "b": "äöü", + "l": ["a", "b", "c"], + "n": None, + "u": "%27%3C%20/%20%3E%27", + "name": "Name", + "title1": "Title", + "title2": "", + "title3": None, + "title4": 0, + } + + def test_conversions(self): + self._run_test("{a!l}", "hello world") + self._run_test("{a!u}", "HELLO WORLD") + self._run_test("{a!c}", "Hello world") + self._run_test("{a!C}", "Hello World") + self._run_test("{a!U}", self.kwdict["a"]) + self._run_test("{u!U}", "'< / >'") + self._run_test("{a!s}", self.kwdict["a"]) + self._run_test("{a!r}", "'" + self.kwdict["a"] + "'") + self._run_test("{a!a}", "'" + self.kwdict["a"] + "'") + self._run_test("{b!a}", "'\\xe4\\xf6\\xfc'") + self._run_test("{a!S}", self.kwdict["a"]) + self._run_test("{l!S}", "a, b, c") + self._run_test("{n!S}", "") + with self.assertRaises(KeyError): + self._run_test("{a!q}", "hello world") + + def test_optional(self): + self._run_test("{name}{title1}", "NameTitle") + self._run_test("{name}{title1:?//}", "NameTitle") + self._run_test("{name}{title1:? **/''/}", "Name **Title''") + + self._run_test("{name}{title2}", "Name") + self._run_test("{name}{title2:?//}", "Name") + self._run_test("{name}{title2:? **/''/}", "Name") + + self._run_test("{name}{title3}", "NameNone") + self._run_test("{name}{title3:?//}", "Name") + self._run_test("{name}{title3:? **/''/}", "Name") + + self._run_test("{name}{title4}", "Name0") + self._run_test("{name}{title4:?//}", "Name") + self._run_test("{name}{title4:? **/''/}", "Name") + + def test_missing(self): + replacement = "None" + + self._run_test("{missing}", replacement) + self._run_test("{missing.attr}", replacement) + self._run_test("{missing[key]}", replacement) + self._run_test("{missing:?a//}", "") + + self._run_test("{name[missing]}", replacement) + self._run_test("{name[missing].attr}", replacement) + self._run_test("{name[missing][key]}", replacement) + self._run_test("{name[missing]:?a//}", "") + + def test_missing_custom_default(self): + replacement = default = "foobar" + self._run_test("{missing}" , replacement, default) + self._run_test("{missing.attr}", replacement, default) + self._run_test("{missing[key]}", replacement, default) + self._run_test("{missing:?a//}", "a" + default, default) + + def test_slicing(self): + v = self.kwdict["a"] + self._run_test("{a[1:10]}" , v[1:10]) + self._run_test("{a[-10:-1]}", v[-10:-1]) + self._run_test("{a[5:]}" , v[5:]) + self._run_test("{a[50:]}", v[50:]) + self._run_test("{a[:5]}" , v[:5]) + self._run_test("{a[:50]}", v[:50]) + self._run_test("{a[:]}" , v) + self._run_test("{a[1:10:2]}" , v[1:10:2]) + self._run_test("{a[-10:-1:2]}", v[-10:-1:2]) + self._run_test("{a[5::2]}" , v[5::2]) + self._run_test("{a[50::2]}", v[50::2]) + self._run_test("{a[:5:2]}" , v[:5:2]) + self._run_test("{a[:50:2]}", v[:50:2]) + self._run_test("{a[::]}" , v) + + def test_maxlen(self): + v = self.kwdict["a"] + self._run_test("{a:L5/foo/}" , "foo") + self._run_test("{a:L50/foo/}", v) + self._run_test("{a:L50/foo/>50}", " " * 39 + v) + self._run_test("{a:L50/foo/>51}", "foo") + self._run_test("{a:Lab/foo/}", "foo") + + def test_join(self): + self._run_test("{l:J}" , "abc") + self._run_test("{l:J,}" , "a,b,c") + self._run_test("{l:J,/}" , "a,b,c") + self._run_test("{l:J,/>20}" , " a,b,c") + self._run_test("{l:J - }" , "a - b - c") + self._run_test("{l:J - /}" , "a - b - c") + self._run_test("{l:J - />20}", " a - b - c") + + self._run_test("{a:J/}" , self.kwdict["a"]) + self._run_test("{a:J, /}" , ", ".join(self.kwdict["a"])) + + def test_replace(self): + self._run_test("{a:Rh/C/}" , "CElLo wOrLd") + self._run_test("{a!l:Rh/C/}", "Cello world") + self._run_test("{a!u:Rh/C/}", "HELLO WORLD") + + self._run_test("{a!l:Rl/_/}", "he__o wor_d") + self._run_test("{a!l:Rl//}" , "heo word") + self._run_test("{name:Rame/othing/}", "Nothing") + + def _run_test(self, format_string, result, default=None): + formatter = util.Formatter(format_string, default) + output = formatter.format_map(self.kwdict) + self.assertEqual(output, result, format_string) + + +class TestOther(unittest.TestCase): + + def test_bencode(self): + self.assertEqual(util.bencode(0), "") + self.assertEqual(util.bencode(123), "123") + self.assertEqual(util.bencode(123, "01"), "1111011") + self.assertEqual(util.bencode(123, "BA"), "AAAABAA") + + def test_bdecode(self): + self.assertEqual(util.bdecode(""), 0) + self.assertEqual(util.bdecode("123"), 123) + self.assertEqual(util.bdecode("1111011", "01"), 123) + self.assertEqual(util.bdecode("AAAABAA", "BA"), 123) + + def test_bencode_bdecode(self): + for _ in range(100): + value = random.randint(0, 1000000) + for alphabet in ("01", "0123456789", string.ascii_letters): + result = util.bdecode(util.bencode(value, alphabet), alphabet) + self.assertEqual(result, value) + + def test_advance(self): + items = range(5) + + self.assertCountEqual( + util.advance(items, 0), items) + self.assertCountEqual( + util.advance(items, 3), range(3, 5)) + self.assertCountEqual( + util.advance(items, 9), []) + self.assertCountEqual( + util.advance(util.advance(items, 1), 2), range(3, 5)) + + def test_raises(self): + func = util.raises(Exception()) + with self.assertRaises(Exception): + func() + + func = util.raises(ValueError(1)) + with self.assertRaises(ValueError): + func() + with self.assertRaises(ValueError): + func() + with self.assertRaises(ValueError): + func() + + def test_combine_dict(self): + self.assertEqual( + util.combine_dict({}, {}), + {}) + self.assertEqual( + util.combine_dict({1: 1, 2: 2}, {2: 4, 4: 8}), + {1: 1, 2: 4, 4: 8}) + self.assertEqual( + util.combine_dict( + {1: {11: 22, 12: 24}, 2: {13: 26, 14: 28}}, + {1: {11: 33, 13: 39}, 2: "str"}), + {1: {11: 33, 12: 24, 13: 39}, 2: "str"}) + self.assertEqual( + util.combine_dict( + {1: {2: {3: {4: {"1": "a", "2": "b"}}}}}, + {1: {2: {3: {4: {"1": "A", "3": "C"}}}}}), + {1: {2: {3: {4: {"1": "A", "2": "b", "3": "C"}}}}}) + + def test_transform_dict(self): + d = {} + util.transform_dict(d, str) + self.assertEqual(d, {}) + + d = {1: 123, 2: "123", 3: True, 4: None} + util.transform_dict(d, str) + self.assertEqual( + d, {1: "123", 2: "123", 3: "True", 4: "None"}) + + d = {1: 123, 2: "123", 3: "foo", 4: {11: 321, 12: "321", 13: "bar"}} + util.transform_dict(d, text.parse_int) + self.assertEqual( + d, {1: 123, 2: 123, 3: 0, 4: {11: 321, 12: 321, 13: 0}}) + + def test_number_to_string(self, f=util.number_to_string): + self.assertEqual(f(1) , "1") + self.assertEqual(f(1.0) , "1.0") + self.assertEqual(f("1.0") , "1.0") + self.assertEqual(f([1]) , [1]) + self.assertEqual(f({1: 2}), {1: 2}) + self.assertEqual(f(True) , True) + self.assertEqual(f(None) , None) + + def test_to_string(self, f=util.to_string): + self.assertEqual(f(1) , "1") + self.assertEqual(f(1.0) , "1.0") + self.assertEqual(f("1.0"), "1.0") + + self.assertEqual(f("") , "") + self.assertEqual(f(None) , "") + self.assertEqual(f(0) , "") + + self.assertEqual(f(["a"]), "a") + self.assertEqual(f([1]) , "1") + self.assertEqual(f(["a", "b", "c"]), "a, b, c") + self.assertEqual(f([1, 2, 3]), "1, 2, 3") + + def test_universal_none(self): + obj = util.NONE + + self.assertFalse(obj) + self.assertEqual(str(obj), str(None)) + self.assertEqual(repr(obj), repr(None)) + self.assertIs(obj.attr, obj) + self.assertIs(obj["key"], obj) + + +if __name__ == '__main__': + unittest.main() -- cgit v1.2.3