parent
1f410033ea
commit
35d2ff2559
@ -0,0 +1,417 @@
|
|||||||
|
from typing import Annotated, Literal, Optional
|
||||||
|
from pydantic import BaseModel, Field, StringConstraints
|
||||||
|
|
||||||
|
|
||||||
|
class RedditCreds(BaseModel):
|
||||||
|
client_id: Annotated[
|
||||||
|
str,
|
||||||
|
StringConstraints(
|
||||||
|
min_length=12, max_length=30, pattern=r"^[-a-zA-Z0-9._~+/]+=*$"
|
||||||
|
),
|
||||||
|
] = Field(..., description="The ID of your Reddit app of SCRIPT type")
|
||||||
|
|
||||||
|
client_secret: Annotated[
|
||||||
|
str,
|
||||||
|
StringConstraints(
|
||||||
|
min_length=20, max_length=40, pattern=r"^[-a-zA-Z0-9._~+/]+=*$"
|
||||||
|
),
|
||||||
|
] = Field(..., description="The SECRET of your Reddit app of SCRIPT type")
|
||||||
|
|
||||||
|
username: Annotated[
|
||||||
|
str, StringConstraints(min_length=3, max_length=20, pattern=r"^[-_0-9a-zA-Z]+$")
|
||||||
|
] = Field(..., description="The username of your Reddit account")
|
||||||
|
|
||||||
|
password: Annotated[str, StringConstraints(min_length=8)] = Field(
|
||||||
|
..., description="The password of your Reddit account"
|
||||||
|
)
|
||||||
|
|
||||||
|
twofa: Optional[bool] = Field(False, description="Whether Reddit 2FA is enabled")
|
||||||
|
|
||||||
|
|
||||||
|
class RedditThread(BaseModel):
|
||||||
|
random: Optional[bool] = Field(
|
||||||
|
False, description="If true, picks a random thread instead of asking for URL"
|
||||||
|
)
|
||||||
|
|
||||||
|
subreddit: Annotated[
|
||||||
|
str, StringConstraints(min_length=3, max_length=20, pattern=r"[_0-9a-zA-Z\+]+$")
|
||||||
|
] = Field(..., description="Name(s) of subreddit(s), '+' separated")
|
||||||
|
|
||||||
|
post_id: Annotated[Optional[str], StringConstraints(pattern=r"^[+a-zA-Z0-9]*$")] = (
|
||||||
|
Field("", description="Specify a Reddit post ID if desired")
|
||||||
|
)
|
||||||
|
|
||||||
|
max_comment_length: Annotated[int, Field(ge=10, le=10000)] = Field(
|
||||||
|
500, description="Max number of characters per comment"
|
||||||
|
)
|
||||||
|
|
||||||
|
min_comment_length: Annotated[int, Field(ge=0, le=10000)] = Field(
|
||||||
|
1, description="Min number of characters per comment"
|
||||||
|
)
|
||||||
|
|
||||||
|
post_lang: Optional[str] = Field(
|
||||||
|
"", description="Target language code for translation (e.g., 'es-cr')"
|
||||||
|
)
|
||||||
|
|
||||||
|
min_comments: Annotated[int, Field(ge=10)] = Field(
|
||||||
|
20, description="Minimum number of comments required"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RedditThreadExtras(BaseModel):
|
||||||
|
min_comments: Annotated[
|
||||||
|
int,
|
||||||
|
Field(
|
||||||
|
default=20,
|
||||||
|
ge=10,
|
||||||
|
le=999999,
|
||||||
|
description="The minimum number of comments a post should have to be included. Default is 20.",
|
||||||
|
examples=[29],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AIConfig(BaseModel):
|
||||||
|
ai_similarity_enabled: Annotated[
|
||||||
|
bool,
|
||||||
|
Field(
|
||||||
|
default=False,
|
||||||
|
description="Threads read from Reddit are sorted based on their similarity to the keywords given below.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
ai_similarity_keywords: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="",
|
||||||
|
description="Every keyword or sentence, separated by commas, is used to sort Reddit threads based on similarity.",
|
||||||
|
examples=["Elon Musk, Twitter, Stocks"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SettingsTTS(BaseModel):
|
||||||
|
voice_choice: Annotated[
|
||||||
|
Literal[
|
||||||
|
"elevenlabs",
|
||||||
|
"streamlabspolly",
|
||||||
|
"tiktok",
|
||||||
|
"googletranslate",
|
||||||
|
"awspolly",
|
||||||
|
"pyttsx",
|
||||||
|
],
|
||||||
|
Field(
|
||||||
|
default="tiktok",
|
||||||
|
description="The voice platform used for TTS generation.",
|
||||||
|
examples=["tiktok"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
random_voice: Annotated[
|
||||||
|
bool,
|
||||||
|
Field(
|
||||||
|
default=True,
|
||||||
|
description="Randomizes the voice used for each comment.",
|
||||||
|
examples=[True],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
elevenlabs_voice_name: Annotated[
|
||||||
|
Literal[
|
||||||
|
"Adam", "Antoni", "Arnold", "Bella", "Domi", "Elli", "Josh", "Rachel", "Sam"
|
||||||
|
],
|
||||||
|
Field(
|
||||||
|
default="Bella",
|
||||||
|
description="The voice used for ElevenLabs.",
|
||||||
|
examples=["Bella"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
elevenlabs_api_key: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="",
|
||||||
|
description="ElevenLabs API key.",
|
||||||
|
examples=["21f13f91f54d741e2ae27d2ab1b99d59"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
aws_polly_voice: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="Matthew",
|
||||||
|
description="The voice used for AWS Polly.",
|
||||||
|
examples=["Matthew"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
streamlabs_polly_voice: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="Matthew",
|
||||||
|
description="The voice used for Streamlabs Polly.",
|
||||||
|
examples=["Matthew"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
tiktok_voice: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="en_us_001",
|
||||||
|
description="The voice used for TikTok TTS.",
|
||||||
|
examples=["en_us_006"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
tiktok_sessionid: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="",
|
||||||
|
description="TikTok sessionid needed for TikTok TTS.",
|
||||||
|
examples=["c76bcc3a7625abcc27b508c7db457ff1"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
python_voice: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="1",
|
||||||
|
description="The index of the system TTS voices (starts from 0).",
|
||||||
|
examples=["1"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
py_voice_num: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="2",
|
||||||
|
description="The number of system voices available.",
|
||||||
|
examples=["2"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
silence_duration: Annotated[
|
||||||
|
float,
|
||||||
|
Field(
|
||||||
|
default=0.3,
|
||||||
|
description="Time in seconds between TTS comments.",
|
||||||
|
examples=["0.1"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
no_emojis: Annotated[
|
||||||
|
bool,
|
||||||
|
Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether to remove emojis from the comments.",
|
||||||
|
examples=[False],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
openai_api_url: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="https://api.openai.com/v1/",
|
||||||
|
description="The API endpoint URL for OpenAI TTS generation.",
|
||||||
|
examples=["https://api.openai.com/v1/"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
openai_api_key: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="",
|
||||||
|
description="Your OpenAI API key for TTS generation.",
|
||||||
|
examples=["sk-abc123def456..."],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
openai_voice_name: Annotated[
|
||||||
|
Literal[
|
||||||
|
"alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"
|
||||||
|
],
|
||||||
|
Field(
|
||||||
|
default="alloy",
|
||||||
|
description="The voice used for OpenAI TTS generation.",
|
||||||
|
examples=["alloy"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
openai_model: Annotated[
|
||||||
|
Literal["tts-1", "tts-1-hd"],
|
||||||
|
Field(
|
||||||
|
default="tts-1",
|
||||||
|
description="The model variant used for OpenAI TTS generation.",
|
||||||
|
examples=["tts-1"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SettingsBackground(BaseModel):
|
||||||
|
background_video: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="minecraft",
|
||||||
|
description="Sets the background for the video based on game name",
|
||||||
|
examples=["rocket-league"],
|
||||||
|
),
|
||||||
|
StringConstraints(strip_whitespace=True),
|
||||||
|
] = "minecraft"
|
||||||
|
|
||||||
|
background_audio: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="lofi",
|
||||||
|
description="Sets the background audio for the video",
|
||||||
|
examples=["chill-summer"],
|
||||||
|
),
|
||||||
|
StringConstraints(strip_whitespace=True),
|
||||||
|
] = "lofi"
|
||||||
|
|
||||||
|
background_audio_volume: Annotated[
|
||||||
|
float,
|
||||||
|
Field(
|
||||||
|
default=0.15,
|
||||||
|
ge=0,
|
||||||
|
le=1,
|
||||||
|
description="Sets the volume of the background audio. If you don't want background audio, set it to 0.",
|
||||||
|
examples=[0.05],
|
||||||
|
),
|
||||||
|
] = 0.15
|
||||||
|
|
||||||
|
enable_extra_audio: Annotated[
|
||||||
|
bool,
|
||||||
|
Field(
|
||||||
|
default=False,
|
||||||
|
description="Used if you want to render another video without background audio in a separate folder",
|
||||||
|
),
|
||||||
|
] = False
|
||||||
|
|
||||||
|
background_thumbnail: Annotated[
|
||||||
|
bool,
|
||||||
|
Field(
|
||||||
|
default=False,
|
||||||
|
description="Generate a thumbnail for the video (put a thumbnail.png file in the assets/backgrounds directory.)",
|
||||||
|
),
|
||||||
|
] = False
|
||||||
|
|
||||||
|
background_thumbnail_font_family: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="arial",
|
||||||
|
description="Font family for the thumbnail text",
|
||||||
|
examples=["arial"],
|
||||||
|
),
|
||||||
|
] = "arial"
|
||||||
|
|
||||||
|
background_thumbnail_font_size: Annotated[
|
||||||
|
int,
|
||||||
|
Field(
|
||||||
|
default=96,
|
||||||
|
description="Font size in pixels for the thumbnail text",
|
||||||
|
examples=[96],
|
||||||
|
),
|
||||||
|
] = 96
|
||||||
|
|
||||||
|
background_thumbnail_font_color: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="255,255,255",
|
||||||
|
description="Font color in RGB format for the thumbnail text",
|
||||||
|
examples=["255,255,255"],
|
||||||
|
),
|
||||||
|
] = "255,255,255"
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseModel):
|
||||||
|
allow_nsfw: Annotated[
|
||||||
|
bool,
|
||||||
|
Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether to allow NSFW content. True or False.",
|
||||||
|
examples=[False],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
theme: Annotated[
|
||||||
|
Literal["dark", "light", "transparent"],
|
||||||
|
Field(
|
||||||
|
default="dark",
|
||||||
|
description="Sets the Reddit theme. For story mode, 'transparent' is also allowed.",
|
||||||
|
examples=["light"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
times_to_run: Annotated[
|
||||||
|
int,
|
||||||
|
Field(
|
||||||
|
default=1,
|
||||||
|
ge=1,
|
||||||
|
description="Used if you want to run multiple times. Must be an int >= 1.",
|
||||||
|
examples=[2],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
opacity: Annotated[
|
||||||
|
float,
|
||||||
|
Field(
|
||||||
|
default=0.9,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Sets the opacity of comments when overlaid over the background.",
|
||||||
|
examples=[0.8],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
storymode: Annotated[
|
||||||
|
bool,
|
||||||
|
Field(
|
||||||
|
default=False,
|
||||||
|
description="Only read out title and post content. Great for story-based subreddits.",
|
||||||
|
examples=[False],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
storymodemethod: Annotated[
|
||||||
|
Literal[0, 1],
|
||||||
|
Field(
|
||||||
|
default=1,
|
||||||
|
description="Style used for story mode: 0 = static image, 1 = fancy video.",
|
||||||
|
examples=[1],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
storymode_max_length: Annotated[
|
||||||
|
int,
|
||||||
|
Field(
|
||||||
|
default=1000,
|
||||||
|
ge=1,
|
||||||
|
description="Max length (in characters) of the story mode video.",
|
||||||
|
examples=[1000],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
resolution_w: Annotated[
|
||||||
|
int,
|
||||||
|
Field(
|
||||||
|
default=1080,
|
||||||
|
description="Sets the width in pixels of the final video.",
|
||||||
|
examples=[1440],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
resolution_h: Annotated[
|
||||||
|
int,
|
||||||
|
Field(
|
||||||
|
default=1920,
|
||||||
|
description="Sets the height in pixels of the final video.",
|
||||||
|
examples=[2560],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
zoom: Annotated[
|
||||||
|
float,
|
||||||
|
Field(
|
||||||
|
default=1.0,
|
||||||
|
ge=0.1,
|
||||||
|
le=2.0,
|
||||||
|
description="Sets the browser zoom level. Useful for making text larger.",
|
||||||
|
examples=[1.1],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
channel_name: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
default="Reddit Tales",
|
||||||
|
description="Sets the channel name for the video.",
|
||||||
|
examples=["Reddit Stories"],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
tts: SettingsTTS
|
||||||
|
background: SettingsBackground
|
||||||
|
|
||||||
|
|
||||||
|
class Reddit(BaseModel):
|
||||||
|
creds: RedditCreds
|
||||||
|
thread: RedditThread
|
||||||
|
|
||||||
|
|
||||||
|
class Config(BaseModel):
|
||||||
|
reddit: Reddit
|
||||||
|
ai: AIConfig
|
||||||
|
settings: Settings
|
@ -1,170 +1,143 @@
|
|||||||
import re
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Tuple
|
from typing import Any, Dict
|
||||||
|
|
||||||
import toml
|
import toml
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
from utils.console import handle_input
|
from utils.config_model import Config
|
||||||
|
from utils.console import print_substep
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
config = dict # autocomplete
|
config: dict # autocomplete
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import ValidationError, BaseModel
|
||||||
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_recursive(obj: BaseModel):
|
||||||
|
"""
|
||||||
|
Recursively prompt for missing or invalid fields in a Pydantic model instance 'obj'.
|
||||||
|
"""
|
||||||
|
for field_name, field in obj.model_fields.items():
|
||||||
|
value = getattr(obj, field_name, None)
|
||||||
|
# If field is a nested BaseModel, recurse into it
|
||||||
|
if hasattr(field.annotation, "model_fields"):
|
||||||
|
nested_obj = value or field.annotation.model_construct()
|
||||||
|
fixed_nested = prompt_recursive(nested_obj)
|
||||||
|
setattr(obj, field_name, fixed_nested)
|
||||||
|
continue
|
||||||
|
|
||||||
def crawl(obj: dict, func=lambda x, y: print(x, y, end="\n"), path=None):
|
# If the value is valid and not None, skip prompt
|
||||||
if path is None: # path Default argument value is mutable
|
if value not in [None, "", [], {}]:
|
||||||
path = []
|
|
||||||
for key in obj.keys():
|
|
||||||
if type(obj[key]) is dict:
|
|
||||||
crawl(obj[key], func, path + [key])
|
|
||||||
continue
|
continue
|
||||||
func(path + [key], obj[key])
|
|
||||||
|
|
||||||
|
|
||||||
def check(value, checks, name):
|
|
||||||
def get_check_value(key, default_result):
|
|
||||||
return checks[key] if key in checks else default_result
|
|
||||||
|
|
||||||
incorrect = False
|
|
||||||
if value == {}:
|
|
||||||
incorrect = True
|
|
||||||
if not incorrect and "type" in checks:
|
|
||||||
try:
|
|
||||||
value = eval(checks["type"])(value) # fixme remove eval
|
|
||||||
except:
|
|
||||||
incorrect = True
|
|
||||||
|
|
||||||
if (
|
|
||||||
not incorrect and "options" in checks and value not in checks["options"]
|
|
||||||
): # FAILSTATE Value is not one of the options
|
|
||||||
incorrect = True
|
|
||||||
if (
|
|
||||||
not incorrect
|
|
||||||
and "regex" in checks
|
|
||||||
and (
|
|
||||||
(isinstance(value, str) and re.match(checks["regex"], value) is None)
|
|
||||||
or not isinstance(value, str)
|
|
||||||
)
|
|
||||||
): # FAILSTATE Value doesn't match regex, or has regex but is not a string.
|
|
||||||
incorrect = True
|
|
||||||
|
|
||||||
if (
|
|
||||||
not incorrect
|
|
||||||
and not hasattr(value, "__iter__")
|
|
||||||
and (
|
|
||||||
("nmin" in checks and checks["nmin"] is not None and value < checks["nmin"])
|
|
||||||
or ("nmax" in checks and checks["nmax"] is not None and value > checks["nmax"])
|
|
||||||
)
|
|
||||||
):
|
|
||||||
incorrect = True
|
|
||||||
if (
|
|
||||||
not incorrect
|
|
||||||
and hasattr(value, "__iter__")
|
|
||||||
and (
|
|
||||||
("nmin" in checks and checks["nmin"] is not None and len(value) < checks["nmin"])
|
|
||||||
or ("nmax" in checks and checks["nmax"] is not None and len(value) > checks["nmax"])
|
|
||||||
)
|
|
||||||
):
|
|
||||||
incorrect = True
|
|
||||||
|
|
||||||
if incorrect:
|
|
||||||
value = handle_input(
|
|
||||||
message=(
|
|
||||||
(("[blue]Example: " + str(checks["example"]) + "\n") if "example" in checks else "")
|
|
||||||
+ "[red]"
|
|
||||||
+ ("Non-optional ", "Optional ")["optional" in checks and checks["optional"] is True]
|
|
||||||
)
|
|
||||||
+ "[#C0CAF5 bold]"
|
|
||||||
+ str(name)
|
|
||||||
+ "[#F7768E bold]=",
|
|
||||||
extra_info=get_check_value("explanation", ""),
|
|
||||||
check_type=eval(get_check_value("type", "False")), # fixme remove eval
|
|
||||||
default=get_check_value("default", NotImplemented),
|
|
||||||
match=get_check_value("regex", ""),
|
|
||||||
err_message=get_check_value("input_error", "Incorrect input"),
|
|
||||||
nmin=get_check_value("nmin", None),
|
|
||||||
nmax=get_check_value("nmax", None),
|
|
||||||
oob_error=get_check_value(
|
|
||||||
"oob_error", "Input out of bounds(Value too high/low/long/short)"
|
|
||||||
),
|
|
||||||
options=get_check_value("options", None),
|
|
||||||
optional=get_check_value("optional", False),
|
|
||||||
)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
description = field.description or ""
|
||||||
|
default_str = (
|
||||||
|
f" (default: {field.default})"
|
||||||
|
if (field.default is not None) or field.default == PydanticUndefined
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
prompt_msg = f"🧩 {field_name}\n 📘 {description}{default_str}\n ⚠️ Required: {field.is_required()}\n ❓ Enter value: "
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_input = input(prompt_msg).strip()
|
||||||
|
if not user_input:
|
||||||
|
if field.default is not None:
|
||||||
|
value_to_set = field.default
|
||||||
|
elif not field.is_required():
|
||||||
|
value_to_set = None
|
||||||
|
else:
|
||||||
|
print("⚠️ This field is required.")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Convert input based on type, you can expand this logic
|
||||||
|
try:
|
||||||
|
value_to_set = parse_value(user_input, field.annotation)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Invalid input: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate the assignment
|
||||||
|
try:
|
||||||
|
obj.__pydantic_validator__.validate_assignment(
|
||||||
|
obj, field_name, value_to_set
|
||||||
|
)
|
||||||
|
setattr(obj, field_name, value_to_set)
|
||||||
|
break
|
||||||
|
except ValidationError as ve:
|
||||||
|
for err in ve.errors():
|
||||||
|
print(f"❌ {err['loc'][0]}: {err['msg']}")
|
||||||
|
|
||||||
def crawl_and_check(obj: dict, path: list, checks: dict = {}, name=""):
|
|
||||||
if len(path) == 0:
|
|
||||||
return check(obj, checks, name)
|
|
||||||
if path[0] not in obj.keys():
|
|
||||||
obj[path[0]] = {}
|
|
||||||
obj[path[0]] = crawl_and_check(obj[path[0]], path[1:], checks, path[0])
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def check_vars(path, checks):
|
def parse_value(raw: str, expected_type: type):
|
||||||
global config
|
from typing import get_args, get_origin
|
||||||
crawl_and_check(config, path, checks)
|
|
||||||
|
|
||||||
|
origin = get_origin(expected_type)
|
||||||
|
args = get_args(expected_type)
|
||||||
|
|
||||||
def check_toml(template_file, config_file) -> Tuple[bool, Dict]:
|
if expected_type == bool:
|
||||||
global config
|
if raw.lower() in ("true", "yes", "1"):
|
||||||
config = None
|
return True
|
||||||
try:
|
elif raw.lower() in ("false", "no", "0"):
|
||||||
template = toml.load(template_file)
|
|
||||||
except Exception as error:
|
|
||||||
console.print(f"[red bold]Encountered error when trying to to load {template_file}: {error}")
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
config = toml.load(config_file)
|
|
||||||
except toml.TomlDecodeError:
|
|
||||||
console.print(
|
|
||||||
f"""[blue]Couldn't read {config_file}.
|
|
||||||
Overwrite it?(y/n)"""
|
|
||||||
)
|
|
||||||
if not input().startswith("y"):
|
|
||||||
print("Unable to read config, and not allowed to overwrite it. Giving up.")
|
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
try:
|
raise ValueError("Expected boolean value (true/false)")
|
||||||
with open(config_file, "w") as f:
|
elif expected_type == int:
|
||||||
f.write("")
|
return int(raw)
|
||||||
except:
|
elif expected_type == float:
|
||||||
console.print(
|
return float(raw)
|
||||||
f"[red bold]Failed to overwrite {config_file}. Giving up.\nSuggestion: check {config_file} permissions for the user."
|
elif expected_type == str:
|
||||||
)
|
return raw
|
||||||
return False
|
elif origin == list and args:
|
||||||
except FileNotFoundError:
|
return [parse_value(x.strip(), args[0]) for x in raw.split(",")]
|
||||||
console.print(
|
else:
|
||||||
f"""[blue]Couldn't find {config_file}
|
raise ValueError(f"Unsupported field type: {expected_type}")
|
||||||
Creating it now."""
|
|
||||||
)
|
|
||||||
try:
|
def check_toml(config_file: str) -> Dict[str, Any]:
|
||||||
with open(config_file, "x") as f:
|
"""
|
||||||
f.write("")
|
Load the template and config TOML files.
|
||||||
config = {}
|
Validate config with Pydantic.
|
||||||
except:
|
If invalid, prompt for missing or invalid fields.
|
||||||
console.print(
|
Save fixed config back.
|
||||||
f"[red bold]Failed to write to {config_file}. Giving up.\nSuggestion: check the folder's permissions for the user."
|
Return the valid Config model.
|
||||||
)
|
"""
|
||||||
return False
|
try:
|
||||||
|
config_dict = toml.load(config_file)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load config {config_file}: {e}")
|
||||||
|
config_dict = {}
|
||||||
|
|
||||||
console.print(
|
try:
|
||||||
"""\
|
config_instance = Config.model_validate(config_dict)
|
||||||
[blue bold]###############################
|
except ValidationError as e:
|
||||||
# #
|
print("Config validation failed, will prompt for missing/invalid fields:")
|
||||||
# Checking TOML configuration #
|
print(e)
|
||||||
# #
|
# Start from a clean model
|
||||||
###############################
|
config_instance = Config.model_construct()
|
||||||
If you see any prompts, that means that you have unset/incorrectly set variables, please input the correct values.\
|
# Update model with any valid partial data loaded from config
|
||||||
"""
|
for k, v in config_dict.items():
|
||||||
)
|
if hasattr(config_instance, k):
|
||||||
crawl(template, check_vars)
|
setattr(config_instance, k, v)
|
||||||
with open(config_file, "w") as f:
|
|
||||||
toml.dump(config, f)
|
# Prompt for missing or invalid fields recursively
|
||||||
|
config_instance = prompt_recursive(config_instance)
|
||||||
|
|
||||||
|
# Validate again to be sure
|
||||||
|
config_instance = Config.model_validate(config_instance.model_dump())
|
||||||
|
|
||||||
|
# Save fixed config back to file
|
||||||
|
with open(config_file, "w", encoding="utf-8") as f:
|
||||||
|
toml.dump(config_instance.model_dump(), f)
|
||||||
|
print(f"Updated config saved to {config_file}")
|
||||||
|
config = config_instance.model_dump()
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
directory = Path().absolute()
|
directory = Path().absolute()
|
||||||
check_toml(f"{directory}/utils/.config.template.toml", "config.toml")
|
check_toml("config.toml")
|
Loading…
Reference in new issue