You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
RedditVideoMakerBot/utils/settings.py

143 lines
4.7 KiB

import sys
from pathlib import Path
from typing import Any, Dict
import toml
from rich.console import Console
from utils.config_model import Config
from utils.console import print_substep
console = Console()
config: dict # autocomplete
from typing import Any
from pydantic import ValidationError, BaseModel
from pydantic_core import PydanticUndefined
def prompt_recursive(obj: BaseModel):
"""
Recursively prompt for missing or invalid fields in a Pydantic model instance 'obj'.
"""
for field_name, field in obj.model_fields.items():
value = getattr(obj, field_name, None)
# If field is a nested BaseModel, recurse into it
if hasattr(field.annotation, "model_fields"):
nested_obj = value or field.annotation.model_construct()
fixed_nested = prompt_recursive(nested_obj)
setattr(obj, field_name, fixed_nested)
continue
# If the value is valid and not None, skip prompt
if value not in [None, "", [], {}]:
continue
description = field.description or ""
default_str = (
f" (default: {field.default})"
if (field.default is not None) or field.default == PydanticUndefined
else ""
)
prompt_msg = f"🧩 {field_name}\n 📘 {description}{default_str}\n ⚠️ Required: {field.is_required()}\n ❓ Enter value: "
while True:
user_input = input(prompt_msg).strip()
if not user_input:
if field.default is not None:
value_to_set = field.default
elif not field.is_required():
value_to_set = None
else:
print("⚠️ This field is required.")
continue
else:
# Convert input based on type, you can expand this logic
try:
value_to_set = parse_value(user_input, field.annotation)
except Exception as e:
print(f"⚠️ Invalid input: {e}")
continue
# Validate the assignment
try:
obj.__pydantic_validator__.validate_assignment(
obj, field_name, value_to_set
)
setattr(obj, field_name, value_to_set)
break
except ValidationError as ve:
for err in ve.errors():
print(f"{err['loc'][0]}: {err['msg']}")
return obj
def parse_value(raw: str, expected_type: type):
from typing import get_args, get_origin
origin = get_origin(expected_type)
args = get_args(expected_type)
if expected_type == bool:
if raw.lower() in ("true", "yes", "1"):
return True
elif raw.lower() in ("false", "no", "0"):
return False
else:
raise ValueError("Expected boolean value (true/false)")
elif expected_type == int:
return int(raw)
elif expected_type == float:
return float(raw)
elif expected_type == str:
return raw
elif origin == list and args:
return [parse_value(x.strip(), args[0]) for x in raw.split(",")]
else:
raise ValueError(f"Unsupported field type: {expected_type}")
def check_toml(config_file: str) -> Dict[str, Any]:
"""
Load the template and config TOML files.
Validate config with Pydantic.
If invalid, prompt for missing or invalid fields.
Save fixed config back.
Return the valid Config model.
"""
try:
config_dict = toml.load(config_file)
except Exception as e:
print(f"Failed to load config {config_file}: {e}")
config_dict = {}
try:
config_instance = Config.model_validate(config_dict)
except ValidationError as e:
print("Config validation failed, will prompt for missing/invalid fields:")
print(e)
# Start from a clean model
config_instance = Config.model_construct()
# Update model with any valid partial data loaded from config
for k, v in config_dict.items():
if hasattr(config_instance, k):
setattr(config_instance, k, v)
# Prompt for missing or invalid fields recursively
config_instance = prompt_recursive(config_instance)
# Validate again to be sure
config_instance = Config.model_validate(config_instance.model_dump())
# Save fixed config back to file
with open(config_file, "w", encoding="utf-8") as f:
toml.dump(config_instance.model_dump(), f)
print(f"Updated config saved to {config_file}")
config = config_instance.model_dump()
return config
if __name__ == "__main__":
directory = Path().absolute()
check_toml("config.toml")