diff options
Diffstat (limited to 'gallery_dl/archive.py')
| -rw-r--r-- | gallery_dl/archive.py | 189 |
1 files changed, 165 insertions, 24 deletions
diff --git a/gallery_dl/archive.py b/gallery_dl/archive.py index 5f05bbf..edecb10 100644 --- a/gallery_dl/archive.py +++ b/gallery_dl/archive.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -# Copyright 2024 Mike Fährmann +# Copyright 2024-2025 Mike Fährmann # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License version 2 as @@ -9,50 +9,94 @@ """Download Archives""" import os -import sqlite3 -from . import formatter +import logging +from . import util, formatter + +log = logging.getLogger("archive") + + +def connect(path, prefix, format, + table=None, mode=None, pragma=None, kwdict=None, cache_key=None): + keygen = formatter.parse(prefix + format).format_map + + if isinstance(path, str) and path.startswith( + ("postgres://", "postgresql://")): + if mode == "memory": + cls = DownloadArchivePostgresqlMemory + else: + cls = DownloadArchivePostgresql + else: + path = util.expand_path(path) + if kwdict is not None and "{" in path: + path = formatter.parse(path).format_map(kwdict) + if mode == "memory": + cls = DownloadArchiveMemory + else: + cls = DownloadArchive + + if kwdict is not None and table: + table = formatter.parse(table).format_map(kwdict) + + return cls(path, keygen, table, pragma, cache_key) + + +def sanitize(name): + return '"' + name.replace('"', "_") + '"' class DownloadArchive(): + _sqlite3 = None + + def __init__(self, path, keygen, table=None, pragma=None, cache_key=None): + if self._sqlite3 is None: + DownloadArchive._sqlite3 = __import__("sqlite3") - def __init__(self, path, format_string, pragma=None, - cache_key="_archive_key"): try: - con = sqlite3.connect(path, timeout=60, check_same_thread=False) - except sqlite3.OperationalError: + con = self._sqlite3.connect( + path, timeout=60, check_same_thread=False) + except self._sqlite3.OperationalError: os.makedirs(os.path.dirname(path)) - con = sqlite3.connect(path, timeout=60, check_same_thread=False) + con = self._sqlite3.connect( + path, timeout=60, check_same_thread=False) con.isolation_level = None - self.keygen = formatter.parse(format_string).format_map + self.keygen = keygen self.connection = con self.close = con.close self.cursor = cursor = con.cursor() - self._cache_key = cache_key + self._cache_key = cache_key or "_archive_key" + + table = "archive" if table is None else sanitize(table) + self._stmt_select = ( + "SELECT 1 " + "FROM " + table + " " + "WHERE entry=? " + "LIMIT 1") + self._stmt_insert = ( + "INSERT OR IGNORE INTO " + table + " " + "(entry) VALUES (?)") if pragma: for stmt in pragma: cursor.execute("PRAGMA " + stmt) try: - cursor.execute("CREATE TABLE IF NOT EXISTS archive " + cursor.execute("CREATE TABLE IF NOT EXISTS " + table + " " "(entry TEXT PRIMARY KEY) WITHOUT ROWID") - except sqlite3.OperationalError: + except self._sqlite3.OperationalError: # fallback for missing WITHOUT ROWID support (#553) - cursor.execute("CREATE TABLE IF NOT EXISTS archive " + cursor.execute("CREATE TABLE IF NOT EXISTS " + table + " " "(entry TEXT PRIMARY KEY)") def add(self, kwdict): """Add item described by 'kwdict' to archive""" key = kwdict.get(self._cache_key) or self.keygen(kwdict) - self.cursor.execute( - "INSERT OR IGNORE INTO archive (entry) VALUES (?)", (key,)) + self.cursor.execute(self._stmt_insert, (key,)) def check(self, kwdict): """Return True if the item described by 'kwdict' exists in archive""" key = kwdict[self._cache_key] = self.keygen(kwdict) - self.cursor.execute( - "SELECT 1 FROM archive WHERE entry=? LIMIT 1", (key,)) + self.cursor.execute(self._stmt_select, (key,)) return self.cursor.fetchone() def finalize(self): @@ -61,9 +105,9 @@ class DownloadArchive(): class DownloadArchiveMemory(DownloadArchive): - def __init__(self, path, format_string, pragma=None, - cache_key="_archive_key"): - DownloadArchive.__init__(self, path, format_string, pragma, cache_key) + def __init__(self, path, keygen, table=None, pragma=None, cache_key=None): + DownloadArchive.__init__( + self, path, keygen, table, pragma, cache_key) self.keys = set() def add(self, kwdict): @@ -75,8 +119,7 @@ class DownloadArchiveMemory(DownloadArchive): key = kwdict[self._cache_key] = self.keygen(kwdict) if key in self.keys: return True - self.cursor.execute( - "SELECT 1 FROM archive WHERE entry=? LIMIT 1", (key,)) + self.cursor.execute(self._stmt_select, (key,)) return self.cursor.fetchone() def finalize(self): @@ -87,12 +130,110 @@ class DownloadArchiveMemory(DownloadArchive): with self.connection: try: cursor.execute("BEGIN") - except sqlite3.OperationalError: + except self._sqlite3.OperationalError: pass - stmt = "INSERT OR IGNORE INTO archive (entry) VALUES (?)" + stmt = self._stmt_insert if len(self.keys) < 100: for key in self.keys: cursor.execute(stmt, (key,)) else: cursor.executemany(stmt, ((key,) for key in self.keys)) + + +class DownloadArchivePostgresql(): + _psycopg = None + + def __init__(self, uri, keygen, table=None, pragma=None, cache_key=None): + if self._psycopg is None: + DownloadArchivePostgresql._psycopg = __import__("psycopg") + + self.connection = con = self._psycopg.connect(uri) + self.cursor = cursor = con.cursor() + self.close = con.close + self.keygen = keygen + self._cache_key = cache_key or "_archive_key" + + table = "archive" if table is None else sanitize(table) + self._stmt_select = ( + "SELECT true " + "FROM " + table + " " + "WHERE entry=%s " + "LIMIT 1") + self._stmt_insert = ( + "INSERT INTO " + table + " (entry) " + "VALUES (%s) " + "ON CONFLICT DO NOTHING") + + try: + cursor.execute("CREATE TABLE IF NOT EXISTS " + table + " " + "(entry TEXT PRIMARY KEY)") + con.commit() + except Exception as exc: + log.error("%s: %s when creating '%s' table: %s", + con, exc.__class__.__name__, table, exc) + con.rollback() + raise + + def add(self, kwdict): + key = kwdict.get(self._cache_key) or self.keygen(kwdict) + try: + self.cursor.execute(self._stmt_insert, (key,)) + self.connection.commit() + except Exception as exc: + log.error("%s: %s when writing entry: %s", + self.connection, exc.__class__.__name__, exc) + self.connection.rollback() + + def check(self, kwdict): + key = kwdict[self._cache_key] = self.keygen(kwdict) + try: + self.cursor.execute(self._stmt_select, (key,)) + return self.cursor.fetchone() + except Exception as exc: + log.error("%s: %s when checking entry: %s", + self.connection, exc.__class__.__name__, exc) + self.connection.rollback() + return False + + def finalize(self): + pass + + +class DownloadArchivePostgresqlMemory(DownloadArchivePostgresql): + + def __init__(self, path, keygen, table=None, pragma=None, cache_key=None): + DownloadArchivePostgresql.__init__( + self, path, keygen, table, pragma, cache_key) + self.keys = set() + + def add(self, kwdict): + self.keys.add( + kwdict.get(self._cache_key) or + self.keygen(kwdict)) + + def check(self, kwdict): + key = kwdict[self._cache_key] = self.keygen(kwdict) + if key in self.keys: + return True + try: + self.cursor.execute(self._stmt_select, (key,)) + return self.cursor.fetchone() + except Exception as exc: + log.error("%s: %s when checking entry: %s", + self.connection, exc.__class__.__name__, exc) + self.connection.rollback() + return False + + def finalize(self): + if not self.keys: + return + try: + self.cursor.executemany( + self._stmt_insert, + ((key,) for key in self.keys)) + self.connection.commit() + except Exception as exc: + log.error("%s: %s when writing entries: %s", + self.connection, exc.__class__.__name__, exc) + self.connection.rollback() |
