diff --git a/video_creation/screenshot_downloader.py b/video_creation/screenshot_downloader.py index d3d32ef..82448d1 100644 --- a/video_creation/screenshot_downloader.py +++ b/video_creation/screenshot_downloader.py @@ -20,30 +20,40 @@ def download_screenshots_of_reddit_posts(reddit_object, screenshot_num, theme): with sync_playwright() as p: print_substep("Launching Headless Browser...") - browser = p.chromium.launch() - context = browser.new_context() - + browser, context = p.chromium.launch(), browser.new_context() + page = context.new_page() + if theme.casefold() == "dark": cookie_file = open('video_creation/cookies.json') cookies = json.load(cookie_file) context.add_cookies(cookies) - # Get the thread screenshot - page = context.new_page() - page.goto(reddit_object["thread_url"]) - page.set_viewport_size(ViewportSize(width=1920, height=1080)) - if page.locator('[data-testid="content-gate"]').is_visible(): - # This means the post is NSFW and requires to click the proceed button. + + save_thread_screenshot(page, reddit_object) + save_comments_screenshots(page, reddit_object, screenshot_num) + + print_substep("Screenshots downloaded Successfully.", + style="bold green") - print_substep("Post is NSFW. You are spicy...") - page.locator('[data-testid="content-gate"] button').click() - page.locator('[data-test-id="post-content"]').screenshot( - path="assets/png/title.png" - ) +def save_thread_screenshot(page, reddit_object): + page.goto(reddit_object["thread_url"]) + page.set_viewport_size(ViewportSize(width=1920, height=1080)) + if page.locator('[data-testid="content-gate"]').is_visible(): + # This means the post is NSFW and requires to click the proceed button. - for idx, comment in track( - enumerate(reddit_object["comments"]), "Downloading screenshots..." + print_substep("Post is NSFW. You are spicy...") + page.locator('[data-testid="content-gate"] button').click() + + page.locator('[data-test-id="post-content"]').screenshot( + path="assets/png/title.png" + ) + + return page + +def save_comments_screenshots(page, reddit_object, screenshot_num): + for idx, comment in track( + enumerate(reddit_object["comments"]), description="Downloading screenshots..." ): # Stop if we have reached the screenshot_num @@ -56,7 +66,4 @@ def download_screenshots_of_reddit_posts(reddit_object, screenshot_num, theme): page.goto(f'https://reddit.com{comment["comment_url"]}') page.locator(f"#t1_{comment['comment_id']}").screenshot( path=f"assets/png/comment_{idx}.png" - ) - - print_substep("Screenshots downloaded Successfully.", - style="bold green") + ) \ No newline at end of file