use pydanctic model for templating and validation

pull/2341/head
electro199 3 months ago
parent 1f410033ea
commit 35d2ff2559

@ -91,9 +91,8 @@ if __name__ == "__main__":
ffmpeg_install()
directory = Path().absolute()
config = settings.check_toml(
f"{directory}/utils/.config.template.toml", f"{directory}/config.toml"
f"{directory}/config.toml"
)
config is False and sys.exit()
if (
not settings.config["settings"]["tts"]["tiktok_sessionid"]

@ -0,0 +1,417 @@
from typing import Annotated, Literal, Optional
from pydantic import BaseModel, Field, StringConstraints
class RedditCreds(BaseModel):
client_id: Annotated[
str,
StringConstraints(
min_length=12, max_length=30, pattern=r"^[-a-zA-Z0-9._~+/]+=*$"
),
] = Field(..., description="The ID of your Reddit app of SCRIPT type")
client_secret: Annotated[
str,
StringConstraints(
min_length=20, max_length=40, pattern=r"^[-a-zA-Z0-9._~+/]+=*$"
),
] = Field(..., description="The SECRET of your Reddit app of SCRIPT type")
username: Annotated[
str, StringConstraints(min_length=3, max_length=20, pattern=r"^[-_0-9a-zA-Z]+$")
] = Field(..., description="The username of your Reddit account")
password: Annotated[str, StringConstraints(min_length=8)] = Field(
..., description="The password of your Reddit account"
)
twofa: Optional[bool] = Field(False, description="Whether Reddit 2FA is enabled")
class RedditThread(BaseModel):
random: Optional[bool] = Field(
False, description="If true, picks a random thread instead of asking for URL"
)
subreddit: Annotated[
str, StringConstraints(min_length=3, max_length=20, pattern=r"[_0-9a-zA-Z\+]+$")
] = Field(..., description="Name(s) of subreddit(s), '+' separated")
post_id: Annotated[Optional[str], StringConstraints(pattern=r"^[+a-zA-Z0-9]*$")] = (
Field("", description="Specify a Reddit post ID if desired")
)
max_comment_length: Annotated[int, Field(ge=10, le=10000)] = Field(
500, description="Max number of characters per comment"
)
min_comment_length: Annotated[int, Field(ge=0, le=10000)] = Field(
1, description="Min number of characters per comment"
)
post_lang: Optional[str] = Field(
"", description="Target language code for translation (e.g., 'es-cr')"
)
min_comments: Annotated[int, Field(ge=10)] = Field(
20, description="Minimum number of comments required"
)
class RedditThreadExtras(BaseModel):
min_comments: Annotated[
int,
Field(
default=20,
ge=10,
le=999999,
description="The minimum number of comments a post should have to be included. Default is 20.",
examples=[29],
),
]
class AIConfig(BaseModel):
ai_similarity_enabled: Annotated[
bool,
Field(
default=False,
description="Threads read from Reddit are sorted based on their similarity to the keywords given below.",
),
]
ai_similarity_keywords: Annotated[
str,
Field(
default="",
description="Every keyword or sentence, separated by commas, is used to sort Reddit threads based on similarity.",
examples=["Elon Musk, Twitter, Stocks"],
),
]
class SettingsTTS(BaseModel):
voice_choice: Annotated[
Literal[
"elevenlabs",
"streamlabspolly",
"tiktok",
"googletranslate",
"awspolly",
"pyttsx",
],
Field(
default="tiktok",
description="The voice platform used for TTS generation.",
examples=["tiktok"],
),
]
random_voice: Annotated[
bool,
Field(
default=True,
description="Randomizes the voice used for each comment.",
examples=[True],
),
]
elevenlabs_voice_name: Annotated[
Literal[
"Adam", "Antoni", "Arnold", "Bella", "Domi", "Elli", "Josh", "Rachel", "Sam"
],
Field(
default="Bella",
description="The voice used for ElevenLabs.",
examples=["Bella"],
),
]
elevenlabs_api_key: Annotated[
str,
Field(
default="",
description="ElevenLabs API key.",
examples=["21f13f91f54d741e2ae27d2ab1b99d59"],
),
]
aws_polly_voice: Annotated[
str,
Field(
default="Matthew",
description="The voice used for AWS Polly.",
examples=["Matthew"],
),
]
streamlabs_polly_voice: Annotated[
str,
Field(
default="Matthew",
description="The voice used for Streamlabs Polly.",
examples=["Matthew"],
),
]
tiktok_voice: Annotated[
str,
Field(
default="en_us_001",
description="The voice used for TikTok TTS.",
examples=["en_us_006"],
),
]
tiktok_sessionid: Annotated[
str,
Field(
default="",
description="TikTok sessionid needed for TikTok TTS.",
examples=["c76bcc3a7625abcc27b508c7db457ff1"],
),
]
python_voice: Annotated[
str,
Field(
default="1",
description="The index of the system TTS voices (starts from 0).",
examples=["1"],
),
]
py_voice_num: Annotated[
str,
Field(
default="2",
description="The number of system voices available.",
examples=["2"],
),
]
silence_duration: Annotated[
float,
Field(
default=0.3,
description="Time in seconds between TTS comments.",
examples=["0.1"],
),
]
no_emojis: Annotated[
bool,
Field(
default=False,
description="Whether to remove emojis from the comments.",
examples=[False],
),
]
openai_api_url: Annotated[
str,
Field(
default="https://api.openai.com/v1/",
description="The API endpoint URL for OpenAI TTS generation.",
examples=["https://api.openai.com/v1/"],
),
]
openai_api_key: Annotated[
str,
Field(
default="",
description="Your OpenAI API key for TTS generation.",
examples=["sk-abc123def456..."],
),
]
openai_voice_name: Annotated[
Literal[
"alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"
],
Field(
default="alloy",
description="The voice used for OpenAI TTS generation.",
examples=["alloy"],
),
]
openai_model: Annotated[
Literal["tts-1", "tts-1-hd"],
Field(
default="tts-1",
description="The model variant used for OpenAI TTS generation.",
examples=["tts-1"],
),
]
class SettingsBackground(BaseModel):
background_video: Annotated[
str,
Field(
default="minecraft",
description="Sets the background for the video based on game name",
examples=["rocket-league"],
),
StringConstraints(strip_whitespace=True),
] = "minecraft"
background_audio: Annotated[
str,
Field(
default="lofi",
description="Sets the background audio for the video",
examples=["chill-summer"],
),
StringConstraints(strip_whitespace=True),
] = "lofi"
background_audio_volume: Annotated[
float,
Field(
default=0.15,
ge=0,
le=1,
description="Sets the volume of the background audio. If you don't want background audio, set it to 0.",
examples=[0.05],
),
] = 0.15
enable_extra_audio: Annotated[
bool,
Field(
default=False,
description="Used if you want to render another video without background audio in a separate folder",
),
] = False
background_thumbnail: Annotated[
bool,
Field(
default=False,
description="Generate a thumbnail for the video (put a thumbnail.png file in the assets/backgrounds directory.)",
),
] = False
background_thumbnail_font_family: Annotated[
str,
Field(
default="arial",
description="Font family for the thumbnail text",
examples=["arial"],
),
] = "arial"
background_thumbnail_font_size: Annotated[
int,
Field(
default=96,
description="Font size in pixels for the thumbnail text",
examples=[96],
),
] = 96
background_thumbnail_font_color: Annotated[
str,
Field(
default="255,255,255",
description="Font color in RGB format for the thumbnail text",
examples=["255,255,255"],
),
] = "255,255,255"
class Settings(BaseModel):
allow_nsfw: Annotated[
bool,
Field(
default=False,
description="Whether to allow NSFW content. True or False.",
examples=[False],
),
]
theme: Annotated[
Literal["dark", "light", "transparent"],
Field(
default="dark",
description="Sets the Reddit theme. For story mode, 'transparent' is also allowed.",
examples=["light"],
),
]
times_to_run: Annotated[
int,
Field(
default=1,
ge=1,
description="Used if you want to run multiple times. Must be an int >= 1.",
examples=[2],
),
]
opacity: Annotated[
float,
Field(
default=0.9,
ge=0.0,
le=1.0,
description="Sets the opacity of comments when overlaid over the background.",
examples=[0.8],
),
]
storymode: Annotated[
bool,
Field(
default=False,
description="Only read out title and post content. Great for story-based subreddits.",
examples=[False],
),
]
storymodemethod: Annotated[
Literal[0, 1],
Field(
default=1,
description="Style used for story mode: 0 = static image, 1 = fancy video.",
examples=[1],
),
]
storymode_max_length: Annotated[
int,
Field(
default=1000,
ge=1,
description="Max length (in characters) of the story mode video.",
examples=[1000],
),
]
resolution_w: Annotated[
int,
Field(
default=1080,
description="Sets the width in pixels of the final video.",
examples=[1440],
),
]
resolution_h: Annotated[
int,
Field(
default=1920,
description="Sets the height in pixels of the final video.",
examples=[2560],
),
]
zoom: Annotated[
float,
Field(
default=1.0,
ge=0.1,
le=2.0,
description="Sets the browser zoom level. Useful for making text larger.",
examples=[1.1],
),
]
channel_name: Annotated[
str,
Field(
default="Reddit Tales",
description="Sets the channel name for the video.",
examples=["Reddit Stories"],
),
]
tts: SettingsTTS
background: SettingsBackground
class Reddit(BaseModel):
creds: RedditCreds
thread: RedditThread
class Config(BaseModel):
reddit: Reddit
ai: AIConfig
settings: Settings

@ -1,170 +1,143 @@
import re
import sys
from pathlib import Path
from typing import Dict, Tuple
from typing import Any, Dict
import toml
from rich.console import Console
from utils.console import handle_input
from utils.config_model import Config
from utils.console import print_substep
console = Console()
config = dict # autocomplete
config: dict # autocomplete
from typing import Any
from pydantic import ValidationError, BaseModel
from pydantic_core import PydanticUndefined
def prompt_recursive(obj: BaseModel):
"""
Recursively prompt for missing or invalid fields in a Pydantic model instance 'obj'.
"""
for field_name, field in obj.model_fields.items():
value = getattr(obj, field_name, None)
# If field is a nested BaseModel, recurse into it
if hasattr(field.annotation, "model_fields"):
nested_obj = value or field.annotation.model_construct()
fixed_nested = prompt_recursive(nested_obj)
setattr(obj, field_name, fixed_nested)
continue
def crawl(obj: dict, func=lambda x, y: print(x, y, end="\n"), path=None):
if path is None: # path Default argument value is mutable
path = []
for key in obj.keys():
if type(obj[key]) is dict:
crawl(obj[key], func, path + [key])
# If the value is valid and not None, skip prompt
if value not in [None, "", [], {}]:
continue
func(path + [key], obj[key])
def check(value, checks, name):
def get_check_value(key, default_result):
return checks[key] if key in checks else default_result
incorrect = False
if value == {}:
incorrect = True
if not incorrect and "type" in checks:
try:
value = eval(checks["type"])(value) # fixme remove eval
except:
incorrect = True
if (
not incorrect and "options" in checks and value not in checks["options"]
): # FAILSTATE Value is not one of the options
incorrect = True
if (
not incorrect
and "regex" in checks
and (
(isinstance(value, str) and re.match(checks["regex"], value) is None)
or not isinstance(value, str)
)
): # FAILSTATE Value doesn't match regex, or has regex but is not a string.
incorrect = True
if (
not incorrect
and not hasattr(value, "__iter__")
and (
("nmin" in checks and checks["nmin"] is not None and value < checks["nmin"])
or ("nmax" in checks and checks["nmax"] is not None and value > checks["nmax"])
)
):
incorrect = True
if (
not incorrect
and hasattr(value, "__iter__")
and (
("nmin" in checks and checks["nmin"] is not None and len(value) < checks["nmin"])
or ("nmax" in checks and checks["nmax"] is not None and len(value) > checks["nmax"])
)
):
incorrect = True
if incorrect:
value = handle_input(
message=(
(("[blue]Example: " + str(checks["example"]) + "\n") if "example" in checks else "")
+ "[red]"
+ ("Non-optional ", "Optional ")["optional" in checks and checks["optional"] is True]
)
+ "[#C0CAF5 bold]"
+ str(name)
+ "[#F7768E bold]=",
extra_info=get_check_value("explanation", ""),
check_type=eval(get_check_value("type", "False")), # fixme remove eval
default=get_check_value("default", NotImplemented),
match=get_check_value("regex", ""),
err_message=get_check_value("input_error", "Incorrect input"),
nmin=get_check_value("nmin", None),
nmax=get_check_value("nmax", None),
oob_error=get_check_value(
"oob_error", "Input out of bounds(Value too high/low/long/short)"
),
options=get_check_value("options", None),
optional=get_check_value("optional", False),
)
return value
description = field.description or ""
default_str = (
f" (default: {field.default})"
if (field.default is not None) or field.default == PydanticUndefined
else ""
)
prompt_msg = f"🧩 {field_name}\n 📘 {description}{default_str}\n ⚠️ Required: {field.is_required()}\n ❓ Enter value: "
while True:
user_input = input(prompt_msg).strip()
if not user_input:
if field.default is not None:
value_to_set = field.default
elif not field.is_required():
value_to_set = None
else:
print("⚠️ This field is required.")
continue
else:
# Convert input based on type, you can expand this logic
try:
value_to_set = parse_value(user_input, field.annotation)
except Exception as e:
print(f"⚠️ Invalid input: {e}")
continue
# Validate the assignment
try:
obj.__pydantic_validator__.validate_assignment(
obj, field_name, value_to_set
)
setattr(obj, field_name, value_to_set)
break
except ValidationError as ve:
for err in ve.errors():
print(f"{err['loc'][0]}: {err['msg']}")
def crawl_and_check(obj: dict, path: list, checks: dict = {}, name=""):
if len(path) == 0:
return check(obj, checks, name)
if path[0] not in obj.keys():
obj[path[0]] = {}
obj[path[0]] = crawl_and_check(obj[path[0]], path[1:], checks, path[0])
return obj
def check_vars(path, checks):
global config
crawl_and_check(config, path, checks)
def parse_value(raw: str, expected_type: type):
from typing import get_args, get_origin
origin = get_origin(expected_type)
args = get_args(expected_type)
def check_toml(template_file, config_file) -> Tuple[bool, Dict]:
global config
config = None
try:
template = toml.load(template_file)
except Exception as error:
console.print(f"[red bold]Encountered error when trying to to load {template_file}: {error}")
return False
try:
config = toml.load(config_file)
except toml.TomlDecodeError:
console.print(
f"""[blue]Couldn't read {config_file}.
Overwrite it?(y/n)"""
)
if not input().startswith("y"):
print("Unable to read config, and not allowed to overwrite it. Giving up.")
if expected_type == bool:
if raw.lower() in ("true", "yes", "1"):
return True
elif raw.lower() in ("false", "no", "0"):
return False
else:
try:
with open(config_file, "w") as f:
f.write("")
except:
console.print(
f"[red bold]Failed to overwrite {config_file}. Giving up.\nSuggestion: check {config_file} permissions for the user."
)
return False
except FileNotFoundError:
console.print(
f"""[blue]Couldn't find {config_file}
Creating it now."""
)
try:
with open(config_file, "x") as f:
f.write("")
config = {}
except:
console.print(
f"[red bold]Failed to write to {config_file}. Giving up.\nSuggestion: check the folder's permissions for the user."
)
return False
raise ValueError("Expected boolean value (true/false)")
elif expected_type == int:
return int(raw)
elif expected_type == float:
return float(raw)
elif expected_type == str:
return raw
elif origin == list and args:
return [parse_value(x.strip(), args[0]) for x in raw.split(",")]
else:
raise ValueError(f"Unsupported field type: {expected_type}")
def check_toml(config_file: str) -> Dict[str, Any]:
"""
Load the template and config TOML files.
Validate config with Pydantic.
If invalid, prompt for missing or invalid fields.
Save fixed config back.
Return the valid Config model.
"""
try:
config_dict = toml.load(config_file)
except Exception as e:
print(f"Failed to load config {config_file}: {e}")
config_dict = {}
console.print(
"""\
[blue bold]###############################
# #
# Checking TOML configuration #
# #
###############################
If you see any prompts, that means that you have unset/incorrectly set variables, please input the correct values.\
"""
)
crawl(template, check_vars)
with open(config_file, "w") as f:
toml.dump(config, f)
try:
config_instance = Config.model_validate(config_dict)
except ValidationError as e:
print("Config validation failed, will prompt for missing/invalid fields:")
print(e)
# Start from a clean model
config_instance = Config.model_construct()
# Update model with any valid partial data loaded from config
for k, v in config_dict.items():
if hasattr(config_instance, k):
setattr(config_instance, k, v)
# Prompt for missing or invalid fields recursively
config_instance = prompt_recursive(config_instance)
# Validate again to be sure
config_instance = Config.model_validate(config_instance.model_dump())
# Save fixed config back to file
with open(config_file, "w", encoding="utf-8") as f:
toml.dump(config_instance.model_dump(), f)
print(f"Updated config saved to {config_file}")
config = config_instance.model_dump()
return config
if __name__ == "__main__":
directory = Path().absolute()
check_toml(f"{directory}/utils/.config.template.toml", "config.toml")
check_toml("config.toml")
Loading…
Cancel
Save