diff --git a/reddit/subreddit.py b/reddit/subreddit.py index 5f2ac5f..26fa07a 100644 --- a/reddit/subreddit.py +++ b/reddit/subreddit.py @@ -1,6 +1,5 @@ -import re - -import praw +from re import sub +from praw import Reddit from praw.models import MoreComments from prawcore.exceptions import ResponseException @@ -13,148 +12,131 @@ from utils.videos import check_done from utils.voice import sanitize_text -def get_subreddit_threads(POST_ID: str): - """ - Returns a list of threads from the AskReddit subreddit. - """ - +def get_reddit_instance(): + """Initialize and return a Reddit instance.""" print_substep("Logging into Reddit.") - - content = {} if settings.config["reddit"]["creds"]["2fa"]: print("\nEnter your two-factor authentication code from your authenticator app.\n") code = input("> ") print() - pw = settings.config["reddit"]["creds"]["password"] - passkey = f"{pw}:{code}" + password = settings.config["reddit"]["creds"]["password"] + passkey = f"{password}:{code}" else: passkey = settings.config["reddit"]["creds"]["password"] + username = settings.config["reddit"]["creds"]["username"] if str(username).casefold().startswith("u/"): username = username[2:] + try: - reddit = praw.Reddit( + return Reddit( client_id=settings.config["reddit"]["creds"]["client_id"], client_secret=settings.config["reddit"]["creds"]["client_secret"], user_agent="Accessing Reddit threads", username=username, - passkey=passkey, + password=passkey, check_for_async=False, ) except ResponseException as e: if e.response.status_code == 401: print("Invalid credentials - please check them in config.toml") - except: + except Exception: print("Something went wrong...") - # Ask user for subreddit input - print_step("Getting subreddit threads...") - similarity_score = 0 - if not settings.config["reddit"]["thread"][ - "subreddit" - ]: # note to user. you can have multiple subreddits via reddit.subreddit("redditdev+learnpython") + +def get_subreddit(reddit): + """Get the subreddit based on user input or config.""" + if not settings.config["reddit"]["thread"]["subreddit"]: try: - subreddit = reddit.subreddit( - re.sub(r"r\/", "", input("What subreddit would you like to pull from? ")) - # removes the r/ from the input - ) + subreddit_name = sub(r"r\/", "", input("What subreddit would you like to pull from? ")) + return reddit.subreddit(subreddit_name) except ValueError: - subreddit = reddit.subreddit("askreddit") print_substep("Subreddit not defined. Using AskReddit.") + return reddit.subreddit("askreddit") else: - sub = settings.config["reddit"]["thread"]["subreddit"] - print_substep(f"Using subreddit: r/{sub} from TOML config") - subreddit_choice = sub - if str(subreddit_choice).casefold().startswith("r/"): # removes the r/ from the input - subreddit_choice = subreddit_choice[2:] - subreddit = reddit.subreddit(subreddit_choice) - - if POST_ID: # would only be called if there are multiple queued posts - submission = reddit.submission(id=POST_ID) - - elif ( - settings.config["reddit"]["thread"]["post_id"] - and len(str(settings.config["reddit"]["thread"]["post_id"]).split("+")) == 1 - ): - submission = reddit.submission(id=settings.config["reddit"]["thread"]["post_id"]) - elif settings.config["ai"]["ai_similarity_enabled"]: # ai sorting based on comparison + subreddit_name = settings.config["reddit"]["thread"]["subreddit"] + print_substep(f"Using subreddit: r/{subreddit_name} from TOML config") + if str(subreddit_name).casefold().startswith("r/"): + subreddit_name = subreddit_name[2:] + return reddit.subreddit(subreddit_name) + + +def get_submission(reddit, subreddit, post_id): + """Retrieve a submission based on post ID or subreddit.""" + if post_id: + return reddit.submission(id=post_id) + + if settings.config["reddit"]["thread"]["post_id"] and len(str(settings.config["reddit"]["thread"]["post_id"]).split("+")) == 1: + return reddit.submission(id=settings.config["reddit"]["thread"]["post_id"]) + + if settings.config["ai"]["ai_similarity_enabled"]: threads = subreddit.hot(limit=50) - keywords = settings.config["ai"]["ai_similarity_keywords"].split(",") - keywords = [keyword.strip() for keyword in keywords] - # Reformat the keywords for printing - keywords_print = ", ".join(keywords) - print(f"Sorting threads by similarity to the given keywords: {keywords_print}") + keywords = [keyword.strip() for keyword in settings.config["ai"]["ai_similarity_keywords"].split(",")] + print(f"Sorting threads by similarity to the given keywords: {', '.join(keywords)}") threads, similarity_scores = sort_by_similarity(threads, keywords) - submission, similarity_score = get_subreddit_undone( - threads, subreddit, similarity_scores=similarity_scores - ) - else: - threads = subreddit.hot(limit=25) - submission = get_subreddit_undone(threads, subreddit) + return get_subreddit_undone(threads, subreddit, similarity_scores=similarity_scores) + + threads = subreddit.hot(limit=25) + return get_subreddit_undone(threads, subreddit) + + +def collect_comments(submission): + """Collect comments from a submission.""" + comments = [] + for top_level_comment in submission.comments: + if isinstance(top_level_comment, MoreComments): + continue + + if top_level_comment.body in ["[removed]", "[deleted]"]: + continue + + if not top_level_comment.stickied: + sanitized_text = sanitize_text(top_level_comment.body) + if sanitized_text and len(top_level_comment.body) <= int(settings.config["reddit"]["thread"]["max_comment_length"]): + if len(top_level_comment.body) >= int(settings.config["reddit"]["thread"]["min_comment_length"]): + if top_level_comment.author is not None: + comments.append({ + "comment_body": top_level_comment.body, + "comment_url": top_level_comment.permalink, + "comment_id": top_level_comment.id, + }) + return comments + + +def get_subreddit_threads(post_id): + """Main function to get subreddit threads.""" + reddit = get_reddit_instance() + if not reddit: + return + + print_step("Getting subreddit threads...") + subreddit = get_subreddit(reddit) + submission = get_submission(reddit, subreddit, post_id) if submission is None: - return get_subreddit_threads(POST_ID) # submission already done. rerun + return get_subreddit_threads(post_id) - elif not submission.num_comments and settings.config["settings"]["storymode"] == "false": + if not submission.num_comments and settings.config["settings"]["storymode"] == "false": print_substep("No comments found. Skipping.") exit() - submission = check_done(submission) # double-checking - - upvotes = submission.score - ratio = submission.upvote_ratio * 100 - num_comments = submission.num_comments - threadurl = f"https://new.reddit.com/{submission.permalink}" - - print_substep(f"Video will be: {submission.title} :thumbsup:", style="bold green") - print_substep(f"Thread url is: {threadurl} :thumbsup:", style="bold green") - print_substep(f"Thread has {upvotes} upvotes", style="bold blue") - print_substep(f"Thread has a upvote ratio of {ratio}%", style="bold blue") - print_substep(f"Thread has {num_comments} comments", style="bold blue") - if similarity_score: - print_substep( - f"Thread has a similarity score up to {round(similarity_score * 100)}%", - style="bold blue", - ) + submission = check_done(submission) + + content = { + "thread_url": f"https://new.reddit.com/{submission.permalink}", + "thread_title": submission.title, + "thread_id": submission.id, + "is_nsfw": submission.over_18, + "comments": [], + } - content["thread_url"] = threadurl - content["thread_title"] = submission.title - content["thread_id"] = submission.id - content["is_nsfw"] = submission.over_18 - content["comments"] = [] if settings.config["settings"]["storymode"]: if settings.config["settings"]["storymodemethod"] == 1: content["thread_post"] = posttextparser(submission.selftext) else: content["thread_post"] = submission.selftext else: - for top_level_comment in submission.comments: - if isinstance(top_level_comment, MoreComments): - continue - - if top_level_comment.body in ["[removed]", "[deleted]"]: - continue # # see https://github.com/JasonLovesDoggo/RedditVideoMakerBot/issues/78 - if not top_level_comment.stickied: - sanitised = sanitize_text(top_level_comment.body) - if not sanitised or sanitised == " ": - continue - if len(top_level_comment.body) <= int( - settings.config["reddit"]["thread"]["max_comment_length"] - ): - if len(top_level_comment.body) >= int( - settings.config["reddit"]["thread"]["min_comment_length"] - ): - if ( - top_level_comment.author is not None - and sanitize_text(top_level_comment.body) is not None - ): # if errors occur with this change to if not. - content["comments"].append( - { - "comment_body": top_level_comment.body, - "comment_url": top_level_comment.permalink, - "comment_id": top_level_comment.id, - } - ) + content["comments"] = collect_comments(submission) print_substep("Received subreddit threads Successfully.", style="bold green") - return content + return content \ No newline at end of file