From 35d2ff25595bd7a7258424b8bb5c45c1a94abdfe Mon Sep 17 00:00:00 2001 From: electro199 <109358640+electro199@users.noreply.github.com> Date: Mon, 9 Jun 2025 16:09:02 +0500 Subject: [PATCH] use pydanctic model for templating and validation --- main.py | 3 +- utils/config_model.py | 417 ++++++++++++++++++++++++++++++++++++++++++ utils/settings.py | 265 ++++++++++++--------------- 3 files changed, 537 insertions(+), 148 deletions(-) create mode 100644 utils/config_model.py diff --git a/main.py b/main.py index 742fedf..99bbdbe 100755 --- a/main.py +++ b/main.py @@ -91,9 +91,8 @@ if __name__ == "__main__": ffmpeg_install() directory = Path().absolute() config = settings.check_toml( - f"{directory}/utils/.config.template.toml", f"{directory}/config.toml" + f"{directory}/config.toml" ) - config is False and sys.exit() if ( not settings.config["settings"]["tts"]["tiktok_sessionid"] diff --git a/utils/config_model.py b/utils/config_model.py new file mode 100644 index 0000000..6146d5f --- /dev/null +++ b/utils/config_model.py @@ -0,0 +1,417 @@ +from typing import Annotated, Literal, Optional +from pydantic import BaseModel, Field, StringConstraints + + +class RedditCreds(BaseModel): + client_id: Annotated[ + str, + StringConstraints( + min_length=12, max_length=30, pattern=r"^[-a-zA-Z0-9._~+/]+=*$" + ), + ] = Field(..., description="The ID of your Reddit app of SCRIPT type") + + client_secret: Annotated[ + str, + StringConstraints( + min_length=20, max_length=40, pattern=r"^[-a-zA-Z0-9._~+/]+=*$" + ), + ] = Field(..., description="The SECRET of your Reddit app of SCRIPT type") + + username: Annotated[ + str, StringConstraints(min_length=3, max_length=20, pattern=r"^[-_0-9a-zA-Z]+$") + ] = Field(..., description="The username of your Reddit account") + + password: Annotated[str, StringConstraints(min_length=8)] = Field( + ..., description="The password of your Reddit account" + ) + + twofa: Optional[bool] = Field(False, description="Whether Reddit 2FA is enabled") + + +class RedditThread(BaseModel): + random: Optional[bool] = Field( + False, description="If true, picks a random thread instead of asking for URL" + ) + + subreddit: Annotated[ + str, StringConstraints(min_length=3, max_length=20, pattern=r"[_0-9a-zA-Z\+]+$") + ] = Field(..., description="Name(s) of subreddit(s), '+' separated") + + post_id: Annotated[Optional[str], StringConstraints(pattern=r"^[+a-zA-Z0-9]*$")] = ( + Field("", description="Specify a Reddit post ID if desired") + ) + + max_comment_length: Annotated[int, Field(ge=10, le=10000)] = Field( + 500, description="Max number of characters per comment" + ) + + min_comment_length: Annotated[int, Field(ge=0, le=10000)] = Field( + 1, description="Min number of characters per comment" + ) + + post_lang: Optional[str] = Field( + "", description="Target language code for translation (e.g., 'es-cr')" + ) + + min_comments: Annotated[int, Field(ge=10)] = Field( + 20, description="Minimum number of comments required" + ) + + +class RedditThreadExtras(BaseModel): + min_comments: Annotated[ + int, + Field( + default=20, + ge=10, + le=999999, + description="The minimum number of comments a post should have to be included. Default is 20.", + examples=[29], + ), + ] + + +class AIConfig(BaseModel): + ai_similarity_enabled: Annotated[ + bool, + Field( + default=False, + description="Threads read from Reddit are sorted based on their similarity to the keywords given below.", + ), + ] + ai_similarity_keywords: Annotated[ + str, + Field( + default="", + description="Every keyword or sentence, separated by commas, is used to sort Reddit threads based on similarity.", + examples=["Elon Musk, Twitter, Stocks"], + ), + ] + + +class SettingsTTS(BaseModel): + voice_choice: Annotated[ + Literal[ + "elevenlabs", + "streamlabspolly", + "tiktok", + "googletranslate", + "awspolly", + "pyttsx", + ], + Field( + default="tiktok", + description="The voice platform used for TTS generation.", + examples=["tiktok"], + ), + ] + random_voice: Annotated[ + bool, + Field( + default=True, + description="Randomizes the voice used for each comment.", + examples=[True], + ), + ] + elevenlabs_voice_name: Annotated[ + Literal[ + "Adam", "Antoni", "Arnold", "Bella", "Domi", "Elli", "Josh", "Rachel", "Sam" + ], + Field( + default="Bella", + description="The voice used for ElevenLabs.", + examples=["Bella"], + ), + ] + elevenlabs_api_key: Annotated[ + str, + Field( + default="", + description="ElevenLabs API key.", + examples=["21f13f91f54d741e2ae27d2ab1b99d59"], + ), + ] + aws_polly_voice: Annotated[ + str, + Field( + default="Matthew", + description="The voice used for AWS Polly.", + examples=["Matthew"], + ), + ] + streamlabs_polly_voice: Annotated[ + str, + Field( + default="Matthew", + description="The voice used for Streamlabs Polly.", + examples=["Matthew"], + ), + ] + tiktok_voice: Annotated[ + str, + Field( + default="en_us_001", + description="The voice used for TikTok TTS.", + examples=["en_us_006"], + ), + ] + tiktok_sessionid: Annotated[ + str, + Field( + default="", + description="TikTok sessionid needed for TikTok TTS.", + examples=["c76bcc3a7625abcc27b508c7db457ff1"], + ), + ] + python_voice: Annotated[ + str, + Field( + default="1", + description="The index of the system TTS voices (starts from 0).", + examples=["1"], + ), + ] + py_voice_num: Annotated[ + str, + Field( + default="2", + description="The number of system voices available.", + examples=["2"], + ), + ] + silence_duration: Annotated[ + float, + Field( + default=0.3, + description="Time in seconds between TTS comments.", + examples=["0.1"], + ), + ] + no_emojis: Annotated[ + bool, + Field( + default=False, + description="Whether to remove emojis from the comments.", + examples=[False], + ), + ] + openai_api_url: Annotated[ + str, + Field( + default="https://api.openai.com/v1/", + description="The API endpoint URL for OpenAI TTS generation.", + examples=["https://api.openai.com/v1/"], + ), + ] + openai_api_key: Annotated[ + str, + Field( + default="", + description="Your OpenAI API key for TTS generation.", + examples=["sk-abc123def456..."], + ), + ] + openai_voice_name: Annotated[ + Literal[ + "alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer" + ], + Field( + default="alloy", + description="The voice used for OpenAI TTS generation.", + examples=["alloy"], + ), + ] + openai_model: Annotated[ + Literal["tts-1", "tts-1-hd"], + Field( + default="tts-1", + description="The model variant used for OpenAI TTS generation.", + examples=["tts-1"], + ), + ] + + +class SettingsBackground(BaseModel): + background_video: Annotated[ + str, + Field( + default="minecraft", + description="Sets the background for the video based on game name", + examples=["rocket-league"], + ), + StringConstraints(strip_whitespace=True), + ] = "minecraft" + + background_audio: Annotated[ + str, + Field( + default="lofi", + description="Sets the background audio for the video", + examples=["chill-summer"], + ), + StringConstraints(strip_whitespace=True), + ] = "lofi" + + background_audio_volume: Annotated[ + float, + Field( + default=0.15, + ge=0, + le=1, + description="Sets the volume of the background audio. If you don't want background audio, set it to 0.", + examples=[0.05], + ), + ] = 0.15 + + enable_extra_audio: Annotated[ + bool, + Field( + default=False, + description="Used if you want to render another video without background audio in a separate folder", + ), + ] = False + + background_thumbnail: Annotated[ + bool, + Field( + default=False, + description="Generate a thumbnail for the video (put a thumbnail.png file in the assets/backgrounds directory.)", + ), + ] = False + + background_thumbnail_font_family: Annotated[ + str, + Field( + default="arial", + description="Font family for the thumbnail text", + examples=["arial"], + ), + ] = "arial" + + background_thumbnail_font_size: Annotated[ + int, + Field( + default=96, + description="Font size in pixels for the thumbnail text", + examples=[96], + ), + ] = 96 + + background_thumbnail_font_color: Annotated[ + str, + Field( + default="255,255,255", + description="Font color in RGB format for the thumbnail text", + examples=["255,255,255"], + ), + ] = "255,255,255" + + +class Settings(BaseModel): + allow_nsfw: Annotated[ + bool, + Field( + default=False, + description="Whether to allow NSFW content. True or False.", + examples=[False], + ), + ] + theme: Annotated[ + Literal["dark", "light", "transparent"], + Field( + default="dark", + description="Sets the Reddit theme. For story mode, 'transparent' is also allowed.", + examples=["light"], + ), + ] + times_to_run: Annotated[ + int, + Field( + default=1, + ge=1, + description="Used if you want to run multiple times. Must be an int >= 1.", + examples=[2], + ), + ] + opacity: Annotated[ + float, + Field( + default=0.9, + ge=0.0, + le=1.0, + description="Sets the opacity of comments when overlaid over the background.", + examples=[0.8], + ), + ] + storymode: Annotated[ + bool, + Field( + default=False, + description="Only read out title and post content. Great for story-based subreddits.", + examples=[False], + ), + ] + storymodemethod: Annotated[ + Literal[0, 1], + Field( + default=1, + description="Style used for story mode: 0 = static image, 1 = fancy video.", + examples=[1], + ), + ] + storymode_max_length: Annotated[ + int, + Field( + default=1000, + ge=1, + description="Max length (in characters) of the story mode video.", + examples=[1000], + ), + ] + resolution_w: Annotated[ + int, + Field( + default=1080, + description="Sets the width in pixels of the final video.", + examples=[1440], + ), + ] + resolution_h: Annotated[ + int, + Field( + default=1920, + description="Sets the height in pixels of the final video.", + examples=[2560], + ), + ] + zoom: Annotated[ + float, + Field( + default=1.0, + ge=0.1, + le=2.0, + description="Sets the browser zoom level. Useful for making text larger.", + examples=[1.1], + ), + ] + channel_name: Annotated[ + str, + Field( + default="Reddit Tales", + description="Sets the channel name for the video.", + examples=["Reddit Stories"], + ), + ] + tts: SettingsTTS + background: SettingsBackground + + +class Reddit(BaseModel): + creds: RedditCreds + thread: RedditThread + + +class Config(BaseModel): + reddit: Reddit + ai: AIConfig + settings: Settings \ No newline at end of file diff --git a/utils/settings.py b/utils/settings.py index 6b8242b..0c5ab14 100755 --- a/utils/settings.py +++ b/utils/settings.py @@ -1,170 +1,143 @@ -import re +import sys from pathlib import Path -from typing import Dict, Tuple +from typing import Any, Dict import toml from rich.console import Console -from utils.console import handle_input +from utils.config_model import Config +from utils.console import print_substep console = Console() -config = dict # autocomplete - +config: dict # autocomplete +from typing import Any + +from pydantic import ValidationError, BaseModel +from pydantic_core import PydanticUndefined + + +def prompt_recursive(obj: BaseModel): + """ + Recursively prompt for missing or invalid fields in a Pydantic model instance 'obj'. + """ + for field_name, field in obj.model_fields.items(): + value = getattr(obj, field_name, None) + # If field is a nested BaseModel, recurse into it + if hasattr(field.annotation, "model_fields"): + nested_obj = value or field.annotation.model_construct() + fixed_nested = prompt_recursive(nested_obj) + setattr(obj, field_name, fixed_nested) + continue -def crawl(obj: dict, func=lambda x, y: print(x, y, end="\n"), path=None): - if path is None: # path Default argument value is mutable - path = [] - for key in obj.keys(): - if type(obj[key]) is dict: - crawl(obj[key], func, path + [key]) + # If the value is valid and not None, skip prompt + if value not in [None, "", [], {}]: continue - func(path + [key], obj[key]) - - -def check(value, checks, name): - def get_check_value(key, default_result): - return checks[key] if key in checks else default_result - - incorrect = False - if value == {}: - incorrect = True - if not incorrect and "type" in checks: - try: - value = eval(checks["type"])(value) # fixme remove eval - except: - incorrect = True - - if ( - not incorrect and "options" in checks and value not in checks["options"] - ): # FAILSTATE Value is not one of the options - incorrect = True - if ( - not incorrect - and "regex" in checks - and ( - (isinstance(value, str) and re.match(checks["regex"], value) is None) - or not isinstance(value, str) - ) - ): # FAILSTATE Value doesn't match regex, or has regex but is not a string. - incorrect = True - - if ( - not incorrect - and not hasattr(value, "__iter__") - and ( - ("nmin" in checks and checks["nmin"] is not None and value < checks["nmin"]) - or ("nmax" in checks and checks["nmax"] is not None and value > checks["nmax"]) - ) - ): - incorrect = True - if ( - not incorrect - and hasattr(value, "__iter__") - and ( - ("nmin" in checks and checks["nmin"] is not None and len(value) < checks["nmin"]) - or ("nmax" in checks and checks["nmax"] is not None and len(value) > checks["nmax"]) - ) - ): - incorrect = True - - if incorrect: - value = handle_input( - message=( - (("[blue]Example: " + str(checks["example"]) + "\n") if "example" in checks else "") - + "[red]" - + ("Non-optional ", "Optional ")["optional" in checks and checks["optional"] is True] - ) - + "[#C0CAF5 bold]" - + str(name) - + "[#F7768E bold]=", - extra_info=get_check_value("explanation", ""), - check_type=eval(get_check_value("type", "False")), # fixme remove eval - default=get_check_value("default", NotImplemented), - match=get_check_value("regex", ""), - err_message=get_check_value("input_error", "Incorrect input"), - nmin=get_check_value("nmin", None), - nmax=get_check_value("nmax", None), - oob_error=get_check_value( - "oob_error", "Input out of bounds(Value too high/low/long/short)" - ), - options=get_check_value("options", None), - optional=get_check_value("optional", False), - ) - return value + description = field.description or "" + default_str = ( + f" (default: {field.default})" + if (field.default is not None) or field.default == PydanticUndefined + else "" + ) + prompt_msg = f"🧩 {field_name}\n 📘 {description}{default_str}\n ⚠️ Required: {field.is_required()}\n ❓ Enter value: " + + while True: + user_input = input(prompt_msg).strip() + if not user_input: + if field.default is not None: + value_to_set = field.default + elif not field.is_required(): + value_to_set = None + else: + print("⚠️ This field is required.") + continue + else: + # Convert input based on type, you can expand this logic + try: + value_to_set = parse_value(user_input, field.annotation) + except Exception as e: + print(f"⚠️ Invalid input: {e}") + continue + + # Validate the assignment + try: + obj.__pydantic_validator__.validate_assignment( + obj, field_name, value_to_set + ) + setattr(obj, field_name, value_to_set) + break + except ValidationError as ve: + for err in ve.errors(): + print(f"❌ {err['loc'][0]}: {err['msg']}") -def crawl_and_check(obj: dict, path: list, checks: dict = {}, name=""): - if len(path) == 0: - return check(obj, checks, name) - if path[0] not in obj.keys(): - obj[path[0]] = {} - obj[path[0]] = crawl_and_check(obj[path[0]], path[1:], checks, path[0]) return obj -def check_vars(path, checks): - global config - crawl_and_check(config, path, checks) +def parse_value(raw: str, expected_type: type): + from typing import get_args, get_origin + origin = get_origin(expected_type) + args = get_args(expected_type) -def check_toml(template_file, config_file) -> Tuple[bool, Dict]: - global config - config = None - try: - template = toml.load(template_file) - except Exception as error: - console.print(f"[red bold]Encountered error when trying to to load {template_file}: {error}") - return False - try: - config = toml.load(config_file) - except toml.TomlDecodeError: - console.print( - f"""[blue]Couldn't read {config_file}. -Overwrite it?(y/n)""" - ) - if not input().startswith("y"): - print("Unable to read config, and not allowed to overwrite it. Giving up.") + if expected_type == bool: + if raw.lower() in ("true", "yes", "1"): + return True + elif raw.lower() in ("false", "no", "0"): return False else: - try: - with open(config_file, "w") as f: - f.write("") - except: - console.print( - f"[red bold]Failed to overwrite {config_file}. Giving up.\nSuggestion: check {config_file} permissions for the user." - ) - return False - except FileNotFoundError: - console.print( - f"""[blue]Couldn't find {config_file} -Creating it now.""" - ) - try: - with open(config_file, "x") as f: - f.write("") - config = {} - except: - console.print( - f"[red bold]Failed to write to {config_file}. Giving up.\nSuggestion: check the folder's permissions for the user." - ) - return False + raise ValueError("Expected boolean value (true/false)") + elif expected_type == int: + return int(raw) + elif expected_type == float: + return float(raw) + elif expected_type == str: + return raw + elif origin == list and args: + return [parse_value(x.strip(), args[0]) for x in raw.split(",")] + else: + raise ValueError(f"Unsupported field type: {expected_type}") + + +def check_toml(config_file: str) -> Dict[str, Any]: + """ + Load the template and config TOML files. + Validate config with Pydantic. + If invalid, prompt for missing or invalid fields. + Save fixed config back. + Return the valid Config model. + """ + try: + config_dict = toml.load(config_file) + except Exception as e: + print(f"Failed to load config {config_file}: {e}") + config_dict = {} - console.print( - """\ -[blue bold]############################### -# # -# Checking TOML configuration # -# # -############################### -If you see any prompts, that means that you have unset/incorrectly set variables, please input the correct values.\ -""" - ) - crawl(template, check_vars) - with open(config_file, "w") as f: - toml.dump(config, f) + try: + config_instance = Config.model_validate(config_dict) + except ValidationError as e: + print("Config validation failed, will prompt for missing/invalid fields:") + print(e) + # Start from a clean model + config_instance = Config.model_construct() + # Update model with any valid partial data loaded from config + for k, v in config_dict.items(): + if hasattr(config_instance, k): + setattr(config_instance, k, v) + + # Prompt for missing or invalid fields recursively + config_instance = prompt_recursive(config_instance) + + # Validate again to be sure + config_instance = Config.model_validate(config_instance.model_dump()) + + # Save fixed config back to file + with open(config_file, "w", encoding="utf-8") as f: + toml.dump(config_instance.model_dump(), f) + print(f"Updated config saved to {config_file}") + config = config_instance.model_dump() return config if __name__ == "__main__": directory = Path().absolute() - check_toml(f"{directory}/utils/.config.template.toml", "config.toml") + check_toml("config.toml") \ No newline at end of file