summaryrefslogtreecommitdiffstats
path: root/gallery_dl/extractor/reddit.py
diff options
context:
space:
mode:
Diffstat (limited to 'gallery_dl/extractor/reddit.py')
-rw-r--r--gallery_dl/extractor/reddit.py37
1 files changed, 26 insertions, 11 deletions
diff --git a/gallery_dl/extractor/reddit.py b/gallery_dl/extractor/reddit.py
index d0232cc..2e3864a 100644
--- a/gallery_dl/extractor/reddit.py
+++ b/gallery_dl/extractor/reddit.py
@@ -222,20 +222,25 @@ class RedditAPI():
self.extractor = extractor
self.comments = text.parse_int(extractor.config("comments", 0))
self.morecomments = extractor.config("morecomments", False)
- self.refresh_token = extractor.config("refresh-token")
self.log = extractor.log
client_id = extractor.config("client-id", self.CLIENT_ID)
user_agent = extractor.config("user-agent", self.USER_AGENT)
if (client_id == self.CLIENT_ID) ^ (user_agent == self.USER_AGENT):
- self.client_id = None
- self.log.warning(
+ raise exception.StopExtraction(
"Conflicting values for 'client-id' and 'user-agent': "
"overwrite either both or none of them.")
+
+ self.client_id = client_id
+ self.headers = {"User-Agent": user_agent}
+
+ token = extractor.config("refresh-token")
+ if token is None or token == "cache":
+ key = "#" + self.client_id
+ self.refresh_token = _refresh_token_cache(key)
else:
- self.client_id = client_id
- extractor.session.headers["User-Agent"] = user_agent
+ self.refresh_token = token
def submission(self, submission_id):
"""Fetch the (submission, comments)=-tuple for a submission id"""
@@ -277,13 +282,15 @@ class RedditAPI():
def authenticate(self):
"""Authenticate the application by requesting an access token"""
- access_token = self._authenticate_impl(self.refresh_token)
- self.extractor.session.headers["Authorization"] = access_token
+ self.headers["Authorization"] = \
+ self._authenticate_impl(self.refresh_token)
@cache(maxage=3600, keyarg=1)
def _authenticate_impl(self, refresh_token=None):
"""Actual authenticate implementation"""
url = "https://www.reddit.com/api/v1/access_token"
+ self.headers["Authorization"] = None
+
if refresh_token:
self.log.info("Refreshing private access token")
data = {"grant_type": "refresh_token",
@@ -294,9 +301,9 @@ class RedditAPI():
"grants/installed_client"),
"device_id": "DO_NOT_TRACK_THIS_DEVICE"}
- auth = (self.client_id, "")
response = self.extractor.request(
- url, method="POST", data=data, auth=auth, fatal=False)
+ url, method="POST", headers=self.headers,
+ data=data, auth=(self.client_id, ""), fatal=False)
data = response.json()
if response.status_code != 200:
@@ -307,9 +314,10 @@ class RedditAPI():
def _call(self, endpoint, params):
url = "https://oauth.reddit.com" + endpoint
- params["raw_json"] = 1
+ params["raw_json"] = "1"
self.authenticate()
- response = self.extractor.request(url, params=params, fatal=None)
+ response = self.extractor.request(
+ url, params=params, headers=self.headers, fatal=None)
remaining = response.headers.get("x-ratelimit-remaining")
if remaining and float(remaining) < 2:
@@ -380,3 +388,10 @@ class RedditAPI():
@staticmethod
def _decode(sid):
return util.bdecode(sid, "0123456789abcdefghijklmnopqrstuvwxyz")
+
+
+@cache(maxage=100*365*24*3600, keyarg=0)
+def _refresh_token_cache(token):
+ if token and token[0] == "#":
+ return None
+ return token