Skip to content
This repository has been archived by the owner on Dec 22, 2021. It is now read-only.

Commit

Permalink
Refactor scrape.py
Browse files Browse the repository at this point in the history
This refactor is incomplete but I'm looking to tidy this up sooner
rather than later
  • Loading branch information
natanlao committed Mar 19, 2021
1 parent 565b32d commit 617bd50
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 62 deletions.
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,22 @@ To authenticate to Reddit, you need to:
### Collect and analyze data

```console
$ python scripts/scrape.py wallstreetbets posts
$ python scripts/scrape.py wallstreetbets comments
$ python scripts/scrape.py wallstreetbets fetch-posts
$ python scripts/scrape.py wallstreetbets fetch-comments
$ python scripts/load.py wallstreetbets # creates wallstreetbets.db
$ python plot.py wallstreetbets.db all plot
```

## Lessons learned

The initial crawl took about two weeks due to the sheer amount of comments and
Reddit's API rate limiting. The first time around, I neglected to save scores
for each comment. This second pass took more than four weeks; in the first pass,
I was able to retrieve thousands of comments with a single API call, but during
the second pass, I retrieved each comment one by one, which was more expensive.
By the time I realized this, I was too busy to fix the problem, and let the
scraper continue running as written.

I could have avoided this problem by being greedier about what data I was saving
in the first pass. (In my defense, I was rtunning out of disk space on my
laptop, but I think I would have made the same mistake even if I wasn't.)
26 changes: 15 additions & 11 deletions scripts/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,31 @@ def __init__(self, db_path: str):
self.c = self.conn.cursor()

def init_db(self):
# TODO: Brittle coupling
self.c.execute('''CREATE TABLE posts
(id integer, title text, selftext text, created text)''')
(id integer, title text, selftext text, num_comments integer, score integer, created text)''')
self.c.execute('''CREATE TABLE comments
(id integer, body text, created text)''')
(id integer, body text, score integer, created text)''')
self.conn.commit()

def load(self, cache: RedditCache):
comments = ((
b36decode(comment['id']),
comment['body'],
timestamp_dbformat(comment['created'])
b36decode(comment.id),
comment.body,
comment.score,
timestamp_dbformat(comment.created)
) for comment in cache.comments)
self.c.executemany('INSERT INTO comments VALUES (?, ?, ?)', comments)
self.c.executemany('INSERT INTO comments VALUES (?, ?, ?, ?)', comments)

posts = ((
b36decode(post['id']),
post['title'],
post['selftext'],
timestamp_dbformat(post['created'])
b36decode(post.id),
post.title,
post.selftext,
post.num_comments,
post.score,
timestamp_dbformat(post.created)
) for post in cache.posts)
self.c.executemany('INSERT INTO posts VALUES (?, ?, ?, ?)', posts)
self.c.executemany('INSERT INTO posts VALUES (?, ?, ?, ?, ?, ?)', posts)

self.conn.commit()

Expand Down
1 change: 1 addition & 0 deletions scripts/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def silver_gamestop_posts(db: Database, outpath: str):
fig.update_layout(plot_layout('posts', 'term'))
fig.write_image(outpath)


def silver_gme_comments(db: Database, outpath: str):
gme_x, gme_y = zip(*comments_freq_time(db, 'GME'))
fig = go.Figure(data=go.Bar(x=gme_x, y=gme_y, name='"GME"'))
Expand Down
146 changes: 97 additions & 49 deletions scripts/scrape.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@
import json
import logging
import pathlib
from typing import Generator
from typing import Generator, NamedTuple

import praw
import prawcore.exceptions

log = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)
reddit = praw.Reddit(user_agent='wsv v0.1 (https://github.com/natanlao/wsv)')


class RedditCache:

def __init__(self, cache_dir: str, subreddit_name: str):
subreddit = reddit.subreddit(subreddit_name)
self.cache_dir = pathlib.Path(cache_dir)
self.posts_dir = self.cache_dir / subreddit_name / 'posts'
self.posts_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -35,16 +36,19 @@ def _yield_dir_json(path: pathlib.Path) -> Generator:

@property
def posts(self) -> Generator:
# This is leaky but I'm lazy
return self._yield_dir_json(self.posts_dir)
for item in self.posts_dir.glob('*.json'):
with item.open('r') as fh:
yield Post(**json.load(fh))

@property
def comments(self) -> Generator:
return self._yield_dir_json(self.comments_dir)
for item in self.comments_dir.glob('*.json'):
with item.open('r') as fh:
yield Comment(**json.load(fh))

@functools.cached_property
def num_posts(self) -> int:
return len(list(self.posts))
return sum(1 for _ in self.posts_dir.glob('*.json'))

@functools.cached_property
def seen_posts(self) -> collections.Counter[str]:
Expand All @@ -69,11 +73,10 @@ def cache_comments(self):
log.info('Saving comments for post %d/%d (%s)',
post_num, self.num_posts, post['id'])
for comment in fetch_comments_for_post(post['id']):
with self.comment_path(comment['id']).open('w') as json_fp:
json.dump(comment, json_fp, indent=2)
comment.save(self)

def cache_posts(self):
sub = reddit.subreddit(self.subreddit_name)
sub = self.subreddit
listings = [sub.new(limit=None),
sub.top('day', limit=None),
sub.top('hour', limit=None),
Expand All @@ -88,57 +91,107 @@ def cache_posts(self):

for post in itertools.chain.from_iterable(listings):
log.info('Caching post ID %s', post.id)
with self.post_path(post.id).open('w') as post_fh:
try:
author = post.author.name
except AttributeError:
author = '[deleted]'
finally:
post = {
'author': author,
'created': post.created,
'edited': post.edited,
'id': post.id,
'permalink': post.permalink,
'selftext': post.selftext,
'title': post.title,
'url': post.url
}
json.dump(post, post_fh, indent=2)
Post.from_praw(post).save(self) # this feels wrong but too lazy to fix now

def update_posts(self):
for num, item in enumerate(self.posts_dir.glob('*.json'), 1):
with item.open('r') as fh:
submission = json.load(fh)
log.info('Caching post ID %s (%d/%d)', submission['id'], num, self.num_posts)
post = praw.models.Submission(reddit, id=submission['id'])
Post.from_praw(post).save(self)


class Post(NamedTuple):
author: str
created: int
edited: bool
id: str
num_comments: int
permalink: str
score: int
selftext: str
title: str
url: str

@classmethod
def from_praw(cls, submission: praw.models.Submission):
try:
author = submission.author.name
except AttributeError:
author = '[deleted]'
finally:
return cls(author=author,
created=submission.created,
edited=submission.edited,
id=submission.id,
num_comments=submission.num_comments,
permalink=submission.permalink,
score=submission.score,
selftext=submission.selftext,
title=submission.title,
url=submission.url)

def save(self, cache: RedditCache):
with cache.post_path(self.id).open('w') as post_fh:
json.dump(self._asdict(), post_fh)


class Comment(NamedTuple):
author: str
body: str
created: int
edited: bool
id: str
permalink: str
post_id: str
score: int

@classmethod
def from_praw(cls, comment: praw.models.Comment):
try:
author = comment.author.name
except AttributeError:
author = '[deleted]'
finally:
return cls(author=author,
body=comment.body,
created=comment.created,
edited=comment.edited,
id=comment.id,
permalink=comment.permalink,
post_id=comment.link_id,
score=comment.score)

def save(self, cache: RedditCache):
with cache.comment_path(self.id).open('w') as comment_fh:
json.dump(self._asdict(), comment_fh)


def fetch_comments_for_post(post_id: str) -> Generator:
post = praw.models.Submission(reddit, id=post_id)

try:
post.comments.replace_more(limit=None)
except prawcore.exceptions.TooLarge:
log.info('Post %s has too many comments', post_id)
post = praw.models.Submission(reddit, id=post_id)
post.comments.replace_more(limit=0) # TODO
finally:
for comment in post.comments.list():
yield Comment.from_praw(comment)

for comment in post.comments.list():
try:
author = comment.author.name
except AttributeError:
author = '[deleted]'
finally:
yield {
'author': author,
'body': comment.body,
'created': comment.created,
'edited': comment.edited,
'id': comment.id,
'permalink': comment.permalink,
'post_id': post.id
}

dispatch = {
'fetch-comments': lambda c: c.cache_comments(),
'fetch-posts': lambda c: c.cache_posts(),
'update-posts': lambda c: c.update_posts()
}

if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('subreddit_name')
parser.add_argument('target',
choices=['comments', 'posts'],
choices=dispatch,
help='Comments are only collected from posts we have '
'already crawled; i.e., crawling comments before '
'posts will result in no comments being saved.')
Expand All @@ -148,9 +201,4 @@ def fetch_comments_for_post(post_id: str) -> Generator:
arguments = parser.parse_args()

cache = RedditCache(arguments.cache_dir, arguments.subreddit_name)
if arguments.target == 'posts':
cache.cache_posts()
elif arguments.target == 'comments':
cache.cache_comments()
else:
exit(1)
dispatch[arguments.target](cache)

0 comments on commit 617bd50

Please sign in to comment.