summaryrefslogtreecommitdiffstats
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/test_cookies.py125
-rw-r--r--test/test_downloader.py1
-rw-r--r--test/test_extractor.py37
-rw-r--r--test/test_formatter.py2
-rw-r--r--test/test_postprocessor.py42
-rw-r--r--test/test_results.py207
6 files changed, 266 insertions, 148 deletions
diff --git a/test/test_cookies.py b/test/test_cookies.py
index 335fa3d..a6ad05f 100644
--- a/test/test_cookies.py
+++ b/test/test_cookies.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
-# Copyright 2017-2022 Mike Fährmann
+# Copyright 2017-2023 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
@@ -46,8 +46,7 @@ class TestCookiejar(unittest.TestCase):
def test_cookiefile(self):
config.set((), "cookies", self.cookiefile)
-
- cookies = extractor.find("test:").session.cookies
+ cookies = _get_extractor("test").cookies
self.assertEqual(len(cookies), 1)
cookie = next(iter(cookies))
@@ -65,12 +64,14 @@ class TestCookiejar(unittest.TestCase):
def _test_warning(self, filename, exc):
config.set((), "cookies", filename)
log = logging.getLogger("test")
+
with mock.patch.object(log, "warning") as mock_warning:
- cookies = extractor.find("test:").session.cookies
- self.assertEqual(len(cookies), 0)
- self.assertEqual(mock_warning.call_count, 1)
- self.assertEqual(mock_warning.call_args[0][0], "cookies: %s")
- self.assertIsInstance(mock_warning.call_args[0][1], exc)
+ cookies = _get_extractor("test").cookies
+
+ self.assertEqual(len(cookies), 0)
+ self.assertEqual(mock_warning.call_count, 1)
+ self.assertEqual(mock_warning.call_args[0][0], "cookies: %s")
+ self.assertIsInstance(mock_warning.call_args[0][1], exc)
class TestCookiedict(unittest.TestCase):
@@ -83,7 +84,8 @@ class TestCookiedict(unittest.TestCase):
config.clear()
def test_dict(self):
- cookies = extractor.find("test:").session.cookies
+ cookies = _get_extractor("test").cookies
+
self.assertEqual(len(cookies), len(self.cdict))
self.assertEqual(sorted(cookies.keys()), sorted(self.cdict.keys()))
self.assertEqual(sorted(cookies.values()), sorted(self.cdict.values()))
@@ -91,11 +93,11 @@ class TestCookiedict(unittest.TestCase):
def test_domain(self):
for category in ["exhentai", "idolcomplex", "nijie", "horne"]:
extr = _get_extractor(category)
- cookies = extr.session.cookies
+ cookies = extr.cookies
for key in self.cdict:
self.assertTrue(key in cookies)
for c in cookies:
- self.assertEqual(c.domain, extr.cookiedomain)
+ self.assertEqual(c.domain, extr.cookies_domain)
class TestCookieLogin(unittest.TestCase):
@@ -122,91 +124,96 @@ class TestCookieLogin(unittest.TestCase):
class TestCookieUtils(unittest.TestCase):
def test_check_cookies(self):
- extr = extractor.find("test:")
- self.assertFalse(extr._cookiejar, "empty")
- self.assertFalse(extr.cookiedomain, "empty")
+ extr = _get_extractor("test")
+ self.assertFalse(extr.cookies, "empty")
+ self.assertFalse(extr.cookies_domain, "empty")
# always returns False when checking for empty cookie list
- self.assertFalse(extr._check_cookies(()))
+ self.assertFalse(extr.cookies_check(()))
- self.assertFalse(extr._check_cookies(("a",)))
- self.assertFalse(extr._check_cookies(("a", "b")))
- self.assertFalse(extr._check_cookies(("a", "b", "c")))
+ self.assertFalse(extr.cookies_check(("a",)))
+ self.assertFalse(extr.cookies_check(("a", "b")))
+ self.assertFalse(extr.cookies_check(("a", "b", "c")))
- extr._cookiejar.set("a", "1")
- self.assertTrue(extr._check_cookies(("a",)))
- self.assertFalse(extr._check_cookies(("a", "b")))
- self.assertFalse(extr._check_cookies(("a", "b", "c")))
+ extr.cookies.set("a", "1")
+ self.assertTrue(extr.cookies_check(("a",)))
+ self.assertFalse(extr.cookies_check(("a", "b")))
+ self.assertFalse(extr.cookies_check(("a", "b", "c")))
- extr._cookiejar.set("b", "2")
- self.assertTrue(extr._check_cookies(("a",)))
- self.assertTrue(extr._check_cookies(("a", "b")))
- self.assertFalse(extr._check_cookies(("a", "b", "c")))
+ extr.cookies.set("b", "2")
+ self.assertTrue(extr.cookies_check(("a",)))
+ self.assertTrue(extr.cookies_check(("a", "b")))
+ self.assertFalse(extr.cookies_check(("a", "b", "c")))
def test_check_cookies_domain(self):
- extr = extractor.find("test:")
- self.assertFalse(extr._cookiejar, "empty")
- extr.cookiedomain = ".example.org"
+ extr = _get_extractor("test")
+ self.assertFalse(extr.cookies, "empty")
+ extr.cookies_domain = ".example.org"
- self.assertFalse(extr._check_cookies(("a",)))
- self.assertFalse(extr._check_cookies(("a", "b")))
+ self.assertFalse(extr.cookies_check(("a",)))
+ self.assertFalse(extr.cookies_check(("a", "b")))
- extr._cookiejar.set("a", "1")
- self.assertFalse(extr._check_cookies(("a",)))
+ extr.cookies.set("a", "1")
+ self.assertFalse(extr.cookies_check(("a",)))
- extr._cookiejar.set("a", "1", domain=extr.cookiedomain)
- self.assertTrue(extr._check_cookies(("a",)))
+ extr.cookies.set("a", "1", domain=extr.cookies_domain)
+ self.assertTrue(extr.cookies_check(("a",)))
- extr._cookiejar.set("a", "1", domain="www" + extr.cookiedomain)
- self.assertEqual(len(extr._cookiejar), 3)
- self.assertTrue(extr._check_cookies(("a",)))
+ extr.cookies.set("a", "1", domain="www" + extr.cookies_domain)
+ self.assertEqual(len(extr.cookies), 3)
+ self.assertTrue(extr.cookies_check(("a",)))
- extr._cookiejar.set("b", "2", domain=extr.cookiedomain)
- extr._cookiejar.set("c", "3", domain=extr.cookiedomain)
- self.assertTrue(extr._check_cookies(("a", "b", "c")))
+ extr.cookies.set("b", "2", domain=extr.cookies_domain)
+ extr.cookies.set("c", "3", domain=extr.cookies_domain)
+ self.assertTrue(extr.cookies_check(("a", "b", "c")))
def test_check_cookies_expires(self):
- extr = extractor.find("test:")
- self.assertFalse(extr._cookiejar, "empty")
- self.assertFalse(extr.cookiedomain, "empty")
+ extr = _get_extractor("test")
+ self.assertFalse(extr.cookies, "empty")
+ self.assertFalse(extr.cookies_domain, "empty")
now = int(time.time())
log = logging.getLogger("test")
- extr._cookiejar.set("a", "1", expires=now-100)
+ extr.cookies.set("a", "1", expires=now-100)
with mock.patch.object(log, "warning") as mw:
- self.assertFalse(extr._check_cookies(("a",)))
+ self.assertFalse(extr.cookies_check(("a",)))
self.assertEqual(mw.call_count, 1)
self.assertEqual(mw.call_args[0], ("Cookie '%s' has expired", "a"))
- extr._cookiejar.set("a", "1", expires=now+100)
+ extr.cookies.set("a", "1", expires=now+100)
with mock.patch.object(log, "warning") as mw:
- self.assertTrue(extr._check_cookies(("a",)))
+ self.assertTrue(extr.cookies_check(("a",)))
self.assertEqual(mw.call_count, 1)
self.assertEqual(mw.call_args[0], (
"Cookie '%s' will expire in less than %s hour%s", "a", 1, ""))
- extr._cookiejar.set("a", "1", expires=now+100+7200)
+ extr.cookies.set("a", "1", expires=now+100+7200)
with mock.patch.object(log, "warning") as mw:
- self.assertTrue(extr._check_cookies(("a",)))
+ self.assertTrue(extr.cookies_check(("a",)))
self.assertEqual(mw.call_count, 1)
self.assertEqual(mw.call_args[0], (
"Cookie '%s' will expire in less than %s hour%s", "a", 3, "s"))
- extr._cookiejar.set("a", "1", expires=now+100+24*3600)
+ extr.cookies.set("a", "1", expires=now+100+24*3600)
with mock.patch.object(log, "warning") as mw:
- self.assertTrue(extr._check_cookies(("a",)))
+ self.assertTrue(extr.cookies_check(("a",)))
self.assertEqual(mw.call_count, 0)
def _get_extractor(category):
- URLS = {
- "exhentai" : "https://exhentai.org/g/1200119/d55c44d3d0/",
- "idolcomplex": "https://idol.sankakucomplex.com/post/show/1",
- "nijie" : "https://nijie.info/view.php?id=1",
- "horne" : "https://horne.red/view.php?id=1",
- }
- return extractor.find(URLS[category])
+ extr = extractor.find(URLS[category])
+ extr.initialize()
+ return extr
+
+
+URLS = {
+ "exhentai" : "https://exhentai.org/g/1200119/d55c44d3d0/",
+ "idolcomplex": "https://idol.sankakucomplex.com/post/show/1",
+ "nijie" : "https://nijie.info/view.php?id=1",
+ "horne" : "https://horne.red/view.php?id=1",
+ "test" : "test:",
+}
if __name__ == "__main__":
diff --git a/test/test_downloader.py b/test/test_downloader.py
index c65be95..840e078 100644
--- a/test/test_downloader.py
+++ b/test/test_downloader.py
@@ -34,6 +34,7 @@ class FakeJob():
def __init__(self):
self.extractor = extractor.find("test:")
+ self.extractor.initialize()
self.pathfmt = path.PathFormat(self.extractor)
self.out = output.NullOutput()
self.get_logger = logging.getLogger
diff --git a/test/test_extractor.py b/test/test_extractor.py
index 6516fa8..9387f5b 100644
--- a/test/test_extractor.py
+++ b/test/test_extractor.py
@@ -93,20 +93,24 @@ class TestExtractorModule(unittest.TestCase):
FakeExtractor.from_url(invalid)
def test_unique_pattern_matches(self):
- test_urls = []
+ try:
+ import test.results
+ except ImportError:
+ raise unittest.SkipTest("no test data")
# collect testcase URLs
+ test_urls = []
append = test_urls.append
- for extr in extractor.extractors():
- for testcase in extr._get_tests():
- append((testcase[0], extr))
+
+ for result in test.results.all():
+ append((result["#url"], result["#class"]))
# iterate over all testcase URLs
for url, extr1 in test_urls:
matches = []
# ... and apply all regex patterns to each one
- for extr2 in extractor._cache:
+ for extr2 in _list_classes():
# skip DirectlinkExtractor pattern if it isn't tested
if extr1 != DirectlinkExtractor and \
@@ -132,8 +136,29 @@ class TestExtractorModule(unittest.TestCase):
else:
self.assertIs(extr1, matches[0][1], url)
+ def test_init(self):
+ """Test for exceptions in Extractor.initialize() and .finalize()"""
+ for cls in extractor.extractors():
+ if cls.category == "ytdl":
+ continue
+ extr = cls.from_url(cls.example)
+ extr.initialize()
+ extr.finalize()
+
+ @unittest.skipIf(sys.hexversion < 0x3060000, "test fails in CI")
+ def test_init_ytdl(self):
+ try:
+ extr = extractor.find("ytdl:")
+ extr.initialize()
+ extr.finalize()
+ except ImportError as exc:
+ if exc.name in ("youtube_dl", "yt_dlp"):
+ raise unittest.SkipTest("cannot import module '{}'".format(
+ exc.name))
+ raise
+
def test_docstrings(self):
- """ensure docstring uniqueness"""
+ """Ensure docstring uniqueness"""
for extr1 in extractor.extractors():
for extr2 in extractor.extractors():
if extr1 != extr2 and extr1.__doc__ and extr2.__doc__:
diff --git a/test/test_formatter.py b/test/test_formatter.py
index 0992f4b..dbdccba 100644
--- a/test/test_formatter.py
+++ b/test/test_formatter.py
@@ -340,6 +340,8 @@ class TestFormatter(unittest.TestCase):
self._run_test("{'foobar'[:3]}", value)
self._run_test("{z|'foo'}" , value)
self._run_test("{z|''|'foo'}" , value)
+ self._run_test("{z|''}" , "")
+ self._run_test("{''|''}" , "")
self._run_test("{_lit[foo]}" , value)
self._run_test("{_lit[foo]!u}" , value.upper())
diff --git a/test/test_postprocessor.py b/test/test_postprocessor.py
index 554a51e..c00144e 100644
--- a/test/test_postprocessor.py
+++ b/test/test_postprocessor.py
@@ -102,10 +102,10 @@ class BasePostprocessorTest(unittest.TestCase):
pp = postprocessor.find(self.__class__.__name__[:-4].lower())
return pp(self.job, options)
- def _trigger(self, events=None, *args):
+ def _trigger(self, events=None):
for event in (events or ("prepare", "file")):
for callback in self.job.hooks[event]:
- callback(self.pathfmt, *args)
+ callback(self.pathfmt)
class ClassifyTest(BasePostprocessorTest):
@@ -579,6 +579,40 @@ class MtimeTest(BasePostprocessorTest):
self.assertEqual(self.pathfmt.kwdict["_mtime"], 315532800)
+class PythonTest(BasePostprocessorTest):
+
+ def test_module(self):
+ path = os.path.join(self.dir.name, "module.py")
+ self._write_module(path)
+
+ sys.path.insert(0, self.dir.name)
+ try:
+ self._create({"function": "module:calc"}, {"_value": 123})
+ finally:
+ del sys.path[0]
+
+ self.assertNotIn("_result", self.pathfmt.kwdict)
+ self._trigger()
+ self.assertEqual(self.pathfmt.kwdict["_result"], 246)
+
+ def test_path(self):
+ path = os.path.join(self.dir.name, "module.py")
+ self._write_module(path)
+
+ self._create({"function": path + ":calc"}, {"_value": 12})
+
+ self.assertNotIn("_result", self.pathfmt.kwdict)
+ self._trigger()
+ self.assertEqual(self.pathfmt.kwdict["_result"], 24)
+
+ def _write_module(self, path):
+ with open(path, "w") as fp:
+ fp.write("""
+def calc(kwdict):
+ kwdict["_result"] = kwdict["_value"] * 2
+""")
+
+
class ZipTest(BasePostprocessorTest):
def test_zip_default(self):
@@ -645,7 +679,7 @@ class ZipTest(BasePostprocessorTest):
self.assertEqual(len(pp.zfile.NameToInfo), 4)
# close file
- self._trigger(("finalize",), 0)
+ self._trigger(("finalize",))
# reopen to check persistence
with zipfile.ZipFile(pp.zfile.filename) as file:
@@ -678,7 +712,7 @@ class ZipTest(BasePostprocessorTest):
self._trigger()
# close file
- self._trigger(("finalize",), 0)
+ self._trigger(("finalize",))
self.assertEqual(pp.zfile.write.call_count, 3)
for call in pp.zfile.write.call_args_list:
diff --git a/test/test_results.py b/test/test_results.py
index 3c7d284..4fb22c7 100644
--- a/test/test_results.py
+++ b/test/test_results.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
-# Copyright 2015-2022 Mike Fährmann
+# Copyright 2015-2023 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
@@ -15,10 +15,12 @@ import re
import json
import hashlib
import datetime
+import collections
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from gallery_dl import \
extractor, util, job, config, exception, formatter # noqa E402
+from test import results # noqa E402
# temporary issues, etc.
@@ -46,28 +48,40 @@ class TestExtractorResults(unittest.TestCase):
for url, exc in cls._skipped:
print('- {} ("{}")'.format(url, exc))
- def _run_test(self, extr, url, result):
- if result:
- if "options" in result:
- for key, value in result["options"]:
- key = key.split(".")
- config.set(key[:-1], key[-1], value)
- if "range" in result:
- config.set((), "image-range" , result["range"])
- config.set((), "chapter-range", result["range"])
- content = "content" in result
+ def assertRange(self, value, range, msg=None):
+ if range.step > 1:
+ self.assertIn(value, range, msg=msg)
else:
+ self.assertLessEqual(value, range.stop, msg=msg)
+ self.assertGreaterEqual(value, range.start, msg=msg)
+
+ def _run_test(self, result):
+ result.pop("#comment", None)
+ only_matching = (len(result) <= 3)
+
+ if only_matching:
content = False
+ else:
+ if "#options" in result:
+ for key, value in result["#options"].items():
+ key = key.split(".")
+ config.set(key[:-1], key[-1], value)
+ if "#range" in result:
+ config.set((), "image-range" , result["#range"])
+ config.set((), "chapter-range", result["#range"])
+ content = ("#sha1_content" in result)
- tjob = ResultJob(url, content=content)
- self.assertEqual(extr, tjob.extractor.__class__)
+ tjob = ResultJob(result["#url"], content=content)
+ self.assertEqual(result["#class"], tjob.extractor.__class__, "#class")
- if not result:
+ if only_matching:
return
- if "exception" in result:
- with self.assertRaises(result["exception"]):
+
+ if "#exception" in result:
+ with self.assertRaises(result["#exception"], msg="#exception"):
tjob.run()
return
+
try:
tjob.run()
except exception.StopExtraction:
@@ -76,64 +90,85 @@ class TestExtractorResults(unittest.TestCase):
exc = str(exc)
if re.match(r"'5\d\d ", exc) or \
re.search(r"\bRead timed out\b", exc):
- self._skipped.append((url, exc))
+ self._skipped.append((result["#url"], exc))
self.skipTest(exc)
raise
- if result.get("archive", True):
+ if result.get("#archive", True):
self.assertEqual(
len(set(tjob.archive_list)),
len(tjob.archive_list),
- "archive-id uniqueness",
- )
+ msg="archive-id uniqueness")
if tjob.queue:
# test '_extractor' entries
for url, kwdict in zip(tjob.url_list, tjob.kwdict_list):
if "_extractor" in kwdict:
extr = kwdict["_extractor"].from_url(url)
- if extr is None and not result.get("extractor", True):
+ if extr is None and not result.get("#extractor", True):
continue
self.assertIsInstance(extr, kwdict["_extractor"])
self.assertEqual(extr.url, url)
else:
# test 'extension' entries
for kwdict in tjob.kwdict_list:
- self.assertIn("extension", kwdict)
+ self.assertIn("extension", kwdict, msg="#extension")
# test extraction results
- if "url" in result:
- self.assertEqual(result["url"], tjob.url_hash.hexdigest())
+ if "#sha1_url" in result:
+ self.assertEqual(
+ result["#sha1_url"],
+ tjob.url_hash.hexdigest(),
+ msg="#sha1_url")
- if "content" in result:
- expected = result["content"]
+ if "#sha1_content" in result:
+ expected = result["#sha1_content"]
digest = tjob.content_hash.hexdigest()
if isinstance(expected, str):
- self.assertEqual(digest, expected, "content")
- else: # assume iterable
- self.assertIn(digest, expected, "content")
-
- if "keyword" in result:
- expected = result["keyword"]
- if isinstance(expected, dict):
- for kwdict in tjob.kwdict_list:
- self._test_kwdict(kwdict, expected)
- else: # assume SHA1 hash
- self.assertEqual(expected, tjob.kwdict_hash.hexdigest())
-
- if "count" in result:
- count = result["count"]
+ self.assertEqual(expected, digest, msg="#sha1_content")
+ else: # iterable
+ self.assertIn(digest, expected, msg="#sha1_content")
+
+ if "#sha1_metadata" in result:
+ self.assertEqual(
+ result["#sha1_metadata"],
+ tjob.kwdict_hash.hexdigest(),
+ "#sha1_metadata")
+
+ if "#count" in result:
+ count = result["#count"]
+ len_urls = len(tjob.url_list)
if isinstance(count, str):
- self.assertRegex(count, r"^ *(==|!=|<|<=|>|>=) *\d+ *$")
- expr = "{} {}".format(len(tjob.url_list), count)
+ self.assertRegex(
+ count, r"^ *(==|!=|<|<=|>|>=) *\d+ *$", msg="#count")
+ expr = "{} {}".format(len_urls, count)
self.assertTrue(eval(expr), msg=expr)
+ elif isinstance(count, range):
+ self.assertRange(len_urls, count, msg="#count")
else: # assume integer
- self.assertEqual(len(tjob.url_list), count)
+ self.assertEqual(len_urls, count, msg="#count")
+
+ if "#pattern" in result:
+ self.assertGreater(len(tjob.url_list), 0, msg="#pattern")
+ pattern = result["#pattern"]
+ if isinstance(pattern, str):
+ for url in tjob.url_list:
+ self.assertRegex(url, pattern, msg="#pattern")
+ else:
+ for url, pat in zip(tjob.url_list, pattern):
+ self.assertRegex(url, pat, msg="#pattern")
- if "pattern" in result:
- self.assertGreater(len(tjob.url_list), 0)
- for url in tjob.url_list:
- self.assertRegex(url, result["pattern"])
+ if "#urls" in result:
+ expected = result["#urls"]
+ if isinstance(expected, str):
+ self.assertEqual(tjob.url_list[0], expected, msg="#urls")
+ else:
+ self.assertSequenceEqual(tjob.url_list, expected, msg="#urls")
+
+ metadata = {k: v for k, v in result.items() if k[0] != "#"}
+ if metadata:
+ for kwdict in tjob.kwdict_list:
+ self._test_kwdict(kwdict, metadata)
def _test_kwdict(self, kwdict, tests):
for key, test in tests.items():
@@ -141,13 +176,15 @@ class TestExtractorResults(unittest.TestCase):
key = key[1:]
if key not in kwdict:
continue
- self.assertIn(key, kwdict)
+ self.assertIn(key, kwdict, msg=key)
value = kwdict[key]
if isinstance(test, dict):
self._test_kwdict(value, test)
elif isinstance(test, type):
self.assertIsInstance(value, test, msg=key)
+ elif isinstance(test, range):
+ self.assertRange(value, test, msg=key)
elif isinstance(test, list):
subtest = False
for idx, item in enumerate(test):
@@ -188,6 +225,8 @@ class ResultJob(job.DownloadJob):
if content:
self.fileobj = TestPathfmt(self.content_hash)
+ else:
+ self._update_content = lambda url, kwdict: None
self.format_directory = TestFormatter(
"".join(self.extractor.directory_fmt)).format_map
@@ -195,6 +234,7 @@ class ResultJob(job.DownloadJob):
self.extractor.filename_fmt).format_map
def run(self):
+ self._init()
for msg in self.extractor:
self.dispatch(msg)
@@ -234,10 +274,17 @@ class ResultJob(job.DownloadJob):
self.archive_hash.update(archive_id.encode())
def _update_content(self, url, kwdict):
- if self.content:
- scheme = url.partition(":")[0]
- self.fileobj.kwdict = kwdict
- self.get_downloader(scheme).download(url, self.fileobj)
+ self.fileobj.kwdict = kwdict
+
+ downloader = self.get_downloader(url.partition(":")[0])
+ if downloader.download(url, self.fileobj):
+ return
+
+ for num, url in enumerate(kwdict.get("_fallback") or (), 1):
+ self.log.warning("Trying fallback URL #%d", num)
+ downloader = self.get_downloader(url.partition(":")[0])
+ if downloader.download(url, self.fileobj):
+ return
class TestPathfmt():
@@ -322,10 +369,11 @@ def setup_test_config():
config.set(("extractor", "mangoxo") , "username", "LiQiang3")
config.set(("extractor", "mangoxo") , "password", "5zbQF10_5u25259Ma")
- for category in ("danbooru", "atfbooru", "aibooru", "e621", "e926",
+ for category in ("danbooru", "atfbooru", "aibooru", "booruvar",
+ "e621", "e926", "e6ai",
"instagram", "twitter", "subscribestar", "deviantart",
"inkbunny", "tapas", "pillowfort", "mangadex",
- "vipergirls", "gfycat"):
+ "vipergirls"):
config.set(("extractor", category), "username", None)
config.set(("extractor", "mastodon.social"), "access-token",
@@ -351,39 +399,40 @@ def setup_test_config():
def generate_tests():
"""Dynamically generate extractor unittests"""
- def _generate_test(extr, tcase):
+ def _generate_method(result):
def test(self):
- url, result = tcase
- print("\n", url, sep="")
- self._run_test(extr, url, result)
+ print("\n" + result["#url"])
+ self._run_test(result)
return test
# enable selective testing for direct calls
- if __name__ == '__main__' and len(sys.argv) > 1:
- categories = sys.argv[1:]
- negate = False
- if categories[0].lower() == "all":
- categories = ()
- negate = True
- elif categories[0].lower() == "broken":
- categories = BROKEN
+ if __name__ == "__main__" and len(sys.argv) > 1:
+ category, _, subcategory = sys.argv[1].partition(":")
del sys.argv[1:]
+
+ if category.startswith("+"):
+ basecategory = category[1:].lower()
+ tests = [t for t in results.all()
+ if t["#category"][0].lower() == basecategory]
+ else:
+ tests = results.category(category)
+
+ if subcategory:
+ tests = [t for t in tests if t["#category"][-1] == subcategory]
else:
- categories = BROKEN
- negate = True
- if categories:
- print("skipping:", ", ".join(categories))
- fltr = util.build_extractor_filter(categories, negate=negate)
+ tests = results.all()
# add 'test_...' methods
- for extr in filter(fltr, extractor.extractors()):
- name = "test_" + extr.__name__ + "_"
- for num, tcase in enumerate(extr._get_tests(), 1):
- test = _generate_test(extr, tcase)
- test.__name__ = name + str(num)
- setattr(TestExtractorResults, test.__name__, test)
+ enum = collections.defaultdict(int)
+ for result in tests:
+ name = "{1}_{2}".format(*result["#category"])
+ enum[name] += 1
+
+ method = _generate_method(result)
+ method.__name__ = "test_{}_{}".format(name, enum[name])
+ setattr(TestExtractorResults, method.__name__, method)
generate_tests()
-if __name__ == '__main__':
- unittest.main(warnings='ignore')
+if __name__ == "__main__":
+ unittest.main(warnings="ignore")