diff --git a/main.py b/main.py index 890709c..ed053e1 100755 --- a/main.py +++ b/main.py @@ -11,6 +11,7 @@ from prawcore import ResponseException from reddit.subreddit import get_subreddit_threads from utils import settings from utils.cleanup import cleanup +from utils.checkpoint import run_step, save_checkpoint, load_checkpoint, clear_checkpoint, print_resume_status from utils.console import print_markdown, print_step, print_substep from utils.ffmpeg_install import ffmpeg_install from utils.id import extract_id @@ -48,21 +49,58 @@ reddit_object: Dict[str, str | list] def main(POST_ID=None) -> None: global reddit_id, reddit_object - # reddit_object = get_subreddit_threads(POST_ID) - # reddit_id = extract_id(reddit_object) - # print_substep(f"Thread ID is {reddit_id}", style="bold blue") - # length, number_of_comments = save_text_to_mp3(reddit_object) - # length = math.ceil(length) - # get_screenshots_of_reddit_posts(reddit_object, number_of_comments) - # bg_config = { - # "video": get_background_config("video"), - # "audio": get_background_config("audio"), - # } - # download_background_video(bg_config["video"]) - # download_background_audio(bg_config["audio"]) - # chop_background(bg_config, length, reddit_object) - # make_final_video(number_of_comments, length, reddit_object, bg_config) - print_step("Reddit pipeline is disabled. Uncomment main() body to re-enable.") + + # Step 1: Fetch Reddit threads (no checkpoint — reddit_id unknown yet) + reddit_object = get_subreddit_threads(POST_ID) + reddit_id = extract_id(reddit_object) + print_substep(f"Thread ID is {reddit_id}", style="bold blue") + save_checkpoint(reddit_id, "fetch_reddit", {"result": None}) + print_resume_status(reddit_id) + + # Step 2: Generate TTS audio + tts_result = run_step( + reddit_id, "generate_tts", + save_text_to_mp3, reddit_object, + ) + length, number_of_comments = tts_result[0], tts_result[1] + length = math.ceil(length) + + # Step 3: Take screenshots + run_step( + reddit_id, "take_screenshots", + get_screenshots_of_reddit_posts, reddit_object, number_of_comments, + ) + + # Step 4: Download background video & audio + bg_config = { + "video": get_background_config("video"), + "audio": get_background_config("audio"), + } + run_step( + reddit_id, "download_background", + _download_backgrounds, bg_config, + ) + + # Step 5: Chop background + run_step( + reddit_id, "chop_background", + chop_background, bg_config, length, reddit_object, + ) + + # Step 6: Make final video + run_step( + reddit_id, "make_final_video", + make_final_video, number_of_comments, length, reddit_object, bg_config, + ) + + # Pipeline complete — clear checkpoint + clear_checkpoint(reddit_id) + print_step("Pipeline completed successfully! Checkpoint cleared.") + + +def _download_backgrounds(bg_config): + download_background_video(bg_config["video"]) + download_background_audio(bg_config["audio"]) def run_many(times) -> None: diff --git a/utils/checkpoint.py b/utils/checkpoint.py new file mode 100644 index 0000000..87d8067 --- /dev/null +++ b/utils/checkpoint.py @@ -0,0 +1,106 @@ +import json +import time +from pathlib import Path +from typing import Any, Optional + +from utils.console import print_step, print_substep + +CHECKPOINT_DIR = Path("assets/temp") + + +def _checkpoint_path(reddit_id: str) -> Path: + return CHECKPOINT_DIR / reddit_id / "checkpoint.json" + + +def save_checkpoint(reddit_id: str, step: str, data: dict[str, Any]) -> None: + path = _checkpoint_path(reddit_id) + path.parent.mkdir(parents=True, exist_ok=True) + + checkpoint = load_checkpoint(reddit_id) or {} + checkpoint["reddit_id"] = reddit_id + checkpoint["last_step"] = step + checkpoint["updated_at"] = time.time() + checkpoint.setdefault("completed_steps", []) + if step not in checkpoint["completed_steps"]: + checkpoint["completed_steps"].append(step) + checkpoint[step] = data + + path.write_text(json.dumps(checkpoint, indent=2, default=str)) + + +def load_checkpoint(reddit_id: str) -> Optional[dict]: + path = _checkpoint_path(reddit_id) + if not path.exists(): + return None + try: + return json.loads(path.read_text()) + except (json.JSONDecodeError, OSError): + return None + + +def is_step_done(reddit_id: str, step: str) -> bool: + cp = load_checkpoint(reddit_id) + if not cp: + return False + return step in cp.get("completed_steps", []) + + +def get_step_data(reddit_id: str, step: str) -> Optional[dict]: + cp = load_checkpoint(reddit_id) + if not cp: + return None + return cp.get(step) + + +def clear_checkpoint(reddit_id: str) -> None: + path = _checkpoint_path(reddit_id) + if path.exists(): + path.unlink() + + +def print_resume_status(reddit_id: str) -> None: + cp = load_checkpoint(reddit_id) + if not cp: + return + done = cp.get("completed_steps", []) + print_substep(f"Resuming from checkpoint. Completed steps: {', '.join(done)}", style="bold yellow") + + +PIPELINE_STEPS = [ + "fetch_reddit", + "generate_tts", + "take_screenshots", + "download_background", + "chop_background", + "make_final_video", +] + + +def run_step(reddit_id: str, step: str, func, *args, max_retries: int = 3, **kwargs) -> Any: + if is_step_done(reddit_id, step): + data = get_step_data(reddit_id, step) + print_substep(f"Step '{step}' already done. Skipping.", style="bold green") + return data.get("result") if data else None + + last_error = None + for attempt in range(1, max_retries + 1): + try: + if attempt > 1: + print_substep(f"Retry {attempt}/{max_retries} for step '{step}'...", style="bold yellow") + result = func(*args, **kwargs) + save_checkpoint(reddit_id, step, {"result": result}) + print_substep(f"Step '{step}' completed successfully.", style="bold green") + return result + except KeyboardInterrupt: + raise + except Exception as e: + last_error = e + print_substep(f"Step '{step}' failed (attempt {attempt}/{max_retries}): {e}", style="bold red") + if attempt < max_retries: + wait = 2 ** attempt + print_substep(f"Waiting {wait}s before retry...", style="yellow") + time.sleep(wait) + + print_step(f"Step '{step}' failed after {max_retries} attempts.") + save_checkpoint(reddit_id, f"{step}_failed", {"error": str(last_error)}) + raise last_error