moved to async webdriver

pull/963/head
Drugsosos 2 years ago
parent c617af98ce
commit ed8cd3cd09
No known key found for this signature in database
GPG Key ID: 8E35176FE617E28D

@ -1,4 +1,5 @@
#!/usr/bin/env python
from asyncio import run
import math
from subprocess import Popen
from os import name
@ -14,7 +15,7 @@ from video_creation.background import (
get_background_config,
)
from video_creation.final_video import make_final_video
from video_creation.screenshot_downloader import download_screenshots_of_reddit_posts
from video_creation.screenshot_downloader import Reddit
from video_creation.voices import save_text_to_mp3
VERSION = "2.2.9"
@ -35,24 +36,28 @@ print_markdown(
print_step(f"You are using V{VERSION} of the bot")
def main(POST_ID=None):
async def main(
POST_ID=None
):
cleanup()
reddit_object = get_subreddit_threads(POST_ID)
length, number_of_comments = save_text_to_mp3(reddit_object)
length = math.ceil(length)
download_screenshots_of_reddit_posts(reddit_object, number_of_comments)
reddit_screenshots = Reddit(reddit_object, number_of_comments)
browser = await reddit_screenshots.get_browser()
await reddit_screenshots.download_screenshots(browser)
bg_config = get_background_config()
download_background(bg_config)
chop_background_video(bg_config, length)
make_final_video(number_of_comments, length, reddit_object, bg_config)
def run_many(times):
async def run_many(times):
for x in range(1, times + 1):
print_step(
f'on the {x}{("th", "st", "nd", "rd", "th", "th", "th", "th","th", "th")[x%10]} iteration of {times}'
) # correct 1st 2nd 3rd 4th 5th....
main()
await main()
Popen("cls" if name == "nt" else "clear", shell=True).wait()
@ -61,7 +66,9 @@ if __name__ == "__main__":
config is False and exit()
try:
if config["settings"]["times_to_run"]:
run(
run_many(config["settings"]["times_to_run"])
)
elif len(config["reddit"]["thread"]["post_id"].split("+")) > 1:
for index, post_id in enumerate(config["reddit"]["thread"]["post_id"].split("+")):
@ -69,7 +76,9 @@ if __name__ == "__main__":
print_step(
f'on the {index}{("st" if index%10 == 1 else ("nd" if index%10 == 2 else ("rd" if index%10 == 3 else "th")))} post of {len(config["reddit"]["thread"]["post_id"].split("+"))}'
)
run(
main(post_id)
)
Popen("cls" if name == "nt" else "clear", shell=True).wait()
else:
main()

@ -9,7 +9,9 @@ from utils.subreddit import get_subreddit_undone
from utils.videos import check_done
def get_subreddit_threads(POST_ID: str):
def get_subreddit_threads(
POST_ID: str
):
"""
Returns a list of threads from the AskReddit subreddit.
"""
@ -87,6 +89,7 @@ def get_subreddit_threads(POST_ID: str):
content["thread_title"] = submission.title
content["thread_post"] = submission.selftext
content["thread_id"] = submission.id
content["is_nsfw"] = 'nsfw' in submission.whitelist_status
content["comments"] = []
for top_level_comment in submission.comments:

@ -2,10 +2,11 @@ boto3==1.24.24
botocore==1.27.24
gTTS==2.2.4
moviepy==1.0.3
playwright==1.23.0
praw==7.6.0
pytube==12.1.0
requests==2.28.1
rich==12.4.4
toml==0.10.2
translators==5.3.1
pyppeteer==1.0.2
attrs==21.4.0

@ -1,107 +1,283 @@
import json
from asyncio import as_completed
from pathlib import Path
from typing import Dict
from utils import settings
from playwright.async_api import async_playwright # pylint: disable=unused-import
# do not remove the above line
from pyppeteer import launch
from pyppeteer.page import Page as PageCls
from pyppeteer.browser import Browser as BrowserCls
from pyppeteer.element_handle import ElementHandle as ElementHandleCls
from pyppeteer.errors import TimeoutError as BrowserTimeoutError
from playwright.sync_api import sync_playwright, ViewportSize
from rich.progress import track
import translators as ts
from utils.console import print_step, print_substep
storymode = False
from attr import attrs, attrib
from typing import TypeVar, Optional, Callable, Union
_function = TypeVar('_function', bound=Callable[..., object])
_exceptions = TypeVar('_exceptions', bound=Optional[Union[type, tuple, list]])
@attrs
class ExceptionDecorator:
__exception: Optional[_exceptions] = attrib(default=None)
__default_exception: _exceptions = attrib(default=BrowserTimeoutError)
def __attrs_post_init__(self):
if not self.__exception:
self.__exception = self.__default_exception
def __call__(
self,
func: _function,
):
async def wrapper(*args, **kwargs):
try:
obj_to_return = await func(*args, **kwargs)
return obj_to_return
except Exception as caughtException:
import logging
if isinstance(self.__exception, type):
if not type(caughtException) == self.__exception:
logging.basicConfig(filename='.webdriver.log', filemode='w', encoding='utf-8',
level=logging.DEBUG)
logging.error(f'unexpected error - {caughtException}')
else:
if not type(caughtException) in self.__exception:
logging.error(f'unexpected error - {caughtException}')
return wrapper
def catch_exception(
func: Optional[_function],
exception: Optional[_exceptions] = None,
) -> ExceptionDecorator | _function:
exceptor = ExceptionDecorator(exception)
if func:
exceptor = exceptor(func)
return exceptor
@attrs
class Browser:
# default_Viewport: dict = attrib(validator=instance_of(dict), default=dict())
#
# def __attrs_post_init__(self):
# if self.default_Viewport.__len__() == 0:
# self.default_Viewport['isLandscape'] = True
async def get_browser(
self,
) -> BrowserCls:
return await launch()
async def close_browser(
self,
browser: BrowserCls
) -> None:
await browser.close()
def download_screenshots_of_reddit_posts(reddit_object: dict, screenshot_num: int):
"""Downloads screenshots of reddit posts as seen on the web. Downloads to assets/temp/png
class Wait:
@staticmethod
@catch_exception
async def find_xpath(
page_instance: PageCls,
xpath: Optional[str] = None,
options: Optional[dict] = None,
) -> 'ElementHandleCls':
if options:
el = await page_instance.waitForXPath(xpath, options=options)
else:
el = await page_instance.waitForXPath(xpath)
return el
@catch_exception
async def click(
self,
page_instance: Optional[PageCls] = None,
xpath: Optional[str] = None,
find_options: Optional[dict] = None,
options: Optional[dict] = None,
el: Optional[ElementHandleCls] = None,
) -> None:
if not el:
el = await self.find_xpath(page_instance, xpath, find_options)
if options:
await el.click(options)
else:
await el.click()
@catch_exception
async def screenshot(
self,
page_instance: Optional[PageCls] = None,
xpath: Optional[str] = None,
options: Optional[dict] = None,
find_options: Optional[dict] = None,
el: Optional[ElementHandleCls] = None,
) -> None:
if not el:
el = await self.find_xpath(page_instance, xpath, find_options)
if options:
await el.screenshot(options)
else:
await el.screenshot()
@attrs(auto_attribs=True)
class Reddit(Browser, Wait):
"""
Args:
reddit_object (Dict): Reddit object received from reddit/subreddit.py
screenshot_num (int): Number of screenshots to download
"""
print_step("Downloading screenshots of reddit posts...")
reddit_object: dict
screenshot_num: int = attrib()
# ! Make sure the reddit screenshots folder exists
Path("assets/temp/png").mkdir(parents=True, exist_ok=True)
@screenshot_num.validator
def validate_screenshot_num(self, attribute, value):
if value <= 0:
raise ValueError('Check screenshot_num in config')
with sync_playwright() as p:
print_substep("Launching Headless Browser...")
async def dark_theme(
self,
page_instance: PageCls,
) -> None:
"""
Enables dark theme in Reddit
"""
browser = p.chromium.launch()
context = browser.new_context()
await self.click(
page_instance,
'//*[contains(@class, \'header-user-dropdown\')]',
{'timeout': 5000},
)
# It's normal not to find it, sometimes there is none :shrug:
await self.click(
page_instance,
'//*[contains(text(), \'Settings\')]/ancestor::button[1]',
{'timeout': 5000},
)
await self.click(
page_instance,
'//*[contains(text(), \'Dark Mode\')]/ancestor::button[1]',
{'timeout': 5000},
)
# Closes settings
await self.click(
page_instance,
'//*[contains(@class, \'header-user-dropdown\')]',
{'timeout': 5000},
)
async def download_screenshots(
self,
browser: BrowserCls
):
"""
Downloads screenshots of reddit posts as seen on the web. Downloads to assets/temp/png
"""
print_step('Downloading screenshots of reddit posts...')
# ! Make sure the reddit screenshots folder exists
Path('assets/temp/png').mkdir(parents=True, exist_ok=True)
print_substep('Launching Headless Browser...')
if settings.config["settings"]["theme"] == "dark":
cookie_file = open("./video_creation/data/cookie-dark-mode.json", encoding="utf-8")
else:
cookie_file = open("./video_creation/data/cookie-light-mode.json", encoding="utf-8")
cookies = json.load(cookie_file)
context.add_cookies(cookies) # load preference cookies
# Get the thread screenshot
page = context.new_page()
page.goto(reddit_object["thread_url"], timeout=0)
page.set_viewport_size(ViewportSize(width=1920, height=1080))
if page.locator('[data-testid="content-gate"]').is_visible():
reddit_main = await browser.newPage()
await reddit_main.goto(self.reddit_object['thread_url'])
if settings.config['settings']['theme'] == 'dark':
await self.dark_theme(reddit_main)
if self.reddit_object['is_nsfw']:
# This means the post is NSFW and requires to click the proceed button.
print_substep("Post is NSFW. You are spicy...")
page.locator('[data-testid="content-gate"] button').click()
page.locator(
'[data-click-id="text"] button'
).click() # Remove "Click to see nsfw" Button in Screenshot
print_substep('Post is NSFW. You are spicy...')
await self.click(
reddit_main,
'//button[contains(text(), \'Yes\')]',
{'timeout': 5000},
)
await self.click(
reddit_main,
'//button[contains(text(), \'nsfw\')]',
{'timeout': 5000},
)
# translate code
if settings.config["reddit"]["thread"]["post_lang"]:
print_substep("Translating post...")
if settings.config['reddit']['thread']['post_lang']:
print_substep('Translating post...')
texts_in_tl = ts.google(
reddit_object["thread_title"],
to_language=settings.config["reddit"]["thread"]["post_lang"],
self.reddit_object['thread_title'],
to_language=settings.config['reddit']['thread']['post_lang'],
)
page.evaluate(
"tl_content => document.querySelector('[data-test-id=\"post-content\"] > div:nth-child(3) > div > div').textContent = tl_content",
await reddit_main.evaluate(
"tl_content => document.querySelector('[data-test-id=\"post-content\"] > div:nth-child(3) > div > "
"div').textContent = tl_content",
texts_in_tl,
)
else:
print_substep("Skipping translation...")
page.locator('[data-test-id="post-content"]').screenshot(path="assets/temp/png/title.png")
if storymode:
page.locator('[data-click-id="text"]').screenshot(
path="assets/temp/png/story_content.png"
await self.screenshot(
reddit_main,
f'//*[contains(@id, \'t3_{self.reddit_object["thread_id"]}\')]',
{'path': f'assets/temp/png/title.png'},
)
else:
for idx, comment in enumerate(
track(reddit_object["comments"], "Downloading screenshots...")
):
# Stop if we have reached the screenshot_num
if idx >= screenshot_num:
break
if page.locator('[data-testid="content-gate"]').is_visible():
page.locator('[data-testid="content-gate"] button').click()
page.goto(f'https://reddit.com{comment["comment_url"]}', timeout=0)
async def collect_comment(
comment_obj: dict,
filename_idx: int,
):
comment_page = await browser.newPage()
await comment_page.goto(f'https://reddit.com{comment_obj["comment_url"]}')
# translate code
if settings.config["reddit"]["thread"]["post_lang"]:
comment_tl = ts.google(
comment["comment_body"],
comment_obj["comment_body"],
to_language=settings.config["reddit"]["thread"]["post_lang"],
)
page.evaluate(
'([tl_content, tl_id]) => document.querySelector(`#t1_${tl_id} > div:nth-child(2) > div > div[data-testid="comment"] > div`).textContent = tl_content',
[comment_tl, comment["comment_id"]],
await comment_page.evaluate(
'([tl_content, tl_id]) => document.querySelector(`#t1_${tl_id} > div:nth-child(2) > div > div['
'data-testid="comment"] > div`).textContent = tl_content',
[comment_tl, comment_obj["comment_id"]],
)
page.locator(f"#t1_{comment['comment_id']}").screenshot(
path=f"assets/temp/png/comment_{idx}.png"
await self.screenshot(
comment_page,
f'//*[contains(@id, \'t1_{comment_obj["comment_id"]}\')]',
{'path': f'assets/temp/png/comment_{filename_idx}.png'},
)
print_substep("Screenshots downloaded Successfully.", style="bold green")
async_tasks_primary = [
collect_comment(comment, idx) for idx, comment in
enumerate(self.reddit_object['comments'])
if idx < self.screenshot_num
]
for task in track(
as_completed(async_tasks_primary),
description='Downloading screenshots...',
total=async_tasks_primary.__len__(),
):
await task
print_substep('Screenshots downloaded Successfully.', style='bold green')

Loading…
Cancel
Save