diff --git a/reddit/subreddit.py b/reddit/subreddit.py index e355d64..34d2612 100644 --- a/reddit/subreddit.py +++ b/reddit/subreddit.py @@ -7,8 +7,12 @@ from dotenv import load_dotenv from utils.console import print_step, print_substep -def get_subreddit_threads(): +def get_subreddit_threads(subreddit_): """ + Takes subreddit_ as parameter which defaults to None, but in this + case since it is None, it would raise ValueError, thus defaulting + to AskReddit. + Returns a list of threads from the AskReddit subreddit. """ @@ -27,7 +31,6 @@ def get_subreddit_threads(): passkey = os.getenv("REDDIT_PASSWORD") content = {} - reddit = praw.Reddit( client_id=os.getenv("REDDIT_CLIENT_ID"), client_secret=os.getenv("REDDIT_CLIENT_SECRET"), @@ -37,9 +40,10 @@ def get_subreddit_threads(): ) try: - subreddit = reddit.subreddit( - input("What subreddit would you like to pull from? ") - ) + if subreddit_ is None: + raise ValueError + + subreddit = reddit.subreddit(subreddit_) except ValueError: subreddit = reddit.subreddit("askreddit") print_substep("Subreddit not defined. Using AskReddit.")