diff options
Diffstat (limited to 'gallery_dl/extractor/reddit.py')
| -rw-r--r-- | gallery_dl/extractor/reddit.py | 313 |
1 files changed, 313 insertions, 0 deletions
diff --git a/gallery_dl/extractor/reddit.py b/gallery_dl/extractor/reddit.py new file mode 100644 index 0000000..0c5a924 --- /dev/null +++ b/gallery_dl/extractor/reddit.py @@ -0,0 +1,313 @@ +# -*- coding: utf-8 -*- + +# Copyright 2017-2019 Mike Fährmann +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License version 2 as +# published by the Free Software Foundation. + +"""Extract images from subreddits at https://www.reddit.com/""" + +from .common import Extractor, Message +from .. import text, util, extractor, exception +from ..cache import cache +import datetime +import time + + +class RedditExtractor(Extractor): + """Base class for reddit extractors""" + category = "reddit" + + def __init__(self, match): + Extractor.__init__(self, match) + self.api = RedditAPI(self) + self.max_depth = int(self.config("recursion", 0)) + self._visited = set() + + def items(self): + subre = RedditSubmissionExtractor.pattern + submissions = self.submissions() + depth = 0 + + yield Message.Version, 1 + with extractor.blacklist( + util.SPECIAL_EXTRACTORS, [RedditSubredditExtractor]): + while True: + extra = [] + for url, data in self._urls(submissions): + if url[0] == "#": + continue + if url[0] == "/": + url = "https://www.reddit.com" + url + + match = subre.match(url) + if match: + extra.append(match.group(1)) + else: + yield Message.Queue, text.unescape(url), data + + if not extra or depth == self.max_depth: + return + depth += 1 + submissions = ( + self.api.submission(sid) for sid in extra + if sid not in self._visited + ) + + def submissions(self): + """Return an iterable containing all (submission, comments) tuples""" + + def _urls(self, submissions): + for submission, comments in submissions: + self._visited.add(submission["id"]) + + if not submission["is_self"]: + yield submission["url"], submission + + for url in text.extract_iter( + submission["selftext_html"] or "", ' href="', '"'): + yield url, submission + + for comment in comments: + for url in text.extract_iter( + comment["body_html"] or "", ' href="', '"'): + yield url, comment + + +class RedditSubredditExtractor(RedditExtractor): + """Extractor for images from subreddits on reddit.com""" + subcategory = "subreddit" + pattern = (r"(?:https?://)?(?:\w+\.)?reddit\.com/r/([^/?&#]+)" + r"(/[a-z]+)?/?" + r"(?:\?.*?(?:\bt=([a-z]+))?)?$") + test = ( + ("https://www.reddit.com/r/lavaporn/"), + ("https://www.reddit.com/r/lavaporn/top/?sort=top&t=month"), + ("https://old.reddit.com/r/lavaporn/"), + ("https://np.reddit.com/r/lavaporn/"), + ("https://m.reddit.com/r/lavaporn/"), + ) + + def __init__(self, match): + RedditExtractor.__init__(self, match) + self.subreddit, self.order, self.timeframe = match.groups() + + def submissions(self): + subreddit = self.subreddit + (self.order or "") + params = {"t": self.timeframe} if self.timeframe else {} + return self.api.submissions_subreddit(subreddit, params) + + +class RedditSubmissionExtractor(RedditExtractor): + """Extractor for images from a submission on reddit.com""" + subcategory = "submission" + pattern = (r"(?:https?://)?(?:" + r"(?:\w+\.)?reddit\.com/r/[^/?&#]+/comments|" + r"redd\.it" + r")/([a-z0-9]+)") + test = ( + ("https://www.reddit.com/r/lavaporn/comments/2a00np/", { + "pattern": r"https?://i\.imgur\.com/AaAUCgy\.jpg", + }), + ("https://old.reddit.com/r/lavaporn/comments/2a00np/"), + ("https://np.reddit.com/r/lavaporn/comments/2a00np/"), + ("https://m.reddit.com/r/lavaporn/comments/2a00np/"), + ("https://redd.it/2a00np/"), + ) + + def __init__(self, match): + RedditExtractor.__init__(self, match) + self.submission_id = match.group(1) + + def submissions(self): + return (self.api.submission(self.submission_id),) + + +class RedditImageExtractor(Extractor): + """Extractor for reddit-hosted images""" + category = "reddit" + subcategory = "image" + archive_fmt = "{filename}" + pattern = (r"(?:https?://)?i\.redd(?:\.it|ituploads\.com)" + r"/[^/?&#]+(?:\?[^#]*)?") + test = ( + ("https://i.redd.it/upjtjcx2npzz.jpg", { + "url": "0de614900feef103e580b632190458c0b62b641a", + "content": "cc9a68cf286708d5ce23c68e79cd9cf7826db6a3", + }), + (("https://i.reddituploads.com/0f44f1b1fca2461f957c713d9592617d" + "?fit=max&h=1536&w=1536&s=e96ce7846b3c8e1f921d2ce2671fb5e2"), { + "url": "f24f25efcedaddeec802e46c60d77ef975dc52a5", + "content": "541dbcc3ad77aa01ee21ca49843c5e382371fae7", + }), + ) + + def items(self): + data = text.nameext_from_url(self.url) + yield Message.Version, 1 + yield Message.Directory, data + yield Message.Url, self.url, data + + +class RedditAPI(): + """Minimal interface for the reddit API""" + CLIENT_ID = "6N9uN0krSDE-ig" + USER_AGENT = "Python:gallery-dl:0.8.4 (by /u/mikf1)" + + def __init__(self, extractor): + self.extractor = extractor + self.comments = extractor.config("comments", 500) + self.morecomments = extractor.config("morecomments", False) + self.refresh_token = extractor.config("refresh-token") + self.log = extractor.log + + client_id = extractor.config("client-id", self.CLIENT_ID) + user_agent = extractor.config("user-agent", self.USER_AGENT) + + if (client_id == self.CLIENT_ID) ^ (user_agent == self.USER_AGENT): + self.client_id = None + self.log.warning( + "Conflicting values for 'client-id' and 'user-agent': " + "override either both or none of them.") + else: + self.client_id = client_id + extractor.session.headers["User-Agent"] = user_agent + + def submission(self, submission_id): + """Fetch the (submission, comments)=-tuple for a submission id""" + endpoint = "/comments/" + submission_id + "/.json" + link_id = "t3_" + submission_id if self.morecomments else None + submission, comments = self._call(endpoint, {"limit": self.comments}) + return (submission["data"]["children"][0]["data"], + self._flatten(comments, link_id)) + + def submissions_subreddit(self, subreddit, params): + """Collect all (submission, comments)-tuples of a subreddit""" + endpoint = "/r/" + subreddit + "/.json" + params["limit"] = 100 + return self._pagination(endpoint, params) + + def morechildren(self, link_id, children): + """Load additional comments from a submission""" + endpoint = "/api/morechildren" + params = {"link_id": link_id, "api_type": "json"} + index, done = 0, False + while not done: + if len(children) - index < 100: + done = True + params["children"] = ",".join(children[index:index + 100]) + index += 100 + + data = self._call(endpoint, params)["json"] + for thing in data["data"]["things"]: + if thing["kind"] == "more": + children.extend(thing["data"]["children"]) + else: + yield thing["data"] + + def authenticate(self): + """Authenticate the application by requesting an access token""" + access_token = self._authenticate_impl(self.refresh_token) + self.extractor.session.headers["Authorization"] = access_token + + @cache(maxage=3600, keyarg=1) + def _authenticate_impl(self, refresh_token=None): + """Actual authenticate implementation""" + url = "https://www.reddit.com/api/v1/access_token" + if refresh_token: + self.log.info("Refreshing private access token") + data = {"grant_type": "refresh_token", + "refresh_token": refresh_token} + else: + self.log.info("Requesting public access token") + data = {"grant_type": ("https://oauth.reddit.com/" + "grants/installed_client"), + "device_id": "DO_NOT_TRACK_THIS_DEVICE"} + response = self.extractor.request( + url, method="POST", data=data, auth=(self.client_id, "")) + if response.status_code != 200: + raise exception.AuthenticationError('"{} ({})"'.format( + response.json().get("message"), response.status_code)) + return "Bearer " + response.json()["access_token"] + + def _call(self, endpoint, params): + url = "https://oauth.reddit.com" + endpoint + params["raw_json"] = 1 + self.authenticate() + response = self.extractor.request( + url, params=params, expect=range(400, 500)) + remaining = response.headers.get("x-ratelimit-remaining") + if remaining and float(remaining) < 2: + wait = int(response.headers["x-ratelimit-reset"]) + self.log.info("Waiting %d seconds for ratelimit reset", wait) + time.sleep(wait) + data = response.json() + if "error" in data: + if data["error"] == 403: + raise exception.AuthorizationError() + if data["error"] == 404: + raise exception.NotFoundError() + raise Exception(data["message"]) + return data + + def _pagination(self, endpoint, params, _empty=()): + date_fmt = self.extractor.config("date-format", "%Y-%m-%dT%H:%M:%S") + date_min = self._parse_datetime("date-min", 0, date_fmt) + date_max = self._parse_datetime("date-max", 253402210800, date_fmt) + + id_min = self._parse_id("id-min", 0) + id_max = self._parse_id("id-max", 2147483647) + + while True: + data = self._call(endpoint, params)["data"] + + for submission in data["children"]: + submission = submission["data"] + if (date_min <= submission["created_utc"] <= date_max and + id_min <= self._decode(submission["id"]) <= id_max): + if submission["num_comments"] and self.comments: + try: + yield self.submission(submission["id"]) + except exception.AuthorizationError: + pass + else: + yield submission, _empty + + if not data["after"]: + return + params["after"] = data["after"] + + def _flatten(self, comments, link_id=None): + extra = [] + queue = comments["data"]["children"] + while queue: + comment = queue.pop(0) + if comment["kind"] == "more": + if link_id: + extra.extend(comment["data"]["children"]) + continue + comment = comment["data"] + yield comment + if comment["replies"]: + queue += comment["replies"]["data"]["children"] + if link_id and extra: + yield from self.morechildren(link_id, extra) + + def _parse_datetime(self, key, default, fmt): + ts = self.extractor.config(key, default) + if isinstance(ts, str): + try: + ts = int(datetime.datetime.strptime(ts, fmt).timestamp()) + except ValueError as exc: + self.log.warning("Unable to parse '%s': %s", key, exc) + ts = default + return ts + + def _parse_id(self, key, default): + sid = self.extractor.config(key) + return self._decode(sid.rpartition("_")[2].lower()) if sid else default + + @staticmethod + def _decode(sid): + return util.bdecode(sid, "0123456789abcdefghijklmnopqrstuvwxyz") |
