diff --git a/reddit/subreddit.py b/reddit/subreddit.py index 11b93af..5b87430 100644 --- a/reddit/subreddit.py +++ b/reddit/subreddit.py @@ -10,6 +10,7 @@ from utils.console import print_step, print_substep from utils.subreddit import get_subreddit_undone from utils.videos import check_done from utils.voice import sanitize_text +from utils.ai_methods import sort_by_similarity def get_subreddit_threads(POST_ID: str): @@ -49,6 +50,7 @@ def get_subreddit_threads(POST_ID: str): # Ask user for subreddit input print_step("Getting subreddit threads...") + similarity_score = 0 if not settings.config["reddit"]["thread"][ "subreddit" ]: # note to user. you can have multiple subreddits via reddit.subreddit("redditdev+learnpython") @@ -75,6 +77,15 @@ def get_subreddit_threads(POST_ID: str): and len(str(settings.config["reddit"]["thread"]["post_id"]).split("+")) == 1 ): submission = reddit.submission(id=settings.config["reddit"]["thread"]["post_id"]) + elif settings.config["ai"]["ai_similarity_enabled"]: # ai sorting based on comparison + threads = subreddit.hot(limit=50) + keywords = settings.config["ai"]["ai_similarity_keywords"].split(',') + keywords = [keyword.strip() for keyword in keywords] + # Reformat the keywords for printing + keywords_print = ", ".join(keywords) + print(f'Sorting threads by similarity to the given keywords: {keywords_print}') + threads, similarity_scores = sort_by_similarity(threads, keywords) + submission, similarity_score = get_subreddit_undone(threads, subreddit, similarity_scores=similarity_scores) else: threads = subreddit.hot(limit=25) submission = get_subreddit_undone(threads, subreddit) @@ -89,6 +100,8 @@ def get_subreddit_threads(POST_ID: str): print_substep(f"Thread has {upvotes} upvotes", style="bold blue") print_substep(f"Thread has a upvote ratio of {ratio}%", style="bold blue") print_substep(f"Thread has {num_comments} comments", style="bold blue") + if similarity_score: + print_substep(f"Thread has a similarity score up to {round(similarity_score * 100)}%", style="bold blue") content["thread_url"] = f"https://reddit.com{submission.permalink}" content["thread_title"] = submission.title diff --git a/requirements.txt b/requirements.txt index 259e056..6b1a786 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,5 @@ toml==0.10.2 translators==5.3.1 pyttsx3==2.90 Pillow~=9.3.0 +torch==1.12.1 +transformers==4.25.1 \ No newline at end of file diff --git a/utils/.config.template.toml b/utils/.config.template.toml index adbaed0..ec9281e 100644 --- a/utils/.config.template.toml +++ b/utils/.config.template.toml @@ -14,6 +14,9 @@ max_comment_length = { default = 500, optional = false, nmin = 10, nmax = 10000, post_lang = { default = "", optional = true, explanation = "The language you would like to translate to.", example = "es-cr" } min_comments = { default = 20, optional = false, nmin = 15, type = "int", explanation = "The minimum number of comments a post should have to be included. default is 20", example = 29, oob_error = "the minimum number of comments should be between 15 and 999999" } +[ai] +ai_similarity_enabled = {optional = true, option = [true, false], default = false, type = "bool", explanation = "Threads read from Reddit are sorted based on their similarity to the keywords given below"} +ai_similarity_keywords = {optional = true, type="str", example= 'Elon Musk, Twitter, Stocks', explanation = "Every keyword or even sentence, seperated with comma, is used to sort the reddit threads based on similarity"} [settings] allow_nsfw = { optional = false, type = "bool", default = false, example = false, options = [true, false,], explanation = "Whether to allow NSFW content, True or False" } diff --git a/utils/ai_methods.py b/utils/ai_methods.py new file mode 100644 index 0000000..244cfff --- /dev/null +++ b/utils/ai_methods.py @@ -0,0 +1,58 @@ +import numpy as np +from transformers import AutoTokenizer, AutoModel +import torch + + +# Mean Pooling - Take attention mask into account for correct averaging +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] # First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + +# This function sort the given threads based on their total similarity with the given keywords +def sort_by_similarity(thread_objects, keywords): + # Initialize tokenizer + model. + tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') + model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') + + # Transform the generator to a list of Submission Objects, so we can sort later based on context similarity to + # keywords + thread_objects = list(thread_objects) + + threads_sentences = [] + for i, thread in enumerate(thread_objects): + threads_sentences.append(' '.join([thread.title, thread.selftext])) + + # Threads inference + encoded_threads = tokenizer(threads_sentences, padding=True, truncation=True, return_tensors='pt') + with torch.no_grad(): + threads_embeddings = model(**encoded_threads) + threads_embeddings = mean_pooling(threads_embeddings, encoded_threads['attention_mask']) + + # Keywords inference + encoded_keywords = tokenizer(keywords, padding=True, truncation=True, return_tensors='pt') + with torch.no_grad(): + keywords_embeddings = model(**encoded_keywords) + keywords_embeddings = mean_pooling(keywords_embeddings, encoded_keywords['attention_mask']) + + # Compare every keyword w/ every thread embedding + threads_embeddings_tensor = torch.tensor(threads_embeddings) + total_scores = torch.zeros(threads_embeddings_tensor.shape[0]) + cosine_similarity = torch.nn.CosineSimilarity() + for keyword_embedding in keywords_embeddings: + keyword_embedding = torch.tensor(keyword_embedding).repeat(threads_embeddings_tensor.shape[0], 1) + similarity = cosine_similarity(keyword_embedding, threads_embeddings_tensor) + total_scores += similarity + + similarity_scores, indices = torch.sort(total_scores, descending=True) + + threads_sentences = np.array(threads_sentences)[indices.numpy()] + + thread_objects = np.array(thread_objects)[indices.numpy()].tolist() + + #print('Similarity Thread Ranking') + #for i, thread in enumerate(thread_objects): + # print(f'{i}) {threads_sentences[i]} score {similarity_scores[i]}') + + return thread_objects, similarity_scores diff --git a/utils/subreddit.py b/utils/subreddit.py index c386868..95bbbf6 100644 --- a/utils/subreddit.py +++ b/utils/subreddit.py @@ -3,9 +3,10 @@ from os.path import exists from utils import settings from utils.console import print_substep +from utils.ai_methods import sort_by_similarity -def get_subreddit_undone(submissions: list, subreddit, times_checked=0): +def get_subreddit_undone(submissions: list, subreddit, times_checked=0, similarity_scores=None): """_summary_ Args: @@ -15,13 +16,18 @@ def get_subreddit_undone(submissions: list, subreddit, times_checked=0): Returns: Any: The submission that has not been done """ + # Second try of getting a valid Submission + if times_checked and settings.config["ai"]["ai_similarity_enabled"]: + print('Sorting based on similarity for a different date filter and thread limit..') + submissions = sort_by_similarity(submissions, keywords=settings.config["ai"]["ai_similarity_enabled"]) + # recursively checks if the top submission in the list was already done. if not exists("./video_creation/data/videos.json"): with open("./video_creation/data/videos.json", "w+") as f: json.dump([], f) with open("./video_creation/data/videos.json", "r", encoding="utf-8") as done_vids_raw: done_videos = json.load(done_vids_raw) - for submission in submissions: + for i, submission in enumerate(submissions): if already_done(done_videos, submission): continue if submission.over_18: @@ -39,6 +45,8 @@ def get_subreddit_undone(submissions: list, subreddit, times_checked=0): f'This post has under the specified minimum of comments ({settings.config["reddit"]["thread"]["min_comments"]}). Skipping...' ) continue + if similarity_scores is not None: + return submission, similarity_scores[i].item() return submission print("all submissions have been done going by top submission order") VALID_TIME_FILTERS = [