use pydanctic model for templating and validation

pull/2341/head
electro199 3 months ago
parent 1f410033ea
commit 35d2ff2559

@ -91,9 +91,8 @@ if __name__ == "__main__":
ffmpeg_install() ffmpeg_install()
directory = Path().absolute() directory = Path().absolute()
config = settings.check_toml( config = settings.check_toml(
f"{directory}/utils/.config.template.toml", f"{directory}/config.toml" f"{directory}/config.toml"
) )
config is False and sys.exit()
if ( if (
not settings.config["settings"]["tts"]["tiktok_sessionid"] not settings.config["settings"]["tts"]["tiktok_sessionid"]

@ -0,0 +1,417 @@
from typing import Annotated, Literal, Optional
from pydantic import BaseModel, Field, StringConstraints
class RedditCreds(BaseModel):
client_id: Annotated[
str,
StringConstraints(
min_length=12, max_length=30, pattern=r"^[-a-zA-Z0-9._~+/]+=*$"
),
] = Field(..., description="The ID of your Reddit app of SCRIPT type")
client_secret: Annotated[
str,
StringConstraints(
min_length=20, max_length=40, pattern=r"^[-a-zA-Z0-9._~+/]+=*$"
),
] = Field(..., description="The SECRET of your Reddit app of SCRIPT type")
username: Annotated[
str, StringConstraints(min_length=3, max_length=20, pattern=r"^[-_0-9a-zA-Z]+$")
] = Field(..., description="The username of your Reddit account")
password: Annotated[str, StringConstraints(min_length=8)] = Field(
..., description="The password of your Reddit account"
)
twofa: Optional[bool] = Field(False, description="Whether Reddit 2FA is enabled")
class RedditThread(BaseModel):
random: Optional[bool] = Field(
False, description="If true, picks a random thread instead of asking for URL"
)
subreddit: Annotated[
str, StringConstraints(min_length=3, max_length=20, pattern=r"[_0-9a-zA-Z\+]+$")
] = Field(..., description="Name(s) of subreddit(s), '+' separated")
post_id: Annotated[Optional[str], StringConstraints(pattern=r"^[+a-zA-Z0-9]*$")] = (
Field("", description="Specify a Reddit post ID if desired")
)
max_comment_length: Annotated[int, Field(ge=10, le=10000)] = Field(
500, description="Max number of characters per comment"
)
min_comment_length: Annotated[int, Field(ge=0, le=10000)] = Field(
1, description="Min number of characters per comment"
)
post_lang: Optional[str] = Field(
"", description="Target language code for translation (e.g., 'es-cr')"
)
min_comments: Annotated[int, Field(ge=10)] = Field(
20, description="Minimum number of comments required"
)
class RedditThreadExtras(BaseModel):
min_comments: Annotated[
int,
Field(
default=20,
ge=10,
le=999999,
description="The minimum number of comments a post should have to be included. Default is 20.",
examples=[29],
),
]
class AIConfig(BaseModel):
ai_similarity_enabled: Annotated[
bool,
Field(
default=False,
description="Threads read from Reddit are sorted based on their similarity to the keywords given below.",
),
]
ai_similarity_keywords: Annotated[
str,
Field(
default="",
description="Every keyword or sentence, separated by commas, is used to sort Reddit threads based on similarity.",
examples=["Elon Musk, Twitter, Stocks"],
),
]
class SettingsTTS(BaseModel):
voice_choice: Annotated[
Literal[
"elevenlabs",
"streamlabspolly",
"tiktok",
"googletranslate",
"awspolly",
"pyttsx",
],
Field(
default="tiktok",
description="The voice platform used for TTS generation.",
examples=["tiktok"],
),
]
random_voice: Annotated[
bool,
Field(
default=True,
description="Randomizes the voice used for each comment.",
examples=[True],
),
]
elevenlabs_voice_name: Annotated[
Literal[
"Adam", "Antoni", "Arnold", "Bella", "Domi", "Elli", "Josh", "Rachel", "Sam"
],
Field(
default="Bella",
description="The voice used for ElevenLabs.",
examples=["Bella"],
),
]
elevenlabs_api_key: Annotated[
str,
Field(
default="",
description="ElevenLabs API key.",
examples=["21f13f91f54d741e2ae27d2ab1b99d59"],
),
]
aws_polly_voice: Annotated[
str,
Field(
default="Matthew",
description="The voice used for AWS Polly.",
examples=["Matthew"],
),
]
streamlabs_polly_voice: Annotated[
str,
Field(
default="Matthew",
description="The voice used for Streamlabs Polly.",
examples=["Matthew"],
),
]
tiktok_voice: Annotated[
str,
Field(
default="en_us_001",
description="The voice used for TikTok TTS.",
examples=["en_us_006"],
),
]
tiktok_sessionid: Annotated[
str,
Field(
default="",
description="TikTok sessionid needed for TikTok TTS.",
examples=["c76bcc3a7625abcc27b508c7db457ff1"],
),
]
python_voice: Annotated[
str,
Field(
default="1",
description="The index of the system TTS voices (starts from 0).",
examples=["1"],
),
]
py_voice_num: Annotated[
str,
Field(
default="2",
description="The number of system voices available.",
examples=["2"],
),
]
silence_duration: Annotated[
float,
Field(
default=0.3,
description="Time in seconds between TTS comments.",
examples=["0.1"],
),
]
no_emojis: Annotated[
bool,
Field(
default=False,
description="Whether to remove emojis from the comments.",
examples=[False],
),
]
openai_api_url: Annotated[
str,
Field(
default="https://api.openai.com/v1/",
description="The API endpoint URL for OpenAI TTS generation.",
examples=["https://api.openai.com/v1/"],
),
]
openai_api_key: Annotated[
str,
Field(
default="",
description="Your OpenAI API key for TTS generation.",
examples=["sk-abc123def456..."],
),
]
openai_voice_name: Annotated[
Literal[
"alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"
],
Field(
default="alloy",
description="The voice used for OpenAI TTS generation.",
examples=["alloy"],
),
]
openai_model: Annotated[
Literal["tts-1", "tts-1-hd"],
Field(
default="tts-1",
description="The model variant used for OpenAI TTS generation.",
examples=["tts-1"],
),
]
class SettingsBackground(BaseModel):
background_video: Annotated[
str,
Field(
default="minecraft",
description="Sets the background for the video based on game name",
examples=["rocket-league"],
),
StringConstraints(strip_whitespace=True),
] = "minecraft"
background_audio: Annotated[
str,
Field(
default="lofi",
description="Sets the background audio for the video",
examples=["chill-summer"],
),
StringConstraints(strip_whitespace=True),
] = "lofi"
background_audio_volume: Annotated[
float,
Field(
default=0.15,
ge=0,
le=1,
description="Sets the volume of the background audio. If you don't want background audio, set it to 0.",
examples=[0.05],
),
] = 0.15
enable_extra_audio: Annotated[
bool,
Field(
default=False,
description="Used if you want to render another video without background audio in a separate folder",
),
] = False
background_thumbnail: Annotated[
bool,
Field(
default=False,
description="Generate a thumbnail for the video (put a thumbnail.png file in the assets/backgrounds directory.)",
),
] = False
background_thumbnail_font_family: Annotated[
str,
Field(
default="arial",
description="Font family for the thumbnail text",
examples=["arial"],
),
] = "arial"
background_thumbnail_font_size: Annotated[
int,
Field(
default=96,
description="Font size in pixels for the thumbnail text",
examples=[96],
),
] = 96
background_thumbnail_font_color: Annotated[
str,
Field(
default="255,255,255",
description="Font color in RGB format for the thumbnail text",
examples=["255,255,255"],
),
] = "255,255,255"
class Settings(BaseModel):
allow_nsfw: Annotated[
bool,
Field(
default=False,
description="Whether to allow NSFW content. True or False.",
examples=[False],
),
]
theme: Annotated[
Literal["dark", "light", "transparent"],
Field(
default="dark",
description="Sets the Reddit theme. For story mode, 'transparent' is also allowed.",
examples=["light"],
),
]
times_to_run: Annotated[
int,
Field(
default=1,
ge=1,
description="Used if you want to run multiple times. Must be an int >= 1.",
examples=[2],
),
]
opacity: Annotated[
float,
Field(
default=0.9,
ge=0.0,
le=1.0,
description="Sets the opacity of comments when overlaid over the background.",
examples=[0.8],
),
]
storymode: Annotated[
bool,
Field(
default=False,
description="Only read out title and post content. Great for story-based subreddits.",
examples=[False],
),
]
storymodemethod: Annotated[
Literal[0, 1],
Field(
default=1,
description="Style used for story mode: 0 = static image, 1 = fancy video.",
examples=[1],
),
]
storymode_max_length: Annotated[
int,
Field(
default=1000,
ge=1,
description="Max length (in characters) of the story mode video.",
examples=[1000],
),
]
resolution_w: Annotated[
int,
Field(
default=1080,
description="Sets the width in pixels of the final video.",
examples=[1440],
),
]
resolution_h: Annotated[
int,
Field(
default=1920,
description="Sets the height in pixels of the final video.",
examples=[2560],
),
]
zoom: Annotated[
float,
Field(
default=1.0,
ge=0.1,
le=2.0,
description="Sets the browser zoom level. Useful for making text larger.",
examples=[1.1],
),
]
channel_name: Annotated[
str,
Field(
default="Reddit Tales",
description="Sets the channel name for the video.",
examples=["Reddit Stories"],
),
]
tts: SettingsTTS
background: SettingsBackground
class Reddit(BaseModel):
creds: RedditCreds
thread: RedditThread
class Config(BaseModel):
reddit: Reddit
ai: AIConfig
settings: Settings

@ -1,170 +1,143 @@
import re import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple from typing import Any, Dict
import toml import toml
from rich.console import Console from rich.console import Console
from utils.console import handle_input from utils.config_model import Config
from utils.console import print_substep
console = Console() console = Console()
config = dict # autocomplete config: dict # autocomplete
from typing import Any
from pydantic import ValidationError, BaseModel
from pydantic_core import PydanticUndefined
def prompt_recursive(obj: BaseModel):
"""
Recursively prompt for missing or invalid fields in a Pydantic model instance 'obj'.
"""
for field_name, field in obj.model_fields.items():
value = getattr(obj, field_name, None)
# If field is a nested BaseModel, recurse into it
if hasattr(field.annotation, "model_fields"):
nested_obj = value or field.annotation.model_construct()
fixed_nested = prompt_recursive(nested_obj)
setattr(obj, field_name, fixed_nested)
continue
def crawl(obj: dict, func=lambda x, y: print(x, y, end="\n"), path=None): # If the value is valid and not None, skip prompt
if path is None: # path Default argument value is mutable if value not in [None, "", [], {}]:
path = []
for key in obj.keys():
if type(obj[key]) is dict:
crawl(obj[key], func, path + [key])
continue continue
func(path + [key], obj[key])
def check(value, checks, name):
def get_check_value(key, default_result):
return checks[key] if key in checks else default_result
incorrect = False
if value == {}:
incorrect = True
if not incorrect and "type" in checks:
try:
value = eval(checks["type"])(value) # fixme remove eval
except:
incorrect = True
if (
not incorrect and "options" in checks and value not in checks["options"]
): # FAILSTATE Value is not one of the options
incorrect = True
if (
not incorrect
and "regex" in checks
and (
(isinstance(value, str) and re.match(checks["regex"], value) is None)
or not isinstance(value, str)
)
): # FAILSTATE Value doesn't match regex, or has regex but is not a string.
incorrect = True
if (
not incorrect
and not hasattr(value, "__iter__")
and (
("nmin" in checks and checks["nmin"] is not None and value < checks["nmin"])
or ("nmax" in checks and checks["nmax"] is not None and value > checks["nmax"])
)
):
incorrect = True
if (
not incorrect
and hasattr(value, "__iter__")
and (
("nmin" in checks and checks["nmin"] is not None and len(value) < checks["nmin"])
or ("nmax" in checks and checks["nmax"] is not None and len(value) > checks["nmax"])
)
):
incorrect = True
if incorrect:
value = handle_input(
message=(
(("[blue]Example: " + str(checks["example"]) + "\n") if "example" in checks else "")
+ "[red]"
+ ("Non-optional ", "Optional ")["optional" in checks and checks["optional"] is True]
)
+ "[#C0CAF5 bold]"
+ str(name)
+ "[#F7768E bold]=",
extra_info=get_check_value("explanation", ""),
check_type=eval(get_check_value("type", "False")), # fixme remove eval
default=get_check_value("default", NotImplemented),
match=get_check_value("regex", ""),
err_message=get_check_value("input_error", "Incorrect input"),
nmin=get_check_value("nmin", None),
nmax=get_check_value("nmax", None),
oob_error=get_check_value(
"oob_error", "Input out of bounds(Value too high/low/long/short)"
),
options=get_check_value("options", None),
optional=get_check_value("optional", False),
)
return value
description = field.description or ""
default_str = (
f" (default: {field.default})"
if (field.default is not None) or field.default == PydanticUndefined
else ""
)
prompt_msg = f"🧩 {field_name}\n 📘 {description}{default_str}\n ⚠️ Required: {field.is_required()}\n ❓ Enter value: "
while True:
user_input = input(prompt_msg).strip()
if not user_input:
if field.default is not None:
value_to_set = field.default
elif not field.is_required():
value_to_set = None
else:
print("⚠️ This field is required.")
continue
else:
# Convert input based on type, you can expand this logic
try:
value_to_set = parse_value(user_input, field.annotation)
except Exception as e:
print(f"⚠️ Invalid input: {e}")
continue
# Validate the assignment
try:
obj.__pydantic_validator__.validate_assignment(
obj, field_name, value_to_set
)
setattr(obj, field_name, value_to_set)
break
except ValidationError as ve:
for err in ve.errors():
print(f"{err['loc'][0]}: {err['msg']}")
def crawl_and_check(obj: dict, path: list, checks: dict = {}, name=""):
if len(path) == 0:
return check(obj, checks, name)
if path[0] not in obj.keys():
obj[path[0]] = {}
obj[path[0]] = crawl_and_check(obj[path[0]], path[1:], checks, path[0])
return obj return obj
def check_vars(path, checks): def parse_value(raw: str, expected_type: type):
global config from typing import get_args, get_origin
crawl_and_check(config, path, checks)
origin = get_origin(expected_type)
args = get_args(expected_type)
def check_toml(template_file, config_file) -> Tuple[bool, Dict]: if expected_type == bool:
global config if raw.lower() in ("true", "yes", "1"):
config = None return True
try: elif raw.lower() in ("false", "no", "0"):
template = toml.load(template_file)
except Exception as error:
console.print(f"[red bold]Encountered error when trying to to load {template_file}: {error}")
return False
try:
config = toml.load(config_file)
except toml.TomlDecodeError:
console.print(
f"""[blue]Couldn't read {config_file}.
Overwrite it?(y/n)"""
)
if not input().startswith("y"):
print("Unable to read config, and not allowed to overwrite it. Giving up.")
return False return False
else: else:
try: raise ValueError("Expected boolean value (true/false)")
with open(config_file, "w") as f: elif expected_type == int:
f.write("") return int(raw)
except: elif expected_type == float:
console.print( return float(raw)
f"[red bold]Failed to overwrite {config_file}. Giving up.\nSuggestion: check {config_file} permissions for the user." elif expected_type == str:
) return raw
return False elif origin == list and args:
except FileNotFoundError: return [parse_value(x.strip(), args[0]) for x in raw.split(",")]
console.print( else:
f"""[blue]Couldn't find {config_file} raise ValueError(f"Unsupported field type: {expected_type}")
Creating it now."""
)
try: def check_toml(config_file: str) -> Dict[str, Any]:
with open(config_file, "x") as f: """
f.write("") Load the template and config TOML files.
config = {} Validate config with Pydantic.
except: If invalid, prompt for missing or invalid fields.
console.print( Save fixed config back.
f"[red bold]Failed to write to {config_file}. Giving up.\nSuggestion: check the folder's permissions for the user." Return the valid Config model.
) """
return False try:
config_dict = toml.load(config_file)
except Exception as e:
print(f"Failed to load config {config_file}: {e}")
config_dict = {}
console.print( try:
"""\ config_instance = Config.model_validate(config_dict)
[blue bold]############################### except ValidationError as e:
# # print("Config validation failed, will prompt for missing/invalid fields:")
# Checking TOML configuration # print(e)
# # # Start from a clean model
############################### config_instance = Config.model_construct()
If you see any prompts, that means that you have unset/incorrectly set variables, please input the correct values.\ # Update model with any valid partial data loaded from config
""" for k, v in config_dict.items():
) if hasattr(config_instance, k):
crawl(template, check_vars) setattr(config_instance, k, v)
with open(config_file, "w") as f:
toml.dump(config, f) # Prompt for missing or invalid fields recursively
config_instance = prompt_recursive(config_instance)
# Validate again to be sure
config_instance = Config.model_validate(config_instance.model_dump())
# Save fixed config back to file
with open(config_file, "w", encoding="utf-8") as f:
toml.dump(config_instance.model_dump(), f)
print(f"Updated config saved to {config_file}")
config = config_instance.model_dump()
return config return config
if __name__ == "__main__": if __name__ == "__main__":
directory = Path().absolute() directory = Path().absolute()
check_toml(f"{directory}/utils/.config.template.toml", "config.toml") check_toml("config.toml")
Loading…
Cancel
Save