aboutsummaryrefslogtreecommitdiffstats
path: root/gallery_dl/extractor/civitai.py
diff options
context:
space:
mode:
authorLibravatarUnit 193 <unit193@unit193.net>2024-09-28 20:01:31 -0400
committerLibravatarUnit 193 <unit193@unit193.net>2024-09-28 20:01:31 -0400
commit5a7a504aeab6c6de4b76e7f69a6dae5308bd6c0a (patch)
tree91d6f25b3b51e7acbd1f61278709eea080f4aa78 /gallery_dl/extractor/civitai.py
parent817f6c044037816056a8aadf6f95394f851df732 (diff)
parent1a457ed68769880ab7760e0746f0cbbd9ca00487 (diff)
Update upstream source from tag 'upstream/1.27.5'
Update to upstream version '1.27.5' with Debian dir 2ce58072136364618b4bed910cf33746af25cbd7
Diffstat (limited to 'gallery_dl/extractor/civitai.py')
-rw-r--r--gallery_dl/extractor/civitai.py490
1 files changed, 490 insertions, 0 deletions
diff --git a/gallery_dl/extractor/civitai.py b/gallery_dl/extractor/civitai.py
new file mode 100644
index 0000000..3e657d6
--- /dev/null
+++ b/gallery_dl/extractor/civitai.py
@@ -0,0 +1,490 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2024 Mike Fährmann
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License version 2 as
+# published by the Free Software Foundation.
+
+"""Extractors for https://www.civitai.com/"""
+
+from .common import Extractor, Message
+from .. import text, util
+import itertools
+import time
+
+BASE_PATTERN = r"(?:https?://)?civitai\.com"
+USER_PATTERN = BASE_PATTERN + r"/user/([^/?#]+)"
+
+
+class CivitaiExtractor(Extractor):
+ """Base class for civitai extractors"""
+ category = "civitai"
+ root = "https://civitai.com"
+ directory_fmt = ("{category}", "{username|user[username]}", "images")
+ filename_fmt = "{id}.{extension}"
+ archive_fmt = "{hash}"
+ request_interval = (0.5, 1.5)
+
+ def _init(self):
+ if self.config("api") == "trpc":
+ self.log.debug("Using tRPC API")
+ self.api = CivitaiTrpcAPI(self)
+ else:
+ self.log.debug("Using REST API")
+ self.api = CivitaiRestAPI(self)
+
+ quality = self.config("quality")
+ if quality:
+ if not isinstance(quality, str):
+ quality = ",".join(quality)
+ self._image_quality = quality
+ self._image_ext = ("png" if quality == "original=true" else "jpg")
+ else:
+ self._image_quality = "original=true"
+ self._image_ext = "png"
+
+ def items(self):
+ models = self.models()
+ if models:
+ data = {"_extractor": CivitaiModelExtractor}
+ for model in models:
+ url = "{}/models/{}".format(self.root, model["id"])
+ yield Message.Queue, url, data
+ return
+
+ images = self.images()
+ if images:
+ for image in images:
+ url = self._url(image)
+ image["date"] = text.parse_datetime(
+ image["createdAt"], "%Y-%m-%dT%H:%M:%S.%fZ")
+ text.nameext_from_url(url, image)
+ image["extension"] = self._image_ext
+ yield Message.Directory, image
+ yield Message.Url, url, image
+ return
+
+ def models(self):
+ return ()
+
+ def images(self):
+ return ()
+
+ def _url(self, image):
+ url = image["url"]
+ if "/" in url:
+ parts = url.rsplit("/", 2)
+ parts[1] = self._image_quality
+ return "/".join(parts)
+
+ name = image.get("name")
+ if not name:
+ mime = image.get("mimeType") or self._image_ext
+ name = "{}.{}".format(image.get("id"), mime.rpartition("/")[2])
+ return (
+ "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/{}/{}/{}".format(
+ url, self._image_quality, name)
+ )
+
+
+class CivitaiModelExtractor(CivitaiExtractor):
+ subcategory = "model"
+ directory_fmt = ("{category}", "{user[username]}",
+ "{model[id]}{model[name]:? //}",
+ "{version[id]}{version[name]:? //}")
+ filename_fmt = "{filename}.{extension}"
+ archive_fmt = "{file[hash]}"
+ pattern = BASE_PATTERN + r"/models/(\d+)(?:/?\?modelVersionId=(\d+))?"
+ example = "https://civitai.com/models/12345/TITLE"
+
+ def items(self):
+ model_id, version_id = self.groups
+ model = self.api.model(model_id)
+
+ if "user" in model:
+ user = model["user"]
+ del model["user"]
+ else:
+ user = model["creator"]
+ del model["creator"]
+ versions = model["modelVersions"]
+ del model["modelVersions"]
+
+ if version_id:
+ version_id = int(version_id)
+ for version in versions:
+ if version["id"] == version_id:
+ break
+ else:
+ version = self.api.model_version(version_id)
+ versions = (version,)
+
+ for version in versions:
+ version["date"] = text.parse_datetime(
+ version["createdAt"], "%Y-%m-%dT%H:%M:%S.%fZ")
+
+ data = {
+ "model" : model,
+ "version": version,
+ "user" : user,
+ }
+
+ yield Message.Directory, data
+ for file in self._extract_files(model, version, user):
+ file.update(data)
+ yield Message.Url, file["url"], file
+
+ def _extract_files(self, model, version, user):
+ filetypes = self.config("files")
+ if filetypes is None:
+ return self._extract_files_image(model, version, user)
+
+ generators = {
+ "model" : self._extract_files_model,
+ "image" : self._extract_files_image,
+ "gallery" : self._extract_files_gallery,
+ "gallerie": self._extract_files_gallery,
+ }
+ if isinstance(filetypes, str):
+ filetypes = filetypes.split(",")
+
+ return itertools.chain.from_iterable(
+ generators[ft.rstrip("s")](model, version, user)
+ for ft in filetypes
+ )
+
+ def _extract_files_model(self, model, version, user):
+ return [
+ {
+ "num" : num,
+ "file" : file,
+ "filename" : file["name"],
+ "extension": "bin",
+ "url" : file["downloadUrl"],
+ "_http_headers" : {
+ "Authorization": self.api.headers.get("Authorization")},
+ "_http_validate": self._validate_file_model,
+ }
+ for num, file in enumerate(version["files"], 1)
+ ]
+
+ def _extract_files_image(self, model, version, user):
+ if "images" in version:
+ images = version["images"]
+ else:
+ params = {
+ "modelVersionId": version["id"],
+ "prioritizedUserIds": [user["id"]],
+ "period": "AllTime",
+ "sort": "Most Reactions",
+ "limit": 20,
+ "pending": True,
+ }
+ images = self.api.images(params, defaults=False)
+
+ return [
+ text.nameext_from_url(file["url"], {
+ "num" : num,
+ "file": file,
+ "url" : self._url(file),
+ })
+ for num, file in enumerate(images, 1)
+ ]
+
+ def _extract_files_gallery(self, model, version, user):
+ images = self.api.images_gallery(model, version, user)
+ for num, file in enumerate(images, 1):
+ yield text.nameext_from_url(file["url"], {
+ "num" : num,
+ "file": file,
+ "url" : self._url(file),
+ })
+
+ def _validate_file_model(self, response):
+ if response.headers.get("Content-Type", "").startswith("text/html"):
+ alert = text.extr(
+ response.text, 'mantine-Alert-message">', "</div></div></div>")
+ if alert:
+ msg = "\"{}\" - 'api-key' required".format(
+ text.remove_html(alert))
+ else:
+ msg = "'api-key' required to download this file"
+ self.log.warning(msg)
+ return False
+ return True
+
+
+class CivitaiImageExtractor(CivitaiExtractor):
+ subcategory = "image"
+ pattern = BASE_PATTERN + r"/images/(\d+)"
+ example = "https://civitai.com/images/12345"
+
+ def images(self):
+ return self.api.image(self.groups[0])
+
+
+class CivitaiTagModelsExtractor(CivitaiExtractor):
+ subcategory = "tag-models"
+ pattern = BASE_PATTERN + r"/(?:tag/|models\?tag=)([^/?&#]+)"
+ example = "https://civitai.com/tag/TAG"
+
+ def models(self):
+ tag = text.unquote(self.groups[0])
+ return self.api.models({"tag": tag})
+
+
+class CivitaiTagImagesExtractor(CivitaiExtractor):
+ subcategory = "tag-images"
+ pattern = BASE_PATTERN + r"/images\?tags=([^&#]+)"
+ example = "https://civitai.com/images?tags=12345"
+
+ def images(self):
+ tag = text.unquote(self.groups[0])
+ return self.api.images({"tag": tag})
+
+
+class CivitaiSearchExtractor(CivitaiExtractor):
+ subcategory = "search"
+ pattern = BASE_PATTERN + r"/search/models\?([^#]+)"
+ example = "https://civitai.com/search/models?query=QUERY"
+
+ def models(self):
+ params = text.parse_query(self.groups[0])
+ return self.api.models(params)
+
+
+class CivitaiUserExtractor(CivitaiExtractor):
+ subcategory = "user"
+ pattern = USER_PATTERN + r"/?(?:$|\?|#)"
+ example = "https://civitai.com/user/USER"
+
+ def initialize(self):
+ pass
+
+ def items(self):
+ base = "{}/user/{}/".format(self.root, self.groups[0])
+ return self._dispatch_extractors((
+ (CivitaiUserModelsExtractor, base + "models"),
+ (CivitaiUserImagesExtractor, base + "images"),
+ ), ("user-models", "user-images"))
+
+
+class CivitaiUserModelsExtractor(CivitaiExtractor):
+ subcategory = "user-models"
+ pattern = USER_PATTERN + r"/models/?(?:\?([^#]+))?"
+ example = "https://civitai.com/user/USER/models"
+
+ def models(self):
+ params = text.parse_query(self.groups[1])
+ params["username"] = text.unquote(self.groups[0])
+ return self.api.models(params)
+
+
+class CivitaiUserImagesExtractor(CivitaiExtractor):
+ subcategory = "user-images"
+ pattern = USER_PATTERN + r"/images/?(?:\?([^#]+))?"
+ example = "https://civitai.com/user/USER/images"
+
+ def images(self):
+ params = text.parse_query(self.groups[1])
+ params["username"] = text.unquote(self.groups[0])
+ return self.api.images(params)
+
+
+class CivitaiRestAPI():
+ """Interface for the Civitai Public REST API
+
+ https://developer.civitai.com/docs/api/public-rest
+ """
+
+ def __init__(self, extractor):
+ self.extractor = extractor
+ self.root = extractor.root + "/api"
+ self.headers = {"Content-Type": "application/json"}
+
+ api_key = extractor.config("api-key")
+ if api_key:
+ extractor.log.debug("Using api_key authentication")
+ self.headers["Authorization"] = "Bearer " + api_key
+
+ nsfw = extractor.config("nsfw")
+ if nsfw is None or nsfw is True:
+ nsfw = "X"
+ elif not nsfw:
+ nsfw = "Safe"
+ self.nsfw = nsfw
+
+ def image(self, image_id):
+ return self.images({
+ "imageId": image_id,
+ })
+
+ def images(self, params):
+ endpoint = "/v1/images"
+ if "nsfw" not in params:
+ params["nsfw"] = self.nsfw
+ return self._pagination(endpoint, params)
+
+ def images_gallery(self, model, version, user):
+ return self.images({
+ "modelId" : model["id"],
+ "modelVersionId": version["id"],
+ })
+
+ def model(self, model_id):
+ endpoint = "/v1/models/{}".format(model_id)
+ return self._call(endpoint)
+
+ def model_version(self, model_version_id):
+ endpoint = "/v1/model-versions/{}".format(model_version_id)
+ return self._call(endpoint)
+
+ def models(self, params):
+ return self._pagination("/v1/models", params)
+
+ def _call(self, endpoint, params=None):
+ if endpoint[0] == "/":
+ url = self.root + endpoint
+ else:
+ url = endpoint
+
+ response = self.extractor.request(
+ url, params=params, headers=self.headers)
+ return response.json()
+
+ def _pagination(self, endpoint, params):
+ while True:
+ data = self._call(endpoint, params)
+ yield from data["items"]
+
+ try:
+ endpoint = data["metadata"]["nextPage"]
+ except KeyError:
+ return
+ params = None
+
+
+class CivitaiTrpcAPI():
+ """Interface for the Civitai TRPC API"""
+
+ def __init__(self, extractor):
+ self.extractor = extractor
+ self.root = extractor.root + "/api/trpc/"
+ self.headers = {
+ "content-type" : "application/json",
+ "x-client-version": "5.0.94",
+ "x-client-date" : "",
+ "x-client" : "web",
+ "x-fingerprint" : "undefined",
+ }
+ api_key = extractor.config("api-key")
+ if api_key:
+ extractor.log.debug("Using api_key authentication")
+ self.headers["Authorization"] = "Bearer " + api_key
+
+ nsfw = extractor.config("nsfw")
+ if nsfw is None or nsfw is True:
+ nsfw = 31
+ elif not nsfw:
+ nsfw = 1
+ self.nsfw = nsfw
+
+ def image(self, image_id):
+ endpoint = "image.get"
+ params = {"id": int(image_id)}
+ return (self._call(endpoint, params),)
+
+ def images(self, params, defaults=True):
+ endpoint = "image.getInfinite"
+
+ if defaults:
+ params_ = {
+ "useIndex" : True,
+ "period" : "AllTime",
+ "sort" : "Newest",
+ "types" : ["image"],
+ "withMeta" : False, # Metadata Only
+ "fromPlatform" : False, # Made On-Site
+ "browsingLevel": self.nsfw,
+ "include" : ["cosmetics"],
+ }
+ params_.update(params)
+ else:
+ params_ = params
+
+ return self._pagination(endpoint, params_)
+
+ def images_gallery(self, model, version, user):
+ endpoint = "image.getImagesAsPostsInfinite"
+ params = {
+ "period" : "AllTime",
+ "sort" : "Newest",
+ "modelVersionId": version["id"],
+ "modelId" : model["id"],
+ "hidden" : False,
+ "limit" : 50,
+ "browsingLevel" : self.nsfw,
+ }
+
+ for post in self._pagination(endpoint, params):
+ yield from post["images"]
+
+ def model(self, model_id):
+ endpoint = "model.getById"
+ params = {"id": int(model_id)}
+ return self._call(endpoint, params)
+
+ def model_version(self, model_version_id):
+ endpoint = "modelVersion.getById"
+ params = {"id": int(model_version_id)}
+ return self._call(endpoint, params)
+
+ def models(self, params, defaults=True):
+ endpoint = "model.getAll"
+
+ if defaults:
+ params_ = {
+ "period" : "AllTime",
+ "periodMode" : "published",
+ "sort" : "Newest",
+ "pending" : False,
+ "hidden" : False,
+ "followed" : False,
+ "earlyAccess" : False,
+ "fromPlatform" : False,
+ "supportsGeneration": False,
+ "browsingLevel": self.nsfw,
+ }
+ params_.update(params)
+ else:
+ params_ = params
+
+ return self._pagination(endpoint, params_)
+
+ def user(self, username):
+ endpoint = "user.getCreator"
+ params = {"username": username}
+ return (self._call(endpoint, params),)
+
+ def _call(self, endpoint, params):
+ url = self.root + endpoint
+ headers = self.headers
+ params = {"input": util.json_dumps({"json": params})}
+
+ headers["x-client-date"] = str(int(time.time() * 1000))
+ response = self.extractor.request(url, headers=headers, params=params)
+
+ return response.json()["result"]["data"]["json"]
+
+ def _pagination(self, endpoint, params):
+ while True:
+ data = self._call(endpoint, params)
+ yield from data["items"]
+
+ try:
+ if not data["nextCursor"]:
+ return
+ params["cursor"] = data["nextCursor"]
+ except KeyError:
+ return