added typing

pull/963/head
Drugsosos 2 years ago
parent eff38a5852
commit e19e532ea5
No known key found for this signature in database
GPG Key ID: 8E35176FE617E28D

@ -15,7 +15,7 @@ from video_creation.background import (
get_background_config,
)
from video_creation.final_video import make_final_video
from video_creation.screenshot_downloader import Reddit
from video_creation.screenshot_downloader import RedditScreenshot
from video_creation.voices import save_text_to_mp3
VERSION = "2.2.9"
@ -43,9 +43,7 @@ async def main(
reddit_object = get_subreddit_threads(POST_ID)
length, number_of_comments = save_text_to_mp3(reddit_object)
length = math.ceil(length)
reddit_screenshots = Reddit(reddit_object, number_of_comments)
browser = await reddit_screenshots.get_browser()
await reddit_screenshots.download_screenshots(browser)
await RedditScreenshot(reddit_object, number_of_comments).download()
bg_config = get_background_config()
download_background(bg_config)
chop_background_video(bg_config, length)

@ -1,30 +1,33 @@
from asyncio import as_completed
from pathlib import Path
from typing import Dict
from utils import settings
from pyppeteer import launch
from pyppeteer.page import Page as PageCls
from pyppeteer.browser import Browser as BrowserCls
from pyppeteer.element_handle import ElementHandle as ElementHandleCls
from pyppeteer.errors import TimeoutError as BrowserTimeoutError
from pathlib import Path
from typing import Dict
from utils import settings
from rich.progress import track
import translators as ts
from utils.console import print_step, print_substep
from attr import attrs, attrib
from attr.validators import instance_of, optional
from typing import TypeVar, Optional, Callable, Union
_function = TypeVar('_function', bound=Callable[..., object])
_exceptions = TypeVar('_exceptions', bound=Optional[Union[type, tuple, list]])
@attrs
class ExceptionDecorator:
"""
Factory for decorating functions
"""
__exception: Optional[_exceptions] = attrib(default=None)
__default_exception: _exceptions = attrib(default=BrowserTimeoutError)
@ -59,6 +62,15 @@ def catch_exception(
func: Optional[_function],
exception: Optional[_exceptions] = None,
) -> ExceptionDecorator | _function:
"""
Decorator for catching exceptions and writing logs
Args:
func: Function to be decorated
exception: Expected exception(s)
Returns:
Decorated function
"""
exceptor = ExceptionDecorator(exception)
if func:
exceptor = exceptor(func)
@ -67,22 +79,41 @@ def catch_exception(
@attrs
class Browser:
# default_Viewport: dict = attrib(validator=instance_of(dict), default=dict())
#
# def __attrs_post_init__(self):
# if self.default_Viewport.__len__() == 0:
# self.default_Viewport['isLandscape'] = True
"""
Args:
default_Viewport (dict):Pyppeteer Browser default_Viewport options
browser (BrowserCls): Pyppeteer Browser instance
"""
default_Viewport: dict = attrib(
validator=instance_of(dict),
default=dict(),
kw_only=True,
)
browser: Optional[BrowserCls] = attrib(
validator=optional(instance_of(BrowserCls)),
default=None,
kw_only=True,
)
def __attrs_post_init__(self):
if self.default_Viewport.__len__() == 0:
self.default_Viewport['isLandscape'] = True
async def get_browser(
self,
) -> BrowserCls:
return await launch()
) -> None:
"""
Creates Pyppeteer browser
"""
self.browser = await launch(self.default_Viewport)
async def close_browser(
self,
browser: BrowserCls
) -> None:
await browser.close()
"""
Closes Pyppeteer browser
"""
await self.browser.close()
class Wait:
@ -93,6 +124,27 @@ class Wait:
xpath: Optional[str] = None,
options: Optional[dict] = None,
) -> 'ElementHandleCls':
"""
Explicitly finds element on the page
Args:
page_instance: Pyppeteer page instance
xpath: xpath query
options: Pyppeteer waitForXPath parameters
Available options are:
* ``visible`` (bool): wait for element to be present in DOM and to be
visible, i.e. to not have ``display: none`` or ``visibility: hidden``
CSS properties. Defaults to ``False``.
* ``hidden`` (bool): wait for element to not be found in the DOM or to
be hidden, i.e. have ``display: none`` or ``visibility: hidden`` CSS
properties. Defaults to ``False``.
* ``timeout`` (int|float): maximum time to wait for in milliseconds.
Defaults to 30000 (30 seconds). Pass ``0`` to disable timeout.
Returns:
Pyppeteer element instance
"""
if options:
el = await page_instance.waitForXPath(xpath, options=options)
else:
@ -108,6 +160,16 @@ class Wait:
options: Optional[dict] = None,
el: Optional[ElementHandleCls] = None,
) -> None:
"""
Clicks on the element
Args:
page_instance: Pyppeteer page instance
xpath: xpath query
find_options: Pyppeteer waitForXPath parameters
options: Pyppeteer click parameters
el: Pyppeteer element instance
"""
if not el:
el = await self.find_xpath(page_instance, xpath, find_options)
if options:
@ -124,6 +186,16 @@ class Wait:
find_options: Optional[dict] = None,
el: Optional[ElementHandleCls] = None,
) -> None:
"""
Makes a screenshot of the element
Args:
page_instance: Pyppeteer page instance
xpath: xpath query
options: Pyppeteer screenshot parameters
find_options: Pyppeteer waitForXPath parameters
el: Pyppeteer element instance
"""
if not el:
el = await self.find_xpath(page_instance, xpath, find_options)
if options:
@ -133,7 +205,7 @@ class Wait:
@attrs(auto_attribs=True)
class Reddit(Browser, Wait):
class RedditScreenshot(Browser, Wait):
"""
Args:
reddit_object (Dict): Reddit object received from reddit/subreddit.py
@ -147,12 +219,15 @@ class Reddit(Browser, Wait):
if value <= 0:
raise ValueError('Check screenshot_num in config')
async def dark_theme(
async def __dark_theme(
self,
page_instance: PageCls,
) -> None:
"""
Enables dark theme in Reddit
Args:
page_instance: Pyppeteer page instance with reddit page opened
"""
await self.click(
@ -181,14 +256,45 @@ class Reddit(Browser, Wait):
{'timeout': 5000},
)
async def download_screenshots(
async def __collect_comment(
self,
browser: BrowserCls
comment_obj: dict,
filename_idx: int,
) -> None:
"""
Makes a screenshot of the comment
Args:
comment_obj: prew comment object
filename_idx: index for the filename
"""
comment_page = await self.browser.newPage()
await comment_page.goto(f'https://reddit.com{comment_obj["comment_url"]}')
# Translates submission' comment
if settings.config["reddit"]["thread"]["post_lang"]:
comment_tl = ts.google(
comment_obj["comment_body"],
to_language=settings.config["reddit"]["thread"]["post_lang"],
)
await comment_page.evaluate(
f'([tl_content, tl_id]) => document.querySelector(`#t1_{comment_obj["comment_id"]} > div:nth-child(2) '
f'> div > div[data-testid="comment"] > div`).textContent = {comment_tl}',
)
await self.screenshot(
comment_page,
f'//*[contains(@id, \'t1_{comment_obj["comment_id"]}\')]',
{'path': f'assets/temp/png/comment_{filename_idx}.png'},
)
async def download(
self,
):
"""
Downloads screenshots of reddit posts as seen on the web. Downloads to assets/temp/png
"""
await self.get_browser()
print_step('Downloading screenshots of reddit posts...')
# ! Make sure the reddit screenshots folder exists
@ -197,11 +303,11 @@ class Reddit(Browser, Wait):
print_substep('Launching Headless Browser...')
# Get the thread screenshot
reddit_main = await browser.newPage()
reddit_main = await self.browser.newPage()
await reddit_main.goto(self.reddit_object['thread_url'])
if settings.config['settings']['theme'] == 'dark':
await self.dark_theme(reddit_main)
await self.__dark_theme(reddit_main)
if self.reddit_object['is_nsfw']:
# This means the post is NSFW and requires to click the proceed button.
@ -219,8 +325,7 @@ class Reddit(Browser, Wait):
{'timeout': 5000},
)
# translate code
# Translates submission title
if settings.config['reddit']['thread']['post_lang']:
print_substep('Translating post...')
texts_in_tl = ts.google(
@ -242,33 +347,8 @@ class Reddit(Browser, Wait):
{'path': f'assets/temp/png/title.png'},
)
async def collect_comment(
comment_obj: dict,
filename_idx: int,
):
comment_page = await browser.newPage()
await comment_page.goto(f'https://reddit.com{comment_obj["comment_url"]}')
# translate code
if settings.config["reddit"]["thread"]["post_lang"]:
comment_tl = ts.google(
comment_obj["comment_body"],
to_language=settings.config["reddit"]["thread"]["post_lang"],
)
await comment_page.evaluate(
'([tl_content, tl_id]) => document.querySelector(`#t1_${tl_id} > div:nth-child(2) > div > div['
'data-testid="comment"] > div`).textContent = tl_content',
[comment_tl, comment_obj["comment_id"]],
)
await self.screenshot(
comment_page,
f'//*[contains(@id, \'t1_{comment_obj["comment_id"]}\')]',
{'path': f'assets/temp/png/comment_{filename_idx}.png'},
)
async_tasks_primary = [
collect_comment(comment, idx) for idx, comment in
self.__collect_comment(comment, idx) for idx, comment in
enumerate(self.reddit_object['comments'])
if idx < self.screenshot_num
]
@ -281,3 +361,4 @@ class Reddit(Browser, Wait):
await task
print_substep('Screenshots downloaded Successfully.', style='bold green')
await self.close_browser()

Loading…
Cancel
Save