diff --git a/main.py b/main.py index 8ce8725..0ba8f27 100755 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +from asyncio import run import math from subprocess import Popen from os import name @@ -14,7 +15,7 @@ from video_creation.background import ( get_background_config, ) from video_creation.final_video import make_final_video -from video_creation.screenshot_downloader import download_screenshots_of_reddit_posts +from video_creation.screenshot_downloader import Reddit from video_creation.voices import save_text_to_mp3 VERSION = "2.2.9" @@ -35,24 +36,28 @@ print_markdown( print_step(f"You are using V{VERSION} of the bot") -def main(POST_ID=None): +async def main( + POST_ID=None +): cleanup() reddit_object = get_subreddit_threads(POST_ID) length, number_of_comments = save_text_to_mp3(reddit_object) length = math.ceil(length) - download_screenshots_of_reddit_posts(reddit_object, number_of_comments) + reddit_screenshots = Reddit(reddit_object, number_of_comments) + browser = await reddit_screenshots.get_browser() + await reddit_screenshots.download_screenshots(browser) bg_config = get_background_config() download_background(bg_config) chop_background_video(bg_config, length) make_final_video(number_of_comments, length, reddit_object, bg_config) -def run_many(times): +async def run_many(times): for x in range(1, times + 1): print_step( f'on the {x}{("th", "st", "nd", "rd", "th", "th", "th", "th","th", "th")[x%10]} iteration of {times}' ) # correct 1st 2nd 3rd 4th 5th.... - main() + await main() Popen("cls" if name == "nt" else "clear", shell=True).wait() @@ -61,7 +66,9 @@ if __name__ == "__main__": config is False and exit() try: if config["settings"]["times_to_run"]: - run_many(config["settings"]["times_to_run"]) + run( + run_many(config["settings"]["times_to_run"]) + ) elif len(config["reddit"]["thread"]["post_id"].split("+")) > 1: for index, post_id in enumerate(config["reddit"]["thread"]["post_id"].split("+")): @@ -69,7 +76,9 @@ if __name__ == "__main__": print_step( f'on the {index}{("st" if index%10 == 1 else ("nd" if index%10 == 2 else ("rd" if index%10 == 3 else "th")))} post of {len(config["reddit"]["thread"]["post_id"].split("+"))}' ) - main(post_id) + run( + main(post_id) + ) Popen("cls" if name == "nt" else "clear", shell=True).wait() else: main() diff --git a/reddit/subreddit.py b/reddit/subreddit.py index b64a52a..2ce80ce 100644 --- a/reddit/subreddit.py +++ b/reddit/subreddit.py @@ -9,7 +9,9 @@ from utils.subreddit import get_subreddit_undone from utils.videos import check_done -def get_subreddit_threads(POST_ID: str): +def get_subreddit_threads( + POST_ID: str +): """ Returns a list of threads from the AskReddit subreddit. """ @@ -87,6 +89,7 @@ def get_subreddit_threads(POST_ID: str): content["thread_title"] = submission.title content["thread_post"] = submission.selftext content["thread_id"] = submission.id + content["is_nsfw"] = 'nsfw' in submission.whitelist_status content["comments"] = [] for top_level_comment in submission.comments: diff --git a/requirements.txt b/requirements.txt index 8b377c2..10ce9c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,10 +2,11 @@ boto3==1.24.24 botocore==1.27.24 gTTS==2.2.4 moviepy==1.0.3 -playwright==1.23.0 praw==7.6.0 pytube==12.1.0 requests==2.28.1 rich==12.4.4 toml==0.10.2 translators==5.3.1 +pyppeteer==1.0.2 +attrs==21.4.0 diff --git a/video_creation/screenshot_downloader.py b/video_creation/screenshot_downloader.py index 6fb9ef4..526a17c 100644 --- a/video_creation/screenshot_downloader.py +++ b/video_creation/screenshot_downloader.py @@ -1,107 +1,283 @@ -import json +from asyncio import as_completed from pathlib import Path from typing import Dict + from utils import settings -from playwright.async_api import async_playwright # pylint: disable=unused-import -# do not remove the above line +from pyppeteer import launch +from pyppeteer.page import Page as PageCls +from pyppeteer.browser import Browser as BrowserCls +from pyppeteer.element_handle import ElementHandle as ElementHandleCls +from pyppeteer.errors import TimeoutError as BrowserTimeoutError -from playwright.sync_api import sync_playwright, ViewportSize from rich.progress import track import translators as ts from utils.console import print_step, print_substep -storymode = False +from attr import attrs, attrib +from typing import TypeVar, Optional, Callable, Union + +_function = TypeVar('_function', bound=Callable[..., object]) +_exceptions = TypeVar('_exceptions', bound=Optional[Union[type, tuple, list]]) + + +@attrs +class ExceptionDecorator: + __exception: Optional[_exceptions] = attrib(default=None) + __default_exception: _exceptions = attrib(default=BrowserTimeoutError) + + def __attrs_post_init__(self): + if not self.__exception: + self.__exception = self.__default_exception + + def __call__( + self, + func: _function, + ): + async def wrapper(*args, **kwargs): + try: + obj_to_return = await func(*args, **kwargs) + return obj_to_return + except Exception as caughtException: + import logging + + if isinstance(self.__exception, type): + if not type(caughtException) == self.__exception: + logging.basicConfig(filename='.webdriver.log', filemode='w', encoding='utf-8', + level=logging.DEBUG) + logging.error(f'unexpected error - {caughtException}') + else: + if not type(caughtException) in self.__exception: + logging.error(f'unexpected error - {caughtException}') + + return wrapper + + +def catch_exception( + func: Optional[_function], + exception: Optional[_exceptions] = None, +) -> ExceptionDecorator | _function: + exceptor = ExceptionDecorator(exception) + if func: + exceptor = exceptor(func) + return exceptor + +@attrs +class Browser: + # default_Viewport: dict = attrib(validator=instance_of(dict), default=dict()) + # + # def __attrs_post_init__(self): + # if self.default_Viewport.__len__() == 0: + # self.default_Viewport['isLandscape'] = True -def download_screenshots_of_reddit_posts(reddit_object: dict, screenshot_num: int): - """Downloads screenshots of reddit posts as seen on the web. Downloads to assets/temp/png + async def get_browser( + self, + ) -> BrowserCls: + return await launch() + async def close_browser( + self, + browser: BrowserCls + ) -> None: + await browser.close() + + +class Wait: + @staticmethod + @catch_exception + async def find_xpath( + page_instance: PageCls, + xpath: Optional[str] = None, + options: Optional[dict] = None, + ) -> 'ElementHandleCls': + if options: + el = await page_instance.waitForXPath(xpath, options=options) + else: + el = await page_instance.waitForXPath(xpath) + return el + + @catch_exception + async def click( + self, + page_instance: Optional[PageCls] = None, + xpath: Optional[str] = None, + find_options: Optional[dict] = None, + options: Optional[dict] = None, + el: Optional[ElementHandleCls] = None, + ) -> None: + if not el: + el = await self.find_xpath(page_instance, xpath, find_options) + if options: + await el.click(options) + else: + await el.click() + + @catch_exception + async def screenshot( + self, + page_instance: Optional[PageCls] = None, + xpath: Optional[str] = None, + options: Optional[dict] = None, + find_options: Optional[dict] = None, + el: Optional[ElementHandleCls] = None, + ) -> None: + if not el: + el = await self.find_xpath(page_instance, xpath, find_options) + if options: + await el.screenshot(options) + else: + await el.screenshot() + + +@attrs(auto_attribs=True) +class Reddit(Browser, Wait): + """ Args: reddit_object (Dict): Reddit object received from reddit/subreddit.py screenshot_num (int): Number of screenshots to download """ - print_step("Downloading screenshots of reddit posts...") + reddit_object: dict + screenshot_num: int = attrib() - # ! Make sure the reddit screenshots folder exists - Path("assets/temp/png").mkdir(parents=True, exist_ok=True) + @screenshot_num.validator + def validate_screenshot_num(self, attribute, value): + if value <= 0: + raise ValueError('Check screenshot_num in config') - with sync_playwright() as p: - print_substep("Launching Headless Browser...") + async def dark_theme( + self, + page_instance: PageCls, + ) -> None: + """ + Enables dark theme in Reddit + """ - browser = p.chromium.launch() - context = browser.new_context() + await self.click( + page_instance, + '//*[contains(@class, \'header-user-dropdown\')]', + {'timeout': 5000}, + ) + + # It's normal not to find it, sometimes there is none :shrug: + await self.click( + page_instance, + '//*[contains(text(), \'Settings\')]/ancestor::button[1]', + {'timeout': 5000}, + ) + + await self.click( + page_instance, + '//*[contains(text(), \'Dark Mode\')]/ancestor::button[1]', + {'timeout': 5000}, + ) + + # Closes settings + await self.click( + page_instance, + '//*[contains(@class, \'header-user-dropdown\')]', + {'timeout': 5000}, + ) + + async def download_screenshots( + self, + browser: BrowserCls + + ): + """ + Downloads screenshots of reddit posts as seen on the web. Downloads to assets/temp/png + """ + print_step('Downloading screenshots of reddit posts...') + + # ! Make sure the reddit screenshots folder exists + Path('assets/temp/png').mkdir(parents=True, exist_ok=True) + + print_substep('Launching Headless Browser...') - if settings.config["settings"]["theme"] == "dark": - cookie_file = open("./video_creation/data/cookie-dark-mode.json", encoding="utf-8") - else: - cookie_file = open("./video_creation/data/cookie-light-mode.json", encoding="utf-8") - cookies = json.load(cookie_file) - context.add_cookies(cookies) # load preference cookies # Get the thread screenshot - page = context.new_page() - page.goto(reddit_object["thread_url"], timeout=0) - page.set_viewport_size(ViewportSize(width=1920, height=1080)) - if page.locator('[data-testid="content-gate"]').is_visible(): + reddit_main = await browser.newPage() + await reddit_main.goto(self.reddit_object['thread_url']) + + if settings.config['settings']['theme'] == 'dark': + await self.dark_theme(reddit_main) + + if self.reddit_object['is_nsfw']: # This means the post is NSFW and requires to click the proceed button. - print_substep("Post is NSFW. You are spicy...") - page.locator('[data-testid="content-gate"] button').click() - page.locator( - '[data-click-id="text"] button' - ).click() # Remove "Click to see nsfw" Button in Screenshot + print_substep('Post is NSFW. You are spicy...') + await self.click( + reddit_main, + '//button[contains(text(), \'Yes\')]', + {'timeout': 5000}, + ) + + await self.click( + reddit_main, + '//button[contains(text(), \'nsfw\')]', + {'timeout': 5000}, + ) # translate code - if settings.config["reddit"]["thread"]["post_lang"]: - print_substep("Translating post...") + if settings.config['reddit']['thread']['post_lang']: + print_substep('Translating post...') texts_in_tl = ts.google( - reddit_object["thread_title"], - to_language=settings.config["reddit"]["thread"]["post_lang"], + self.reddit_object['thread_title'], + to_language=settings.config['reddit']['thread']['post_lang'], ) - page.evaluate( - "tl_content => document.querySelector('[data-test-id=\"post-content\"] > div:nth-child(3) > div > div').textContent = tl_content", + await reddit_main.evaluate( + "tl_content => document.querySelector('[data-test-id=\"post-content\"] > div:nth-child(3) > div > " + "div').textContent = tl_content", texts_in_tl, ) else: print_substep("Skipping translation...") - page.locator('[data-test-id="post-content"]').screenshot(path="assets/temp/png/title.png") + await self.screenshot( + reddit_main, + f'//*[contains(@id, \'t3_{self.reddit_object["thread_id"]}\')]', + {'path': f'assets/temp/png/title.png'}, + ) - if storymode: - page.locator('[data-click-id="text"]').screenshot( - path="assets/temp/png/story_content.png" - ) - else: - for idx, comment in enumerate( - track(reddit_object["comments"], "Downloading screenshots...") - ): - # Stop if we have reached the screenshot_num - if idx >= screenshot_num: - break - - if page.locator('[data-testid="content-gate"]').is_visible(): - page.locator('[data-testid="content-gate"] button').click() - - page.goto(f'https://reddit.com{comment["comment_url"]}', timeout=0) - - # translate code - - if settings.config["reddit"]["thread"]["post_lang"]: - comment_tl = ts.google( - comment["comment_body"], - to_language=settings.config["reddit"]["thread"]["post_lang"], - ) - page.evaluate( - '([tl_content, tl_id]) => document.querySelector(`#t1_${tl_id} > div:nth-child(2) > div > div[data-testid="comment"] > div`).textContent = tl_content', - [comment_tl, comment["comment_id"]], - ) - - page.locator(f"#t1_{comment['comment_id']}").screenshot( - path=f"assets/temp/png/comment_{idx}.png" + async def collect_comment( + comment_obj: dict, + filename_idx: int, + ): + comment_page = await browser.newPage() + await comment_page.goto(f'https://reddit.com{comment_obj["comment_url"]}') + + # translate code + if settings.config["reddit"]["thread"]["post_lang"]: + comment_tl = ts.google( + comment_obj["comment_body"], + to_language=settings.config["reddit"]["thread"]["post_lang"], ) + await comment_page.evaluate( + '([tl_content, tl_id]) => document.querySelector(`#t1_${tl_id} > div:nth-child(2) > div > div[' + 'data-testid="comment"] > div`).textContent = tl_content', + [comment_tl, comment_obj["comment_id"]], + ) + + await self.screenshot( + comment_page, + f'//*[contains(@id, \'t1_{comment_obj["comment_id"]}\')]', + {'path': f'assets/temp/png/comment_{filename_idx}.png'}, + ) + + async_tasks_primary = [ + collect_comment(comment, idx) for idx, comment in + enumerate(self.reddit_object['comments']) + if idx < self.screenshot_num + ] + + for task in track( + as_completed(async_tasks_primary), + description='Downloading screenshots...', + total=async_tasks_primary.__len__(), + ): + await task - print_substep("Screenshots downloaded Successfully.", style="bold green") + print_substep('Screenshots downloaded Successfully.', style='bold green')