CHORE: syntactical changes + some pythonic changes and future improvements

pull/2080/head
Jason 1 year ago
parent 6b474b4b50
commit fc9a166d39

@ -4,7 +4,7 @@ import sys
from os import name from os import name
from pathlib import Path from pathlib import Path
from subprocess import Popen from subprocess import Popen
from typing import NoReturn from typing import NoReturn, Dict
from prawcore import ResponseException from prawcore import ResponseException
@ -13,7 +13,7 @@ from utils import settings
from utils.cleanup import cleanup from utils.cleanup import cleanup
from utils.console import print_markdown, print_step, print_substep from utils.console import print_markdown, print_step, print_substep
from utils.ffmpeg_install import ffmpeg_install from utils.ffmpeg_install import ffmpeg_install
from utils.id import id from utils.id import extract_id
from utils.version import checkversion from utils.version import checkversion
from video_creation.background import ( from video_creation.background import (
chop_background, chop_background,
@ -42,11 +42,14 @@ print_markdown(
) )
checkversion(__VERSION__) checkversion(__VERSION__)
reddit_id: str
reddit_object: Dict[str, str | list]
def main(POST_ID=None) -> None: def main(POST_ID=None) -> None:
global redditid, reddit_object global reddit_id, reddit_object
reddit_object = get_subreddit_threads(POST_ID) reddit_object = get_subreddit_threads(POST_ID)
redditid = id(reddit_object) reddit_id = extract_id(reddit_object)
print_substep(f"Thread ID is {reddit_id}", style="bold blue")
length, number_of_comments = save_text_to_mp3(reddit_object) length, number_of_comments = save_text_to_mp3(reddit_object)
length = math.ceil(length) length = math.ceil(length)
get_screenshots_of_reddit_posts(reddit_object, number_of_comments) get_screenshots_of_reddit_posts(reddit_object, number_of_comments)
@ -64,15 +67,15 @@ def run_many(times) -> None:
for x in range(1, times + 1): for x in range(1, times + 1):
print_step( print_step(
f'on the {x}{("th", "st", "nd", "rd", "th", "th", "th", "th", "th", "th")[x % 10]} iteration of {times}' f'on the {x}{("th", "st", "nd", "rd", "th", "th", "th", "th", "th", "th")[x % 10]} iteration of {times}'
) # correct 1st 2nd 3rd 4th 5th.... )
main() main()
Popen("cls" if name == "nt" else "clear", shell=True).wait() Popen("cls" if name == "nt" else "clear", shell=True).wait()
def shutdown() -> NoReturn: def shutdown() -> NoReturn:
if "redditid" in globals(): if "reddit_id" in globals():
print_markdown("## Clearing temp files") print_markdown("## Clearing temp files")
cleanup(redditid) cleanup(reddit_id)
print("Exiting...") print("Exiting...")
sys.exit() sys.exit()

@ -12,7 +12,7 @@ def mean_pooling(model_output, attention_mask):
) )
# This function sort the given threads based on their total similarity with the given keywords # This function sorts the given threads based on their total similarity with the given keywords
def sort_by_similarity(thread_objects, keywords): def sort_by_similarity(thread_objects, keywords):
# Initialize tokenizer + model. # Initialize tokenizer + model.
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2") tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
@ -34,7 +34,7 @@ def sort_by_similarity(thread_objects, keywords):
threads_embeddings = model(**encoded_threads) threads_embeddings = model(**encoded_threads)
threads_embeddings = mean_pooling(threads_embeddings, encoded_threads["attention_mask"]) threads_embeddings = mean_pooling(threads_embeddings, encoded_threads["attention_mask"])
# Keywords inference # Keyword inference
encoded_keywords = tokenizer(keywords, padding=True, truncation=True, return_tensors="pt") encoded_keywords = tokenizer(keywords, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad(): with torch.no_grad():
keywords_embeddings = model(**encoded_keywords) keywords_embeddings = model(**encoded_keywords)
@ -53,7 +53,7 @@ def sort_by_similarity(thread_objects, keywords):
similarity_scores, indices = torch.sort(total_scores, descending=True) similarity_scores, indices = torch.sort(total_scores, descending=True)
threads_sentences = np.array(threads_sentences)[indices.numpy()] # threads_sentences = np.array(threads_sentences)[indices.numpy()]
thread_objects = np.array(thread_objects)[indices.numpy()].tolist() thread_objects = np.array(thread_objects)[indices.numpy()].tolist()

@ -102,7 +102,7 @@ def handle_input(
user_input = input("").strip() user_input = input("").strip()
if check_type is not False: if check_type is not False:
try: try:
isinstance(eval(user_input), check_type) isinstance(eval(user_input), check_type) # fixme: remove eval
return check_type(user_input) return check_type(user_input)
except: except:
console.print( console.print(

@ -28,8 +28,8 @@ def ffmpeg_install_windows():
for root, dirs, files in os.walk(ffmpeg_extracted_folder, topdown=False): for root, dirs, files in os.walk(ffmpeg_extracted_folder, topdown=False):
for file in files: for file in files:
os.remove(os.path.join(root, file)) os.remove(os.path.join(root, file))
for dir in dirs: for directory in dirs:
os.rmdir(os.path.join(root, dir)) os.rmdir(os.path.join(root, directory))
os.rmdir(ffmpeg_extracted_folder) os.rmdir(ffmpeg_extracted_folder)
# Extract FFmpeg # Extract FFmpeg
@ -110,7 +110,7 @@ def ffmpeg_install():
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
) )
except FileNotFoundError as e: except FileNotFoundError:
# Check if there's ffmpeg.exe in the current directory # Check if there's ffmpeg.exe in the current directory
if os.path.exists("./ffmpeg.exe"): if os.path.exists("./ffmpeg.exe"):
print( print(

@ -25,7 +25,9 @@ def get_checks():
# Get current config (from config.toml) as dict # Get current config (from config.toml) as dict
def get_config(obj: dict, done={}): def get_config(obj: dict, done=None):
if done is None:
done = {}
for key in obj.keys(): for key in obj.keys():
if not isinstance(obj[key], dict): if not isinstance(obj[key], dict):
done[key] = obj[key] done[key] = obj[key]
@ -44,13 +46,13 @@ def check(value, checks):
if not incorrect and "type" in checks: if not incorrect and "type" in checks:
try: try:
value = eval(checks["type"])(value) value = eval(checks["type"])(value) # fixme remove eval
except Exception: except Exception:
incorrect = True incorrect = True
if ( if (
not incorrect and "options" in checks and value not in checks["options"] not incorrect and "options" in checks and value not in checks["options"]
): # FAILSTATE Value is not one of the options ): # FAILSTATE Value isn't one of the options
incorrect = True incorrect = True
if ( if (
not incorrect not incorrect
@ -59,7 +61,7 @@ def check(value, checks):
(isinstance(value, str) and re.match(checks["regex"], value) is None) (isinstance(value, str) and re.match(checks["regex"], value) is None)
or not isinstance(value, str) or not isinstance(value, str)
) )
): # FAILSTATE Value doesn't match regex, or has regex but is not a string. ): # FAILSTATE Value doesn't match regular expression, or has regular expression but isn't a string.
incorrect = True incorrect = True
if ( if (
@ -88,17 +90,17 @@ def check(value, checks):
return value return value
# Modify settings (after form is submitted) # Modify settings (after the form is submitted)
def modify_settings(data: dict, config_load, checks: dict): def modify_settings(data: dict, config_load, checks: dict):
# Modify config settings # Modify config settings
def modify_config(obj: dict, name: str, value: any): def modify_config(obj: dict, config_name: str, value: any):
for key in obj.keys(): for key in obj.keys():
if name == key: if config_name == key:
obj[key] = value obj[key] = value
elif not isinstance(obj[key], dict): elif not isinstance(obj[key], dict):
continue continue
else: else:
modify_config(obj[key], name, value) modify_config(obj[key], config_name, value)
# Remove empty/incorrect key-value pairs # Remove empty/incorrect key-value pairs
data = {key: value for key, value in data.items() if value and key in checks.keys()} data = {key: value for key, value in data.items() if value and key in checks.keys()}
@ -158,7 +160,7 @@ def add_background(youtube_uri, filename, citation, position):
youtube_uri = f"https://www.youtube.com/watch?v={regex.group(1)}" youtube_uri = f"https://www.youtube.com/watch?v={regex.group(1)}"
# Check if position is valid # Check if the position is valid
if position == "" or position == "center": if position == "" or position == "center":
position = "center" position = "center"
@ -178,7 +180,7 @@ def add_background(youtube_uri, filename, citation, position):
filename = filename.replace(" ", "_") filename = filename.replace(" ", "_")
# Check if background doesn't already exist # Check if the background doesn't already exist
with open("utils/backgrounds.json", "r", encoding="utf-8") as backgrounds: with open("utils/backgrounds.json", "r", encoding="utf-8") as backgrounds:
data = json.load(backgrounds) data = json.load(backgrounds)

@ -1,12 +1,14 @@
import re import re
from typing import Optional
from utils.console import print_substep from utils.console import print_substep
def id(reddit_obj: dict): def extract_id(reddit_obj: dict, field: Optional[str] = "thread_id"):
""" """
This function takes a reddit object and returns the post id This function takes a reddit object and returns the post id
""" """
id = re.sub(r"[^\w\s-]", "", reddit_obj["thread_id"]) if field not in reddit_obj.keys():
print_substep(f"Thread ID is {id}", style="bold blue") raise ValueError(f"Field '{field}' not found in reddit object")
return id reddit_id = re.sub(r"[^\w\s-]", "", reddit_obj[field])
return reddit_id

@ -7,6 +7,7 @@ from rich.progress import track
from TTS.engine_wrapper import process_text from TTS.engine_wrapper import process_text
from utils.fonts import getheight, getsize from utils.fonts import getheight, getsize
from utils.id import extract_id
def draw_multiple_line_text( def draw_multiple_line_text(
@ -58,18 +59,16 @@ def imagemaker(theme, reddit_obj: dict, txtclr, padding=5, transparent=False) ->
Render Images for video Render Images for video
""" """
texts = reddit_obj["thread_post"] texts = reddit_obj["thread_post"]
id = re.sub(r"[^\w\s-]", "", reddit_obj["thread_id"]) reddit_id = extract_id(reddit_obj)
if transparent: if transparent:
font = ImageFont.truetype(os.path.join("fonts", "Roboto-Bold.ttf"), 100) font = ImageFont.truetype(os.path.join("fonts", "Roboto-Bold.ttf"), 100)
else: else:
font = ImageFont.truetype(os.path.join("fonts", "Roboto-Regular.ttf"), 100) font = ImageFont.truetype(os.path.join("fonts", "Roboto-Regular.ttf"), 100)
size = (1920, 1080) size = (1920, 1080)
image = Image.new("RGBA", size, theme)
for idx, text in track(enumerate(texts), "Rendering Image"): for idx, text in track(enumerate(texts), "Rendering Image"):
image = Image.new("RGBA", size, theme) image = Image.new("RGBA", size, theme)
text = process_text(text, False) text = process_text(text, False)
draw_multiple_line_text(image, text, font, txtclr, padding, wrap=30, transparent=transparent) draw_multiple_line_text(image, text, font, txtclr, padding, wrap=30, transparent=transparent)
image.save(f"assets/temp/{id}/png/img{idx}.png") image.save(f"assets/temp/{reddit_id}/png/img{idx}.png")

@ -30,7 +30,7 @@ def check(value, checks, name):
incorrect = True incorrect = True
if not incorrect and "type" in checks: if not incorrect and "type" in checks:
try: try:
value = eval(checks["type"])(value) value = eval(checks["type"])(value) # fixme remove eval
except: except:
incorrect = True incorrect = True
@ -78,7 +78,7 @@ def check(value, checks, name):
+ str(name) + str(name)
+ "[#F7768E bold]=", + "[#F7768E bold]=",
extra_info=get_check_value("explanation", ""), extra_info=get_check_value("explanation", ""),
check_type=eval(get_check_value("type", "False")), check_type=eval(get_check_value("type", "False")), # fixme remove eval
default=get_check_value("default", NotImplemented), default=get_check_value("default", NotImplemented),
match=get_check_value("regex", ""), match=get_check_value("regex", ""),
err_message=get_check_value("input_error", "Incorrect input"), err_message=get_check_value("input_error", "Incorrect input"),

@ -19,6 +19,7 @@ from utils import settings
from utils.cleanup import cleanup from utils.cleanup import cleanup
from utils.console import print_step, print_substep from utils.console import print_step, print_substep
from utils.fonts import getheight from utils.fonts import getheight
from utils.id import extract_id
from utils.thumbnail import create_thumbnail from utils.thumbnail import create_thumbnail
from utils.videos import save_data from utils.videos import save_data
@ -204,7 +205,7 @@ def make_final_video(
opacity = settings.config["settings"]["opacity"] opacity = settings.config["settings"]["opacity"]
reddit_id = re.sub(r"[^\w\s-]", "", reddit_obj["thread_id"]) reddit_id = extract_id(reddit_obj)
allowOnlyTTSFolder: bool = ( allowOnlyTTSFolder: bool = (
settings.config["settings"]["background"]["enable_extra_audio"] settings.config["settings"]["background"]["enable_extra_audio"]
@ -343,8 +344,8 @@ def make_final_video(
) )
current_time += audio_clips_durations[i] current_time += audio_clips_durations[i]
title = re.sub(r"[^\w\s-]", "", reddit_obj["thread_title"]) title = extract_id(reddit_obj, "thread_title")
idx = re.sub(r"[^\w\s-]", "", reddit_obj["thread_id"]) idx = extract_id(reddit_obj)
title_thumb = reddit_obj["thread_title"] title_thumb = reddit_obj["thread_title"]
filename = f"{name_normalize(title)[:251]}" filename = f"{name_normalize(title)[:251]}"

@ -34,7 +34,7 @@ def get_screenshots_of_reddit_posts(reddit_object: dict, screenshot_num: int):
# ! Make sure the reddit screenshots folder exists # ! Make sure the reddit screenshots folder exists
Path(f"assets/temp/{reddit_id}/png").mkdir(parents=True, exist_ok=True) Path(f"assets/temp/{reddit_id}/png").mkdir(parents=True, exist_ok=True)
# set the theme and disable non-essential cookies # set the theme and turn off non-essential cookies
if settings.config["settings"]["theme"] == "dark": if settings.config["settings"]["theme"] == "dark":
cookie_file = open("./video_creation/data/cookie-dark-mode.json", encoding="utf-8") cookie_file = open("./video_creation/data/cookie-dark-mode.json", encoding="utf-8")
bgcolor = (33, 33, 36, 255) bgcolor = (33, 33, 36, 255)
@ -60,7 +60,6 @@ def get_screenshots_of_reddit_posts(reddit_object: dict, screenshot_num: int):
transparent = False transparent = False
if storymode and settings.config["settings"]["storymodemethod"] == 1: if storymode and settings.config["settings"]["storymodemethod"] == 1:
# for idx,item in enumerate(reddit_object["thread_post"]):
print_substep("Generating images...") print_substep("Generating images...")
return imagemaker( return imagemaker(
theme=bgcolor, theme=bgcolor,

Loading…
Cancel
Save