diff options
Diffstat (limited to 'gallery_dl/extractor/reddit.py')
| -rw-r--r-- | gallery_dl/extractor/reddit.py | 37 |
1 files changed, 26 insertions, 11 deletions
diff --git a/gallery_dl/extractor/reddit.py b/gallery_dl/extractor/reddit.py index d0232cc..2e3864a 100644 --- a/gallery_dl/extractor/reddit.py +++ b/gallery_dl/extractor/reddit.py @@ -222,20 +222,25 @@ class RedditAPI(): self.extractor = extractor self.comments = text.parse_int(extractor.config("comments", 0)) self.morecomments = extractor.config("morecomments", False) - self.refresh_token = extractor.config("refresh-token") self.log = extractor.log client_id = extractor.config("client-id", self.CLIENT_ID) user_agent = extractor.config("user-agent", self.USER_AGENT) if (client_id == self.CLIENT_ID) ^ (user_agent == self.USER_AGENT): - self.client_id = None - self.log.warning( + raise exception.StopExtraction( "Conflicting values for 'client-id' and 'user-agent': " "overwrite either both or none of them.") + + self.client_id = client_id + self.headers = {"User-Agent": user_agent} + + token = extractor.config("refresh-token") + if token is None or token == "cache": + key = "#" + self.client_id + self.refresh_token = _refresh_token_cache(key) else: - self.client_id = client_id - extractor.session.headers["User-Agent"] = user_agent + self.refresh_token = token def submission(self, submission_id): """Fetch the (submission, comments)=-tuple for a submission id""" @@ -277,13 +282,15 @@ class RedditAPI(): def authenticate(self): """Authenticate the application by requesting an access token""" - access_token = self._authenticate_impl(self.refresh_token) - self.extractor.session.headers["Authorization"] = access_token + self.headers["Authorization"] = \ + self._authenticate_impl(self.refresh_token) @cache(maxage=3600, keyarg=1) def _authenticate_impl(self, refresh_token=None): """Actual authenticate implementation""" url = "https://www.reddit.com/api/v1/access_token" + self.headers["Authorization"] = None + if refresh_token: self.log.info("Refreshing private access token") data = {"grant_type": "refresh_token", @@ -294,9 +301,9 @@ class RedditAPI(): "grants/installed_client"), "device_id": "DO_NOT_TRACK_THIS_DEVICE"} - auth = (self.client_id, "") response = self.extractor.request( - url, method="POST", data=data, auth=auth, fatal=False) + url, method="POST", headers=self.headers, + data=data, auth=(self.client_id, ""), fatal=False) data = response.json() if response.status_code != 200: @@ -307,9 +314,10 @@ class RedditAPI(): def _call(self, endpoint, params): url = "https://oauth.reddit.com" + endpoint - params["raw_json"] = 1 + params["raw_json"] = "1" self.authenticate() - response = self.extractor.request(url, params=params, fatal=None) + response = self.extractor.request( + url, params=params, headers=self.headers, fatal=None) remaining = response.headers.get("x-ratelimit-remaining") if remaining and float(remaining) < 2: @@ -380,3 +388,10 @@ class RedditAPI(): @staticmethod def _decode(sid): return util.bdecode(sid, "0123456789abcdefghijklmnopqrstuvwxyz") + + +@cache(maxage=100*365*24*3600, keyarg=0) +def _refresh_token_cache(token): + if token and token[0] == "#": + return None + return token |
