summaryrefslogtreecommitdiffstats
path: root/gallery_dl/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'gallery_dl/__init__.py')
-rw-r--r--gallery_dl/__init__.py93
1 files changed, 74 insertions, 19 deletions
diff --git a/gallery_dl/__init__.py b/gallery_dl/__init__.py
index 245dbf8..116ca5d 100644
--- a/gallery_dl/__init__.py
+++ b/gallery_dl/__init__.py
@@ -11,7 +11,7 @@ import logging
from . import version, config, option, output, extractor, job, util, exception
__author__ = "Mike Fährmann"
-__copyright__ = "Copyright 2014-2022 Mike Fährmann"
+__copyright__ = "Copyright 2014-2023 Mike Fährmann"
__license__ = "GPLv2"
__maintainer__ = "Mike Fährmann"
__email__ = "mike_faehrmann@web.de"
@@ -33,20 +33,24 @@ def progress(urls, pformat):
def main():
try:
- if sys.stdout and sys.stdout.encoding.lower() != "utf-8":
- output.replace_std_streams()
-
parser = option.build_parser()
args = parser.parse_args()
log = output.initialize_logging(args.loglevel)
# configuration
- if args.load_config:
+ if args.config_load:
config.load()
- if args.cfgfiles:
- config.load(args.cfgfiles, strict=True)
- if args.yamlfiles:
- config.load(args.yamlfiles, strict=True, fmt="yaml")
+ if args.configs_json:
+ config.load(args.configs_json, strict=True)
+ if args.configs_yaml:
+ import yaml
+ config.load(args.configs_yaml, strict=True, load=yaml.safe_load)
+ if args.configs_toml:
+ try:
+ import tomllib as toml
+ except ImportError:
+ import toml
+ config.load(args.configs_toml, strict=True, load=toml.loads)
if args.filename:
filename = args.filename
if filename == "/O":
@@ -77,6 +81,8 @@ def main():
for opts in args.options:
config.set(*opts)
+ output.configure_standard_streams()
+
# signals
signals = config.get((), "signals-ignore")
if signals:
@@ -105,20 +111,17 @@ def main():
output.ANSI = True
- # extractor modules
- modules = config.get(("extractor",), "modules")
- if modules is not None:
- if isinstance(modules, str):
- modules = modules.split(",")
- extractor.modules = modules
- extractor._module_iter = iter(modules)
-
# format string separator
separator = config.get((), "format-separator")
if separator:
from . import formatter
formatter._SEPARATOR = separator
+ # eval globals
+ path = config.get((), "globals")
+ if path:
+ util.GLOBALS = util.import_file(path).__dict__
+
# loglevels
output.configure_logging(args.loglevel)
if args.loglevel >= logging.ERROR:
@@ -128,7 +131,7 @@ def main():
import requests
extra = ""
- if getattr(sys, "frozen", False):
+ if util.EXECUTABLE:
extra = " - Executable"
else:
git_head = util.git_head()
@@ -147,6 +150,44 @@ def main():
log.debug("Configuration Files %s", config._files)
+ # extractor modules
+ modules = config.get(("extractor",), "modules")
+ if modules is not None:
+ if isinstance(modules, str):
+ modules = modules.split(",")
+ extractor.modules = modules
+
+ # external modules
+ if args.extractor_sources:
+ sources = args.extractor_sources
+ sources.append(None)
+ else:
+ sources = config.get(("extractor",), "module-sources")
+
+ if sources:
+ import os
+ modules = []
+
+ for source in sources:
+ if source:
+ path = util.expand_path(source)
+ try:
+ files = os.listdir(path)
+ modules.append(extractor._modules_path(path, files))
+ except Exception as exc:
+ log.warning("Unable to load modules from %s (%s: %s)",
+ path, exc.__class__.__name__, exc)
+ else:
+ modules.append(extractor._modules_internal())
+
+ if len(modules) > 1:
+ import itertools
+ extractor._module_iter = itertools.chain(*modules)
+ elif not modules:
+ extractor._module_iter = ()
+ else:
+ extractor._module_iter = iter(modules[0])
+
if args.list_modules:
extractor.modules.append("")
sys.stdout.write("\n".join(extractor.modules))
@@ -177,6 +218,10 @@ def main():
"Deleted %d %s from '%s'",
cnt, "entry" if cnt == 1 else "entries", cache._path(),
)
+
+ elif args.config_init:
+ return config.initialize()
+
else:
if not args.urls and not args.inputfiles:
parser.error(
@@ -220,9 +265,13 @@ def main():
pformat = config.get(("output",), "progress", True)
if pformat and len(urls) > 1 and args.loglevel < logging.ERROR:
urls = progress(urls, pformat)
+ else:
+ urls = iter(urls)
retval = 0
- for url in urls:
+ url = next(urls, None)
+
+ while url is not None:
try:
log.debug("Starting %s for '%s'", jobtype.__name__, url)
if isinstance(url, util.ExtendedUrl):
@@ -234,9 +283,15 @@ def main():
retval |= jobtype(url).run()
except exception.TerminateExtraction:
pass
+ except exception.RestartExtraction:
+ log.debug("Restarting '%s'", url)
+ continue
except exception.NoExtractorError:
log.error("Unsupported URL '%s'", url)
retval |= 64
+
+ url = next(urls, None)
+
return retval
except KeyboardInterrupt: