diff --git a/TTS/GTTS.py b/TTS/GTTS.py index 31e29df..51dd92f 100644 --- a/TTS/GTTS.py +++ b/TTS/GTTS.py @@ -1,23 +1,27 @@ #!/usr/bin/env python3 -import random from utils import settings from gtts import gTTS -max_chars = 0 - class GTTS: - def __init__(self): - self.max_chars = 0 - self.voices = [] + max_chars = 0 + # voices = [] + + @staticmethod + def run( + text, + filepath + ) -> None: + """ + Calls for TTS api - def run(self, text, filepath): + Args: + text: text to be voiced over + filepath: name of the audio file + """ tts = gTTS( text=text, lang=settings.config["reddit"]["thread"]["post_lang"] or "en", slow=False, ) tts.save(filepath) - - def randomvoice(self): - return random.choice(self.voices) diff --git a/TTS/TikTok.py b/TTS/TikTok.py index 743118c..c321b89 100644 --- a/TTS/TikTok.py +++ b/TTS/TikTok.py @@ -1,15 +1,17 @@ -import base64 from utils import settings -import random import requests from requests.adapters import HTTPAdapter, Retry -# from profanity_filter import ProfanityFilter -# pf = ProfanityFilter() -# Code by @JasonLovesDoggo -# https://twitter.com/scanlime/status/1512598559769702406 +from attr import attrs, attrib +from attr.validators import instance_of -nonhuman = [ # DISNEY VOICES +from TTS.common import BaseApiTTS, get_random_voice + +# TTS examples: https://twitter.com/scanlime/status/1512598559769702406 + +voices = dict() + +voices["nonhuman"] = [ # DISNEY VOICES "en_us_ghostface", # Ghost Face "en_us_chewbacca", # Chewbacca "en_us_c3po", # C3PO @@ -18,7 +20,7 @@ nonhuman = [ # DISNEY VOICES "en_us_rocket", # Rocket # ENGLISH VOICES ] -human = [ +voices["human"] = [ "en_au_001", # English AU - Female "en_au_002", # English AU - Male "en_uk_001", # English UK - Male 1 @@ -30,9 +32,8 @@ human = [ "en_us_009", # English US - Male 3 "en_us_010", ] -voices = nonhuman + human -noneng = [ +voices["non_eng"] = [ "fr_001", # French - Male 1 "fr_002", # French - Male 2 "de_001", # German - Female @@ -56,32 +57,51 @@ noneng = [ ] -# good_voices = {'good': ['en_us_002', 'en_us_006'], -# 'ok': ['en_au_002', 'en_uk_001']} # less en_us_stormtrooper more less en_us_rocket en_us_ghostface +# good_voices: 'en_us_002', 'en_us_006' +# ok: 'en_au_002', 'en_uk_001' +# less: en_us_stormtrooper +# more or less: en_us_rocket, en_us_ghostface -class TikTok: # TikTok Text-to-Speech Wrapper - def __init__(self): - self.URI_BASE = ( - "https://api16-normal-useast5.us.tiktokv.com/media/api/text/speech/invoke/?text_speaker=" - ) - self.max_chars = 300 - self.voices = {"human": human, "nonhuman": nonhuman, "noneng": noneng} +@attrs +class TikTok(BaseApiTTS): # TikTok Text-to-Speech Wrapper + random_voice: bool = attrib( + validator=instance_of(bool), + default=False + ) + uri_base: str = "https://api16-normal-useast5.us.tiktokv.com/media/api/text/speech/invoke/" + max_chars: int = 300 + decode_base64: bool = True - def run(self, text, filepath, random_voice: bool = False): - # if censor: - # req_text = pf.censor(req_text) - # pass + def make_request( + self, + text: str, + ): + """ + Makes a requests to remote TTS service + + Args: + text: text to be voice over + + Returns: + Request's response + """ voice = ( - self.randomvoice() - if random_voice - else ( - settings.config["settings"]["tts"]["tiktok_voice"] - or random.choice(self.voices["human"]) - ) + get_random_voice(voices, "human") + if self.random_voice + else str(settings.config["settings"]["tts"]["tiktok_voice"]).lower() + if str(settings.config["settings"]["tts"]["tiktok_voice"]).lower() in [ + voice.lower() for dict_title in voices for voice in voices[dict_title]] + else get_random_voice(voices, "human") ) try: - r = requests.post(f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0") + r = requests.post( + self.uri_base, + params={ + "text_speaker": voice, + "req_text": text, + "speaker_map_type": 0, + }) except requests.exceptions.SSLError: # https://stackoverflow.com/a/47475019/18516611 session = requests.Session() @@ -89,13 +109,6 @@ class TikTok: # TikTok Text-to-Speech Wrapper adapter = HTTPAdapter(max_retries=retry) session.mount("http://", adapter) session.mount("https://", adapter) - r = session.post(f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0") + r = session.post(f"{self.uri_base}{voice}&req_text={text}&speaker_map_type=0") # print(r.text) - vstr = [r.json()["data"]["v_str"]][0] - b64d = base64.b64decode(vstr) - - with open(filepath, "wb") as out: - out.write(b64d) - - def randomvoice(self): - return random.choice(self.voices["human"]) + return r.json()["data"]["v_str"] diff --git a/TTS/aws_polly.py b/TTS/aws_polly.py index efd762b..1a9c87b 100644 --- a/TTS/aws_polly.py +++ b/TTS/aws_polly.py @@ -1,9 +1,14 @@ #!/usr/bin/env python3 from boto3 import Session from botocore.exceptions import BotoCoreError, ClientError, ProfileNotFound + import sys from utils import settings -import random +from attr import attrs, attrib +from attr.validators import instance_of + +from TTS.common import get_random_voice + voices = [ "Brian", @@ -24,23 +29,37 @@ voices = [ ] +@attrs class AWSPolly: - def __init__(self): - self.max_chars = 0 - self.voices = voices + random_voice: bool = attrib( + validator=instance_of(bool), + default=False + ) + max_chars: int = 0 + + def run( + self, + text, + filepath, + ) -> None: + """ + Calls for TTS api - def run(self, text, filepath, random_voice: bool = False): + Args: + text: text to be voiced over + filepath: name of the audio file + """ try: session = Session(profile_name="polly") polly = session.client("polly") - if random_voice: - voice = self.randomvoice() - else: - if not settings.config["settings"]["tts"]["aws_polly_voice"]: - raise ValueError( - f"Please set the TOML variable AWS_VOICE to a valid voice. options are: {voices}" - ) - voice = str(settings.config["settings"]["tts"]["aws_polly_voice"]).capitalize() + voice = ( + get_random_voice(voices) + if self.random_voice + else str(settings.config["settings"]["tts"]["aws_polly_voice"]).capitalize() + if str(settings.config["settings"]["tts"]["aws_polly_voice"]).lower() in [voice.lower() for voice in + voices] + else get_random_voice(voices) + ) try: # Request speech synthesis response = polly.synthesize_speech( @@ -71,6 +90,3 @@ class AWSPolly: """ ) sys.exit(-1) - - def randomvoice(self): - return random.choice(self.voices) diff --git a/TTS/common.py b/TTS/common.py new file mode 100644 index 0000000..f355a89 --- /dev/null +++ b/TTS/common.py @@ -0,0 +1,141 @@ +import base64 +from random import choice +from typing import Union, Optional + + +class BaseApiTTS: + max_chars: int + decode_base64: bool = False + + @staticmethod + def text_len_sanitize( + text: str, + max_length: int, + ) -> list: + """ + Splits text if it's too long to be a query + + Args: + text: text to be sanitized + max_length: maximum length of the query + + Returns: + Split text as a list + """ + # Split by comma or dot (else you can lose intonations), if there is non, split by groups of 299 chars + split_text = list( + map(lambda x: x.strip() if x.strip()[-1] != "." else x.strip()[:-1], + filter(lambda x: True if x else False, text.split("."))) + ) + if split_text and all([chunk.__len__() < max_length for chunk in split_text]): + return split_text + + split_text = list( + map(lambda x: x.strip() if x.strip()[-1] != "," else x.strip()[:-1], + filter(lambda x: True if x else False, text.split(",")) + ) + ) + if split_text and all([chunk.__len__() < max_length for chunk in split_text]): + return split_text + + return list( + map( + lambda x: x.strip() if x.strip()[-1] != "." or x.strip()[-1] != "," else x.strip()[:-1], + filter( + lambda x: True if x else False, + [text[i:i + max_length] for i in range(0, len(text), max_length)] + ) + ) + ) + + def write_file( + self, + output_text: str, + filepath: str, + ) -> None: + """ + Writes and decodes TTS responses in files + + Args: + output_text: text to be written + filepath: path/name of the file + """ + decoded_text = base64.b64decode(output_text) if self.decode_base64 else output_text + + with open(filepath, "wb") as out: + out.write(decoded_text) + + def run( + self, + text: str, + filepath: str, + ) -> None: + """ + Calls for TTS api and writes audio file + + Args: + text: text to be voice over + filepath: path/name of the file + + Returns: + + """ + output_text = "" + if len(text) > self.max_chars: + for part in self.text_len_sanitize(text, self.max_chars): + if part: + output_text += self.make_request(part) + else: + output_text = self.make_request(text) + self.write_file(output_text, filepath) + + +def get_random_voice( + voices: Union[list, dict], + key: Optional[str] = None, +) -> str: + """ + Return random voice from list or dict + + Args: + voices: list or dict of voices + key: key of a dict if you are using one + + Returns: + random voice as a str + """ + if isinstance(voices, list): + return choice(voices) + else: + return choice(voices[key] if key else list(voices.values())[0]) + + +def audio_length( + path: str, +) -> Union[float, int]: + """ + Gets the length of the audio file + + Args: + path: audio file path + + Returns: + length in seconds as an int + """ + from moviepy.editor import AudioFileClip + + try: + # please use something else here in the future + audio_clip = AudioFileClip(path) + audio_duration = audio_clip.duration + audio_clip.close() + return audio_duration + except Exception as e: + import logging + + logger = logging.getLogger("tts_logger") + logger.setLevel(logging.ERROR) + handler = logging.FileHandler(".tts.log", mode="a+", encoding="utf-8") + logger.addHandler(handler) + logger.error("Error occurred in audio_length:", e) + return 0 diff --git a/TTS/engine_wrapper.py b/TTS/engine_wrapper.py index 267f47a..af45d38 100644 --- a/TTS/engine_wrapper.py +++ b/TTS/engine_wrapper.py @@ -1,145 +1,140 @@ #!/usr/bin/env python3 from pathlib import Path -from typing import Tuple -import re +from typing import Union -# import sox -# from mutagen import MutagenError -# from mutagen.mp3 import MP3, HeaderNotFoundError import translators as ts from rich.progress import track -from moviepy.editor import AudioFileClip, CompositeAudioClip, concatenate_audioclips +from attr import attrs, attrib + from utils.console import print_step, print_substep from utils.voice import sanitize_text from utils import settings +from TTS.common import audio_length -DEFAULT_MAX_LENGTH: int = 50 # video length variable +from TTS.GTTS import GTTS +from TTS.streamlabs_polly import StreamlabsPolly +from TTS.TikTok import TikTok +from TTS.aws_polly import AWSPolly +@attrs class TTSEngine: - """Calls the given TTS engine to reduce code duplication and allow multiple TTS engines. Args: tts_module : The TTS module. Your module should handle the TTS itself and saving to the given path under the run method. reddit_object : The reddit object that contains the posts to read. - path (Optional) : The unix style path to save the mp3 files to. This must not have leading or trailing slashes. - max_length (Optional) : The maximum length of the mp3 files in total. Notes: tts_module must take the arguments text and filepath. """ + tts_module: Union[GTTS, StreamlabsPolly, TikTok, AWSPolly] = attrib() + reddit_object: dict = attrib() + __path: str = "assets/temp/mp3" + __total_length: int = 0 + + def __attrs_post_init__(self): + # Calls an instance of the tts_module class + self.tts_module = self.tts_module() + # Loading settings from the config + self.max_length: int = settings.config["settings"]["video_length"] + self.time_before_tts: float = settings.config["settings"]["time_before_tts"] + self.time_between_pictures: float = settings.config["settings"]["time_between_pictures"] + self.__total_length = ( + settings.config["settings"]["time_before_first_picture"] + + settings.config["settings"]["delay_before_end"] + ) + + def run( + self + ) -> list: + """ + Voices over comments & title of the submission + + Returns: + Indexes of comments to be used in the final video + """ + Path(self.__path).mkdir(parents=True, exist_ok=True) - def __init__( - self, - tts_module, - reddit_object: dict, - path: str = "assets/temp/mp3", - max_length: int = DEFAULT_MAX_LENGTH, - last_clip_length: int = 0, - ): - self.tts_module = tts_module() - self.reddit_object = reddit_object - self.path = path - self.max_length = max_length - self.length = 0 - self.last_clip_length = last_clip_length - - def run(self) -> Tuple[int, int]: - - Path(self.path).mkdir(parents=True, exist_ok=True) - - # This file needs to be removed in case this post does not use post text, so that it won't appear in the final video + # This file needs to be removed in case this post does not use post text + # so that it won't appear in the final video try: - Path(f"{self.path}/posttext.mp3").unlink() + Path(f"{self.__path}/posttext.mp3").unlink() except OSError: pass print_step("Saving Text to MP3 files...") self.call_tts("title", self.reddit_object["thread_title"]) - if ( - self.reddit_object["thread_post"] != "" - and settings.config["settings"]["storymode"] == True - ): + + if self.reddit_object["thread_post"] and settings.config["settings"]["storymode"]: self.call_tts("posttext", self.reddit_object["thread_post"]) - idx = None - for idx, comment in track( - enumerate(self.reddit_object["comments"]), "Saving..." - ): - # ! Stop creating mp3 files if the length is greater than max length. - if self.length > self.max_length: - self.length -= self.last_clip_length - idx -= 1 - break - if ( - len(comment["comment_body"]) > self.tts_module.max_chars - ): # Split the comment if it is too long - self.split_post(comment["comment_body"], idx) # Split the comment - else: # If the comment is not too long, just call the tts engine - self.call_tts(f"{idx}", comment["comment_body"]) + sync_tasks_primary = [ + self.call_tts(str(idx), comment["comment_body"]) + for idx, comment in track( + enumerate(self.reddit_object["comments"]), + description="Saving...", + total=self.reddit_object["comments"].__len__()) + # Crunch, there will be fix in async TTS api, maybe + if self.__total_length + self.__total_length * 0.05 < self.max_length + ] print_substep("Saved Text to MP3 files successfully.", style="bold green") - return self.length, idx - - def split_post(self, text: str, idx: int): - split_files = [] - split_text = [ - x.group().strip() - for x in re.finditer( - r" *(((.|\n){0," + str(self.tts_module.max_chars) + "})(\.|.$))", text - ) + return [ + comments for comments, condition in + zip(range(self.reddit_object["comments"].__len__()), sync_tasks_primary) + if condition ] - offset = 0 - for idy, text_cut in enumerate(split_text): - # print(f"{idx}-{idy}: {text_cut}\n") - if not text_cut or text_cut.isspace(): - offset += 1 - continue - - self.call_tts(f"{idx}-{idy - offset}.part", text_cut) - split_files.append( - AudioFileClip(f"{self.path}/{idx}-{idy - offset}.part.mp3") - ) - - CompositeAudioClip([concatenate_audioclips(split_files)]).write_audiofile( - f"{self.path}/{idx}.mp3", fps=44100, verbose=False, logger=None - ) - for i in split_files: - name = i.filename - i.close() - Path(name).unlink() + def call_tts( + self, + filename: str, + text: str + ) -> bool: + """ + Calls for TTS api from the factory - # for i in range(0, idy + 1): - # print(f"Cleaning up {self.path}/{idx}-{i}.part.mp3") + Args: + filename: name of audio file w/o .mp3 + text: text to be voiced over - # Path(f"{self.path}/{idx}-{i}.part.mp3").unlink() + Returns: + True if audio files not exceeding the maximum length else false + """ + if not text: + return False - def call_tts(self, filename: str, text: str): self.tts_module.run( - text=process_text(text), filepath=f"{self.path}/{filename}.mp3" + text=self.process_text(text), + filepath=f"{self.__path}/{filename}.mp3" ) - # try: - # self.length += MP3(f"{self.path}/{filename}.mp3").info.length - # except (MutagenError, HeaderNotFoundError): - # self.length += sox.file_info.duration(f"{self.path}/{filename}.mp3") - try: - clip = AudioFileClip(f"{self.path}/{filename}.mp3") - if clip.duration + self.length < self.max_length: - self.last_clip_length = clip.duration - self.length += clip.duration - clip.close() - except: - self.length = 0 - - -def process_text(text: str): - lang = settings.config["reddit"]["thread"]["post_lang"] - new_text = sanitize_text(text) - if lang: - print_substep("Translating Text...") - translated_text = ts.google(text, to_language=lang) - new_text = sanitize_text(translated_text) - return new_text + + clip_length = audio_length(f"{self.__path}/{filename}.mp3") + clip_offset = self.time_between_pictures + self.time_before_tts * 2 + + if clip_length and self.__total_length + clip_length + clip_offset <= self.max_length: + self.__total_length += clip_length + clip_offset + return True + return False + + @staticmethod + def process_text( + text: str, + ) -> str: + """ + Sanitizes text for illegal characters and translates text + + Args: + text: text to be sanitized & translated + + Returns: + Processed text as a str + """ + lang = settings.config["reddit"]["thread"]["post_lang"] + new_text = sanitize_text(text) + if lang: + print_substep("Translating Text...") + translated_text = ts.google(text, to_language=lang) + new_text = sanitize_text(translated_text) + return new_text diff --git a/TTS/streamlabs_polly.py b/TTS/streamlabs_polly.py index 75c4f49..3cc8cb9 100644 --- a/TTS/streamlabs_polly.py +++ b/TTS/streamlabs_polly.py @@ -1,7 +1,10 @@ -import random import requests from requests.exceptions import JSONDecodeError from utils import settings +from attr import attrs, attrib +from attr.validators import instance_of + +from TTS.common import BaseApiTTS, get_random_voice from utils.voice import check_ratelimit voices = [ @@ -26,37 +29,52 @@ voices = [ # valid voices https://lazypy.ro/tts/ -class StreamlabsPolly: - def __init__(self): - self.url = "https://streamlabs.com/polly/speak" - self.max_chars = 550 - self.voices = voices +@attrs +class StreamlabsPolly(BaseApiTTS): + random_voice: bool = attrib( + validator=instance_of(bool), + default=False + ) + url: str = "https://streamlabs.com/polly/speak" + max_chars: int = 550 - def run(self, text, filepath, random_voice: bool = False): - if random_voice: - voice = self.randomvoice() - else: - if not settings.config["settings"]["tts"]["streamlabs_polly_voice"]: - raise ValueError( - f"Please set the config variable STREAMLABS_POLLY_VOICE to a valid voice. options are: {voices}" - ) - voice = str(settings.config["settings"]["tts"]["streamlabs_polly_voice"]).capitalize() - body = {"voice": voice, "text": text, "service": "polly"} - response = requests.post(self.url, data=body) - if not check_ratelimit(response): - self.run(text, filepath, random_voice) + def make_request( + self, + text, + ): + """ + Makes a requests to remote TTS service + Args: + text: text to be voice over + + Returns: + Request's response + """ + voice = ( + get_random_voice(voices) + if self.random_voice + else str(settings.config["settings"]["tts"]["streamlabs_polly_voice"]).capitalize() + if str(settings.config["settings"]["tts"]["streamlabs_polly_voice"]).lower() in [ + voice.lower() for voice in voices] + else get_random_voice(voices) + ) + response = requests.post( + self.url, + data={ + "voice": voice, + "text": text, + "service": "polly", + }) + if not check_ratelimit(response): + return self.make_request(text) else: try: - voice_data = requests.get(response.json()["speak_url"]) - with open(filepath, "wb") as f: - f.write(voice_data.content) + results = requests.get(response.json()["speak_url"]) + return results.content except (KeyError, JSONDecodeError): try: if response.json()["error"] == "No text specified!": raise ValueError("Please specify a text to convert to speech.") except (KeyError, JSONDecodeError): print("Error occurred calling Streamlabs Polly") - - def randomvoice(self): - return random.choice(self.voices) diff --git a/install.sh b/install.sh index f85c80b..8ea94ae 100644 --- a/install.sh +++ b/install.sh @@ -12,7 +12,7 @@ function Help(){ echo "Options:" echo " -h: Show this help message and exit" echo " -d: Install only dependencies" - echo " -p: Install only python dependencies (including playwright)" + echo " -p: Install only python dependencies (including playwright)" echo " -b: Install just the bot" echo " -l: Install the bot and the python dependencies" } @@ -112,20 +112,20 @@ function install_python_dep(){ # install playwright function function install_playwright(){ - # tell the user that the script is going to install playwright + # tell the user that the script is going to install playwright echo "Installing playwright" # cd into the directory where the script is downloaded cd RedditVideoMakerBot-master # run the install script - python3 -m playwright install - python3 -m playwright install-deps + python3 -m playwright install + python3 -m playwright install-deps # give a note printf "Note, if these gave any errors, playwright may not be officially supported on your OS, check this issues page for support\nhttps://github.com/microsoft/playwright/issues" if [ -x "$(command -v pacman)" ]; then printf "It seems you are on and Arch based distro.\nTry installing these from the AUR for playwright to run:\nenchant1.6\nicu66\nlibwebp052\n" fi cd .. -} +} # Install depndencies function install_deps(){ diff --git a/main.py b/main.py index 6725540..ac300c6 100755 --- a/main.py +++ b/main.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -import math +from asyncio import run from subprocess import Popen from os import name @@ -11,12 +11,10 @@ from utils.console import print_markdown, print_step from utils import settings from video_creation.background import ( - download_background, - chop_background_video, get_background_config, ) -from video_creation.final_video import make_final_video -from video_creation.screenshot_downloader import download_screenshots_of_reddit_posts +from video_creation.final_video import FinalVideo +from webdriver.web_engine import screenshot_factory from video_creation.voices import save_text_to_mp3 __VERSION__ = "2.3.1" @@ -39,24 +37,22 @@ print_markdown( print_step(f"You are using v{__VERSION__} of the bot") -def main(POST_ID=None): +async def main(POST_ID=None): cleanup() reddit_object = get_subreddit_threads(POST_ID) - length, number_of_comments = save_text_to_mp3(reddit_object) - length = math.ceil(length) - download_screenshots_of_reddit_posts(reddit_object, number_of_comments) + comments_created = save_text_to_mp3(reddit_object) + webdriver = screenshot_factory(config["settings"]["webdriver"]) + await webdriver(reddit_object, comments_created).download() bg_config = get_background_config() - download_background(bg_config) - chop_background_video(bg_config, length) - make_final_video(number_of_comments, length, reddit_object, bg_config) + FinalVideo().make(comments_created, reddit_object, bg_config) -def run_many(times): +async def run_many(times): for x in range(1, times + 1): print_step( f'on the {x}{("th", "st", "nd", "rd", "th", "th", "th", "th", "th", "th")[x % 10]} iteration of {times}' ) # correct 1st 2nd 3rd 4th 5th.... - main() + await main() Popen("cls" if name == "nt" else "clear", shell=True).wait() @@ -72,7 +68,9 @@ if __name__ == "__main__": config is False and exit() try: if config["settings"]["times_to_run"]: - run_many(config["settings"]["times_to_run"]) + run( + run_many(config["settings"]["times_to_run"]) + ) elif len(config["reddit"]["thread"]["post_id"].split("+")) > 1: for index, post_id in enumerate(config["reddit"]["thread"]["post_id"].split("+")): @@ -80,11 +78,13 @@ if __name__ == "__main__": print_step( f'on the {index}{("st" if index % 10 == 1 else ("nd" if index % 10 == 2 else ("rd" if index % 10 == 3 else "th")))} post of {len(config["reddit"]["thread"]["post_id"].split("+"))}' ) - main(post_id) + run( + main(post_id) + ) Popen("cls" if name == "nt" else "clear", shell=True).wait() else: main() - except KeyboardInterrupt: + except KeyboardInterrupt: # TODO won't work with async code shutdown() except ResponseException: # error for invalid credentials diff --git a/reddit/subreddit.py b/reddit/subreddit.py index 716a7fa..486447f 100644 --- a/reddit/subreddit.py +++ b/reddit/subreddit.py @@ -87,6 +87,7 @@ def get_subreddit_threads(POST_ID: str): content["thread_title"] = submission.title content["thread_post"] = submission.selftext content["thread_id"] = submission.id + content["is_nsfw"] = "nsfw" in submission.whitelist_status content["comments"] = [] for top_level_comment in submission.comments: diff --git a/requirements.txt b/requirements.txt index 7bccd0d..9684dc7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,5 +9,6 @@ requests==2.28.1 rich==12.5.1 toml==0.10.2 translators==5.3.1 - -Pillow~=9.1.1 +pyppeteer==1.0.2 +attrs==21.4.0 +Pillow~=9.1.1 \ No newline at end of file diff --git a/utils/.config.template.toml b/utils/.config.template.toml index 102a813..3e0a098 100644 --- a/utils/.config.template.toml +++ b/utils/.config.template.toml @@ -16,7 +16,8 @@ subreddit = { optional = false, regex = "[_0-9a-zA-Z]+$", nmin = 3, explanation post_id = { optional = true, default = "", regex = "^((?!://|://)[+a-zA-Z])*$", explanation = "Used if you want to use a specific post.", example = "urdtfx" } max_comment_length = { default = 500, optional = false, nmin = 10, nmax = 10000, type = "int", explanation = "max number of characters a comment can have. default is 500", example = 500, oob_error = "the max comment length should be between 10 and 10000" } post_lang = { default = "", optional = true, explanation = "The language you would like to translate to.", example = "es-cr" } -min_comments = { default = 20, optional = false, nmin = 15, type = "int", explanation = "The minimum number of comments a post should have to be included. default is 20", example = 29, oob_error = "the minimum number of comments should be between 15 and 999999" } +min_comments = { default = 20, optional = false, nmin = 15, type = "int", explanation = "The minimum number of comments a post should have to be included. default is 20", example = 29, oob_error = "the minimum number of comments must be at least 15" } + [settings] allow_nsfw = { optional = false, type = "bool", default = false, example = false, options = [true, false, @@ -30,6 +31,14 @@ transition = { optional = true, default = 0.2, example = 0.2, explanation = "Set storymode = { optional = true, type = "bool", default = false, example = false, options = [true, false, ], explanation = "not yet implemented" } +video_length = { optional = false, default = 50, example = 60, explanation = "Approximated final video length", type = "int", nmin = 15, oob_error = "15 seconds is short enought" } +time_before_first_picture = { optional = false, default = 0.5, example = 1.0, explanation = "Deley before first screenshot apears", type = "float", nmin = 0, oob_error = "Choose at least 0 second" } +time_before_tts = { optional = false, default = 0.5, example = 1.0, explanation = "Deley between screenshot and TTS", type = "float", nmin = 0, oob_error = "Choose at least 0 second" } +time_between_pictures = { optional = false, default = 0.5, example = 1.0, explanation = "Time between every screenshot", type = "float", nmin = 0, oob_error = "Choose at least 0 second" } +delay_before_end = { optional = false, default = 0.5, example = 1.0, explanation = "Deley before video ends", type = "float", nmin = 0, oob_error = "Choose at least 0 second" } +video_width = { optional = true, default = 1080, example = 1080, explanation = "Final video width", type = "int", nmin = 600, oob_error = "Choose at least 600 pixels wide" } +video_height = { optional = true, default = 1920, example = 1920, explanation = "Final video height", type = "int", nmin = 600, oob_error = "Choose at least 600 pixels long" } +webdriver = { optional = true, default = "pyppeteer", example = "pyppeteer", options = ["pyppeteer", "playwright"], explanation = "Driver used to take screenshots (use pyppeteer if you have some problems with playwright)"} [settings.background] background_choice = { optional = true, default = "minecraft", example = "minecraft", options = ["minecraft", "gta", "rocket-league", "motor-gta", "csgo-surf", "cluster-truck", ""], explanation = "Sets the background for the video" } diff --git a/utils/settings.py b/utils/settings.py index a9d7726..5f764e2 100755 --- a/utils/settings.py +++ b/utils/settings.py @@ -3,13 +3,13 @@ import toml from rich.console import Console import re -from typing import Tuple, Dict +from typing import Dict, Optional, Union from utils.console import handle_input console = Console() -config = dict # autocomplete +config: Optional[dict] = None # autocomplete def crawl(obj: dict, func=lambda x, y: print(x, y, end="\n"), path=None): @@ -108,7 +108,7 @@ def check_vars(path, checks): crawl_and_check(config, path, checks) -def check_toml(template_file, config_file) -> Tuple[bool, Dict]: +def check_toml(template_file, config_file) -> Union[bool, Dict]: global config config = None try: diff --git a/utils/subreddit.py b/utils/subreddit.py index c386868..3253099 100644 --- a/utils/subreddit.py +++ b/utils/subreddit.py @@ -9,6 +9,7 @@ def get_subreddit_undone(submissions: list, subreddit, times_checked=0): """_summary_ Args: + times_checked: (int): For internal use, number of times function was called submissions (list): List of posts that are going to potentially be generated into a video subreddit (praw.Reddit.SubredditHelper): Chosen subreddit @@ -34,22 +35,23 @@ def get_subreddit_undone(submissions: list, subreddit, times_checked=0): if submission.stickied: print_substep("This post was pinned by moderators. Skipping...") continue - if submission.num_comments <= int(settings.config["reddit"]["thread"]["min_comments"]): + if submission.num_comments < int(settings.config["reddit"]["thread"]["min_comments"]): print_substep( - f'This post has under the specified minimum of comments ({settings.config["reddit"]["thread"]["min_comments"]}). Skipping...' + "This post has under the specified minimum of comments" + f'({settings.config["reddit"]["thread"]["min_comments"]}). Skipping...' ) continue return submission print("all submissions have been done going by top submission order") VALID_TIME_FILTERS = [ - "day", "hour", + "day", "month", "week", "year", "all", ] # set doesn't have __getitem__ - index = times_checked + 1 + index = times_checked + 1 if times_checked != 0 else times_checked if index == len(VALID_TIME_FILTERS): print("all time filters have been checked you absolute madlad ") diff --git a/utils/video.py b/utils/video.py index 63dc170..556693e 100644 --- a/utils/video.py +++ b/utils/video.py @@ -48,7 +48,7 @@ class Video: img_clip = img_clip.set_opacity(opacity).set_duration(duration) img_clip = img_clip.set_position( position, relative=True - ) # todo get dara from utils/CONSTANTS.py and adapt position accordingly + ) # todo get data from utils/CONSTANTS.py and adapt position accordingly # Overlay the img clip on the first video clip self.video = CompositeVideoClip([self.video, img_clip]) diff --git a/video_creation/data/cookie-dark-mode.json b/video_creation/data/cookie-dark-mode.json deleted file mode 100644 index 774f4cc..0000000 --- a/video_creation/data/cookie-dark-mode.json +++ /dev/null @@ -1,14 +0,0 @@ -[ - { - "name": "USER", - "value": "eyJwcmVmcyI6eyJ0b3BDb250ZW50RGlzbWlzc2FsVGltZSI6MCwiZ2xvYmFsVGhlbWUiOiJSRURESVQiLCJuaWdodG1vZGUiOnRydWUsImNvbGxhcHNlZFRyYXlTZWN0aW9ucyI6eyJmYXZvcml0ZXMiOmZhbHNlLCJtdWx0aXMiOmZhbHNlLCJtb2RlcmF0aW5nIjpmYWxzZSwic3Vic2NyaXB0aW9ucyI6ZmFsc2UsInByb2ZpbGVzIjpmYWxzZX0sInRvcENvbnRlbnRUaW1lc0Rpc21pc3NlZCI6MH19", - "domain": ".reddit.com", - "path": "/" - }, - { - "name": "eu_cookie", - "value": "{%22opted%22:true%2C%22nonessential%22:false}", - "domain": ".reddit.com", - "path": "/" - } -] diff --git a/video_creation/data/cookie-light-mode.json b/video_creation/data/cookie-light-mode.json deleted file mode 100644 index 048a3e3..0000000 --- a/video_creation/data/cookie-light-mode.json +++ /dev/null @@ -1,8 +0,0 @@ -[ - { - "name": "eu_cookie", - "value": "{%22opted%22:true%2C%22nonessential%22:false}", - "domain": ".reddit.com", - "path": "/" - } -] diff --git a/video_creation/final_video.py b/video_creation/final_video.py index fd12642..bbcea64 100755 --- a/video_creation/final_video.py +++ b/video_creation/final_video.py @@ -3,165 +3,286 @@ import multiprocessing import os import re from os.path import exists -from typing import Tuple, Any -from moviepy.audio.AudioClip import concatenate_audioclips, CompositeAudioClip -from moviepy.audio.io.AudioFileClip import AudioFileClip -from moviepy.video.VideoClip import ImageClip -from moviepy.video.compositing.CompositeVideoClip import CompositeVideoClip -from moviepy.video.compositing.concatenate import concatenate_videoclips -from moviepy.video.io.VideoFileClip import VideoFileClip -from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip +from typing import Tuple, Any, Union + +from moviepy.editor import ( + VideoFileClip, + AudioFileClip, + ImageClip, + CompositeAudioClip, + CompositeVideoClip, +) from rich.console import Console +from rich.progress import track +from attr import attrs from utils.cleanup import cleanup from utils.console import print_step, print_substep from utils.video import Video from utils.videos import save_data from utils import settings +from video_creation.background import download_background, chop_background_video + + +@attrs +class FinalVideo: + video_duration: int = 0 + console = Console() + + def __attrs_post_init__(self): + self.W: int = int(settings.config["settings"]["video_width"]) + self.H: int = int(settings.config["settings"]["video_height"]) + + if not self.W or not self.H: + self.W, self.H = 1080, 1920 + + self.vertical_video: bool = self.W < self.H + + self.max_length: int = int(settings.config["settings"]["video_length"]) + self.time_before_first_picture: float = settings.config["settings"]["time_before_first_picture"] + self.time_before_tts: float = settings.config["settings"]["time_before_tts"] + self.time_between_pictures: float = settings.config["settings"]["time_between_pictures"] + self.delay_before_end: float = settings.config["settings"]["delay_before_end"] + + self.opacity = settings.config["settings"]["opacity"] + self.opacity = 1 if self.opacity is None or self.opacity >= 1 else self.opacity + self.transition = settings.config["settings"]["transition"] + self.transition = 0 if self.transition is None or self.transition > 2 else self.transition + + @staticmethod + def name_normalize( + name: str + ) -> str: + name = re.sub(r'[?\\"%*:|<>]', "", name) + name = re.sub(r"( [w,W]\s?/\s?[oO0])", r" without", name) + name = re.sub(r"( [w,W]\s?/)", r" with", name) + name = re.sub(r"(\d+)\s?/\s?(\d+)", r"\1 of \2", name) + name = re.sub(r"(\w+)\s?/\s?(\w+)", r"\1 or \2", name) + name = re.sub(r"/", "", name) + + lang = settings.config["reddit"]["thread"]["post_lang"] + translated_name = None + if lang: + import translators as ts + + print_substep("Translating filename...") + translated_name = ts.google(name, to_language=lang) + return translated_name[:30] if translated_name else name[:30] + + @staticmethod + def create_audio_clip( + clip_title: Union[str, int], + clip_start: float, + ) -> AudioFileClip: + return ( + AudioFileClip(f"assets/temp/mp3/{clip_title}.mp3") + .set_start(clip_start) + ) + + def create_image_clip( + self, + image_title: Union[str, int], + audio_start: float, + audio_duration: float, + clip_position: str, + ) -> ImageClip: + return ( + ImageClip(f"assets/temp/png/{image_title}.png") + .set_start(audio_start - self.time_before_tts) + .set_duration(self.time_before_tts * 2 + audio_duration) + .set_opacity(self.opacity) + .set_position(clip_position) + .resize( + width=self.W - self.W / 20 if self.vertical_video else None, + height=self.H - self.H / 5 if not self.vertical_video else None, + ) + .crossfadein(self.transition) + .crossfadeout(self.transition) + ) + + def make( + self, + indexes_of_clips: list, + reddit_obj: dict, + background_config: Tuple[str, str, str, Any], + ) -> None: + """Gathers audio clips, gathers all screenshots, stitches them together and saves the final video to assets/temp + Args: + indexes_of_clips (list): Indexes of voiced comments + reddit_obj (dict): The reddit object that contains the posts to read. + background_config (Tuple[str, str, str, Any]): The background config to use. + """ + # try: # if it isn't found (i.e you just updated and copied over config.toml) it will throw an error + # VOLUME_MULTIPLIER = settings.config["settings"]['background']["background_audio_volume"] + # except (TypeError, KeyError): + # print('No background audio volume found in config.toml. Using default value of 1.') + # VOLUME_MULTIPLIER = 1 + print_step("Creating the final video 🎥") + VideoFileClip.reW = lambda clip: clip.resize(width=self.W) + VideoFileClip.reH = lambda clip: clip.resize(width=self.H) + + # Gather all audio clips + audio_clips = list() + correct_audio_offset = self.time_before_tts * 2 + self.time_between_pictures + + audio_title = self.create_audio_clip( + "title", + self.time_before_first_picture + self.time_before_tts, + ) + self.video_duration += audio_title.duration + self.time_before_first_picture + self.time_before_tts + audio_clips.append(audio_title) + indexes_for_videos = list() + + for audio_title in track( + indexes_of_clips, + description="Gathering audio clips...", + total=indexes_of_clips.__len__() + ): + temp_audio_clip = self.create_audio_clip( + audio_title, + correct_audio_offset + self.video_duration, + ) + if self.video_duration + temp_audio_clip.duration + \ + correct_audio_offset + self.delay_before_end <= self.max_length: + self.video_duration += temp_audio_clip.duration + correct_audio_offset + audio_clips.append(temp_audio_clip) + indexes_for_videos.append(audio_title) + + self.video_duration += self.delay_before_end + self.time_before_tts + + # Can't use concatenate_audioclips here, it resets clips' start point + audio_composite = CompositeAudioClip(audio_clips) + + self.console.log("[bold green] Video Will Be: %.2f Seconds Long" % self.video_duration) + + # Gather all images + image_clips = list() + + # Accounting for title and other stuff if audio_clips + index_offset = 1 -console = Console() -W, H = 1080, 1920 - - -def name_normalize(name: str) -> str: - name = re.sub(r'[?\\"%*:|<>]', "", name) - name = re.sub(r"( [w,W]\s?\/\s?[o,O,0])", r" without", name) - name = re.sub(r"( [w,W]\s?\/)", r" with", name) - name = re.sub(r"(\d+)\s?\/\s?(\d+)", r"\1 of \2", name) - name = re.sub(r"(\w+)\s?\/\s?(\w+)", r"\1 or \2", name) - name = re.sub(r"\/", r"", name) - name[:30] - - lang = settings.config["reddit"]["thread"]["post_lang"] - if lang: - import translators as ts - - print_substep("Translating filename...") - translated_name = ts.google(name, to_language=lang) - return translated_name - - else: - return name - - -def make_final_video( - number_of_clips: int, - length: int, - reddit_obj: dict, - background_config: Tuple[str, str, str, Any], -): - """Gathers audio clips, gathers all screenshots, stitches them together and saves the final video to assets/temp - Args: - number_of_clips (int): Index to end at when going through the screenshots' - length (int): Length of the video - reddit_obj (dict): The reddit object that contains the posts to read. - background_config (Tuple[str, str, str, Any]): The background config to use. - """ - # try: # if it isn't found (i.e you just updated and copied over config.toml) it will throw an error - # VOLUME_MULTIPLIER = settings.config["settings"]['background']["background_audio_volume"] - # except (TypeError, KeyError): - # print('No background audio volume found in config.toml. Using default value of 1.') - # VOLUME_MULTIPLIER = 1 - print_step("Creating the final video 🎥") - VideoFileClip.reW = lambda clip: clip.resize(width=W) - VideoFileClip.reH = lambda clip: clip.resize(width=H) - opacity = settings.config["settings"]["opacity"] - transition = settings.config["settings"]["transition"] - background_clip = ( - VideoFileClip("assets/temp/background.mp4") - .without_audio() - .resize(height=H) - .crop(x1=1166.6, y1=0, x2=2246.6, y2=1920) - ) - - # Gather all audio clips - audio_clips = [AudioFileClip(f"assets/temp/mp3/{i}.mp3") for i in range(number_of_clips)] - audio_clips.insert(0, AudioFileClip("assets/temp/mp3/title.mp3")) - audio_concat = concatenate_audioclips(audio_clips) - audio_composite = CompositeAudioClip([audio_concat]) - - console.log(f"[bold green] Video Will Be: {length} Seconds Long") - # add title to video - image_clips = [] - # Gather all images - new_opacity = 1 if opacity is None or float(opacity) >= 1 else float(opacity) - new_transition = 0 if transition is None or float(transition) > 2 else float(transition) - image_clips.insert( - 0, - ImageClip("assets/temp/png/title.png") - .set_duration(audio_clips[0].duration) - .resize(width=W - 100) - .set_opacity(new_opacity) - .crossfadein(new_transition) - .crossfadeout(new_transition), - ) - - for i in range(0, number_of_clips): image_clips.append( - ImageClip(f"assets/temp/png/comment_{i}.png") - .set_duration(audio_clips[i + 1].duration) - .resize(width=W - 100) - .set_opacity(new_opacity) - .crossfadein(new_transition) - .crossfadeout(new_transition) + self.create_image_clip( + "title", + audio_clips[0].start, + audio_clips[0].duration, + background_config[3], + ) ) - # if os.path.exists("assets/mp3/posttext.mp3"): - # image_clips.insert( - # 0, - # ImageClip("assets/png/title.png") - # .set_duration(audio_clips[0].duration + audio_clips[1].duration) - # .set_position("center") - # .resize(width=W - 100) - # .set_opacity(float(opacity)), - # ) - # else: story mode stuff - img_clip_pos = background_config[3] - image_concat = concatenate_videoclips(image_clips).set_position( - img_clip_pos - ) # note transition kwarg for delay in imgs - image_concat.audio = audio_composite - final = CompositeVideoClip([background_clip, image_concat]) - title = re.sub(r"[^\w\s-]", "", reddit_obj["thread_title"]) - idx = re.sub(r"[^\w\s-]", "", reddit_obj["thread_id"]) - - filename = f"{name_normalize(title)}.mp4" - subreddit = settings.config["reddit"]["thread"]["subreddit"] - - if not exists(f"./results/{subreddit}"): - print_substep("The results folder didn't exist so I made it") - os.makedirs(f"./results/{subreddit}") - - # if settings.config["settings"]['background']["background_audio"] and exists(f"assets/backgrounds/background.mp3"): - # audioclip = mpe.AudioFileClip(f"assets/backgrounds/background.mp3").set_duration(final.duration) - # audioclip = audioclip.fx( volumex, 0.2) - # final_audio = mpe.CompositeAudioClip([final.audio, audioclip]) - # # lowered_audio = audio_background.multiply_volume( # todo get this to work - # # VOLUME_MULTIPLIER) # lower volume by background_audio_volume, use with fx - # final.set_audio(final_audio) - final = Video(final).add_watermark( - text=f"Background credit: {background_config[2]}", opacity=0.4 - ) - final.write_videofile( - "assets/temp/temp.mp4", - fps=30, - audio_codec="aac", - audio_bitrate="192k", - verbose=False, - threads=multiprocessing.cpu_count(), - ) - ffmpeg_extract_subclip( - "assets/temp/temp.mp4", - 0, - length, - targetname=f"results/{subreddit}/{filename}", - ) - save_data(subreddit, filename, title, idx, background_config[2]) - print_step("Removing temporary files 🗑") - cleanups = cleanup() - print_substep(f"Removed {cleanups} temporary files 🗑") - print_substep("See result in the results folder!") - - print_step( - f'Reddit title: {reddit_obj["thread_title"]} \n Background Credit: {background_config[2]}' - ) + for idx, photo_idx in track( + enumerate( + indexes_for_videos, + start=index_offset, + ), + description="Gathering audio clips...", + total=indexes_for_videos.__len__() + ): + image_clips.append( + self.create_image_clip( + f"comment_{photo_idx}", + audio_clips[idx].start, + audio_clips[idx].duration, + background_config[3], + ) + ) + + # if os.path.exists("assets/mp3/posttext.mp3"): + # image_clips.insert( + # 0, + # ImageClip("assets/png/title.png") + # .set_duration(audio_clips[0].duration + audio_clips[1].duration) + # .set_position("center") + # .resize(width=W - 100) + # .set_opacity(float(opacity)), + # ) + # else: story mode stuff + + # Can't use concatenate_videoclips here, it resets clips' start point + + download_background(background_config) + chop_background_video(background_config, self.video_duration) + background_clip = ( + VideoFileClip("assets/temp/background.mp4") + .set_start(0) + .set_end(self.video_duration) + .without_audio() + .resize(height=self.H) + ) + + back_video_width, back_video_height = background_clip.size + + # Fix for crop with vertical videos + if back_video_width < self.H: + background_clip = ( + background_clip + .resize(width=self.W) + ) + back_video_width, back_video_height = background_clip.size + background_clip = background_clip.crop( + x1=0, + x2=back_video_width, + y1=back_video_height / 2 - self.H / 2, + y2=back_video_height / 2 + self.H / 2 + ) + else: + background_clip = background_clip.crop( + x1=back_video_width / 2 - self.W / 2, + x2=back_video_width / 2 + self.W / 2, + y1=0, + y2=back_video_height + ) + + final = CompositeVideoClip([background_clip, *image_clips]) + final.audio = audio_composite + + title = re.sub(r"[^\w\s-]", "", reddit_obj["thread_title"]) + idx = re.sub(r"[^\w\s-]", "", reddit_obj["thread_id"]) + + filename = f"{self.name_normalize(title)}.mp4" + subreddit = str(settings.config["reddit"]["thread"]["subreddit"]) + + if not exists(f"./results/{subreddit}"): + print_substep("The results folder didn't exist so I made it") + os.makedirs(f"./results/{subreddit}") + + # if ( + # settings.config["settings"]['background']["background_audio"] and + # exists(f"assets/backgrounds/background.mp3") + # ): + # audioclip = ( + # AudioFileClip(f"assets/backgrounds/background.mp3") + # .set_duration(final.duration) + # .volumex(0.2) + # ) + # final_audio = CompositeAudioClip([final.audio, audioclip]) + # # lowered_audio = audio_background.multiply_volume( # TODO get this to work + # # VOLUME_MULTIPLIER) # lower volume by background_audio_volume, use with fx + # final.set_audio(final_audio) + + final = Video(final).add_watermark( + text=f"Background credit: {background_config[2]}", opacity=0.4 + ) + + final.write_videofile( + "assets/temp/temp.mp4", + fps=30, + audio_codec="aac", + audio_bitrate="192k", + verbose=False, + threads=multiprocessing.cpu_count(), + ) + # Moves file in subreddit folder and renames it + os.rename( + "assets/temp/temp.mp4", + f"results/{subreddit}/{filename}", + ) + save_data(subreddit, filename, title, idx, background_config[2]) + print_step("Removing temporary files 🗑") + cleanups = cleanup() + print_substep(f"Removed {cleanups} temporary files 🗑") + print_substep("See result in the results folder!") + + print_step( + f'Reddit title: {reddit_obj["thread_title"]} \n Background Credit: {background_config[2]}' + ) diff --git a/video_creation/screenshot_downloader.py b/video_creation/screenshot_downloader.py deleted file mode 100644 index 7898e8d..0000000 --- a/video_creation/screenshot_downloader.py +++ /dev/null @@ -1,114 +0,0 @@ -import json - -from pathlib import Path -from typing import Dict -from utils import settings -from playwright.async_api import async_playwright # pylint: disable=unused-import - -# do not remove the above line - -from playwright.sync_api import sync_playwright, ViewportSize -from rich.progress import track -import translators as ts - -from utils.console import print_step, print_substep - -storymode = False - - -def download_screenshots_of_reddit_posts(reddit_object: dict, screenshot_num: int): - """Downloads screenshots of reddit posts as seen on the web. Downloads to assets/temp/png - - Args: - reddit_object (Dict): Reddit object received from reddit/subreddit.py - screenshot_num (int): Number of screenshots to download - """ - print_step("Downloading screenshots of reddit posts...") - - # ! Make sure the reddit screenshots folder exists - Path("assets/temp/png").mkdir(parents=True, exist_ok=True) - - with sync_playwright() as p: - print_substep("Launching Headless Browser...") - - browser = p.chromium.launch() - context = browser.new_context() - - if settings.config["settings"]["theme"] == "dark": - cookie_file = open("./video_creation/data/cookie-dark-mode.json", encoding="utf-8") - else: - cookie_file = open("./video_creation/data/cookie-light-mode.json", encoding="utf-8") - cookies = json.load(cookie_file) - context.add_cookies(cookies) # load preference cookies - # Get the thread screenshot - page = context.new_page() - page.goto(reddit_object["thread_url"], timeout=0) - page.set_viewport_size(ViewportSize(width=1920, height=1080)) - if page.locator('[data-testid="content-gate"]').is_visible(): - # This means the post is NSFW and requires to click the proceed button. - - print_substep("Post is NSFW. You are spicy...") - page.locator('[data-testid="content-gate"] button').click() - page.wait_for_load_state() # Wait for page to fully load - - if page.locator('[data-click-id="text"] button').is_visible(): - page.locator( - '[data-click-id="text"] button' - ).click() # Remove "Click to see nsfw" Button in Screenshot - - # translate code - - if settings.config["reddit"]["thread"]["post_lang"]: - print_substep("Translating post...") - texts_in_tl = ts.google( - reddit_object["thread_title"], - to_language=settings.config["reddit"]["thread"]["post_lang"], - ) - - page.evaluate( - "tl_content => document.querySelector('[data-test-id=\"post-content\"] > div:nth-child(3) > div > div').textContent = tl_content", - texts_in_tl, - ) - else: - print_substep("Skipping translation...") - - page.locator('[data-test-id="post-content"]').screenshot(path="assets/temp/png/title.png") - - if storymode: - page.locator('[data-click-id="text"]').screenshot( - path="assets/temp/png/story_content.png" - ) - else: - for idx, comment in enumerate( - track(reddit_object["comments"], "Downloading screenshots...") - ): - # Stop if we have reached the screenshot_num - if idx >= screenshot_num: - break - - if page.locator('[data-testid="content-gate"]').is_visible(): - page.locator('[data-testid="content-gate"] button').click() - - page.goto(f'https://reddit.com{comment["comment_url"]}', timeout=0) - - # translate code - - if settings.config["reddit"]["thread"]["post_lang"]: - comment_tl = ts.google( - comment["comment_body"], - to_language=settings.config["reddit"]["thread"]["post_lang"], - ) - page.evaluate( - '([tl_content, tl_id]) => document.querySelector(`#t1_${tl_id} > div:nth-child(2) > div > div[data-testid="comment"] > div`).textContent = tl_content', - [comment_tl, comment["comment_id"]], - ) - try: - page.locator(f"#t1_{comment['comment_id']}").screenshot( - path=f"assets/temp/png/comment_{idx}.png" - ) - except TimeoutError: - del reddit_object["comments"] - screenshot_num += 1 - print("TimeoutError: Skipping screenshot...") - continue - print_substep("Screenshots downloaded Successfully.", style="bold green") diff --git a/video_creation/voices.py b/video_creation/voices.py index ac33dd7..1e5a1a5 100644 --- a/video_creation/voices.py +++ b/video_creation/voices.py @@ -1,9 +1,3 @@ -#!/usr/bin/env python - -from typing import Dict, Tuple - -from rich.console import Console - from TTS.engine_wrapper import TTSEngine from TTS.GTTS import GTTS from TTS.streamlabs_polly import StreamlabsPolly @@ -13,8 +7,6 @@ from utils import settings from utils.console import print_table, print_step -console = Console() - TTSProviders = { "GoogleTranslate": GTTS, "AWSPolly": AWSPolly, @@ -23,29 +15,29 @@ TTSProviders = { } -def save_text_to_mp3(reddit_obj) -> Tuple[int, int]: +def save_text_to_mp3( + reddit_obj: dict, +) -> list: """Saves text to MP3 files. Args: reddit_obj (): Reddit object received from reddit API in reddit/subreddit.py Returns: - tuple[int,int]: (total length of the audio, the number of comments audio was generated for) + The number of comments audio was generated for """ voice = settings.config["settings"]["tts"]["choice"] - if str(voice).casefold() in map(lambda _: _.casefold(), TTSProviders): - text_to_mp3 = TTSEngine(get_case_insensitive_key_value(TTSProviders, voice), reddit_obj) - else: + if voice.casefold() not in map(lambda _: _.casefold(), TTSProviders): while True: print_step("Please choose one of the following TTS providers: ") print_table(TTSProviders) - choice = input("\n") - if choice.casefold() in map(lambda _: _.casefold(), TTSProviders): + voice = input("\n") + if voice.casefold() in map(lambda _: _.casefold(), TTSProviders): break print("Unknown Choice") - text_to_mp3 = TTSEngine(get_case_insensitive_key_value(TTSProviders, choice), reddit_obj) - return text_to_mp3.run() + engine_instance = TTSEngine(get_case_insensitive_key_value(TTSProviders, voice), reddit_obj) + return engine_instance.run() def get_case_insensitive_key_value(input_dict, key): diff --git a/webdriver/__init__.py b/webdriver/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/webdriver/common.py b/webdriver/common.py new file mode 100644 index 0000000..0061b3b --- /dev/null +++ b/webdriver/common.py @@ -0,0 +1,84 @@ +from attr import attrs, attrib +from typing import TypeVar, Optional, Callable, Union + + +_function = TypeVar("_function", bound=Callable[..., object]) +_exceptions = TypeVar("_exceptions", bound=Optional[Union[type, tuple, list]]) + +default_exception = None + + +@attrs +class ExceptionDecorator: + """ + Decorator for catching exceptions and writing logs + """ + exception: Optional[_exceptions] = attrib(default=None) + + def __attrs_post_init__(self): + if not self.exception: + self.exception = default_exception + + def __call__( + self, + func: _function, + ): + async def wrapper(*args, **kwargs): + try: + obj_to_return = await func(*args, **kwargs) + return obj_to_return + except Exception as caughtException: + import logging + + logger = logging.getLogger("webdriver_log") + logger.setLevel(logging.ERROR) + handler = logging.FileHandler(".webdriver.log", mode="a+", encoding="utf-8") + logger.addHandler(handler) + + if isinstance(self.exception, type): + if not type(caughtException) == self.exception: + logger.error(f"unexpected error - {caughtException}") + else: + if not type(caughtException) in self.exception: + logger.error(f"unexpected error - {caughtException}") + + return wrapper + + +def catch_exception( + func: Optional[_function], + exception: Optional[_exceptions] = None, +) -> Union[object, _function]: + """ + Decorator for catching exceptions and writing logs + + Args: + func: Function to be decorated + exception: Expected exception(s) + Returns: + Decorated function + """ + exceptor = ExceptionDecorator(exception) + if func: + exceptor = exceptor(func) + return exceptor + + +# Lots of tabs - lots of memory +# chunk needed to minimize memory required +def chunks( + array: list, + size: int, +): + """ + Yield successive n-sized chunks from list. + + Args: + array: List to be chunked + size: size of a chunk + + Returns: + Generator with chunked list + """ + for i in range(0, len(array), size): + yield array[i:i + size] diff --git a/webdriver/playwright.py b/webdriver/playwright.py new file mode 100644 index 0000000..30ea32c --- /dev/null +++ b/webdriver/playwright.py @@ -0,0 +1,332 @@ +from asyncio import as_completed +from pathlib import Path +from typing import Dict, Optional + +import translators as ts +from attr import attrs, attrib +from attr.validators import instance_of +from playwright.async_api import Browser, Playwright, Page, BrowserContext, ElementHandle +from playwright.async_api import async_playwright, TimeoutError +from rich.progress import track + +from utils import settings +from utils.console import print_step, print_substep + +import webdriver.common as common + +common.default_exception = TimeoutError + + +@attrs +class Browser: + """ + Args: + default_Viewport (dict):Pyppeteer Browser default_Viewport options + browser (BrowserCls): Pyppeteer Browser instance + """ + default_Viewport: dict = attrib( + validator=instance_of(dict), + default={ + # 9x21 to see long posts + "width": 500, + "height": 1200, + }, + kw_only=True, + ) + playwright: Playwright + browser: Browser + context: BrowserContext + + async def get_browser( + self, + ) -> None: + """ + Creates Playwright instance & browser + """ + self.playwright = await async_playwright().start() + self.browser = await self.playwright.chromium.launch() + self.context = await self.browser.new_context(viewport=self.default_Viewport) + + async def close_browser( + self, + ) -> None: + """ + Closes Playwright stuff + """ + await self.context.close() + await self.browser.close() + await self.playwright.stop() + + +class Flaky: + """ + All methods decorated with function catching default exceptions and writing logs + """ + + @staticmethod + @common.catch_exception + async def find_element( + selector: str, + page_instance: Page, + options: Optional[dict] = None, + ) -> ElementHandle: + return ( + await page_instance.wait_for_selector(selector, **options) + if options + else await page_instance.wait_for_selector(selector) + ) + + @common.catch_exception + async def click( + self, + page_instance: Optional[Page] = None, + query: Optional[str] = None, + options: Optional[dict] = None, + *, + find_options: Optional[dict] = None, + element: Optional[ElementHandle] = None, + ) -> None: + if element: + await element.click(**options) if options else await element.click() + else: + results = ( + await self.find_element(query, page_instance, **find_options) + if find_options + else await self.find_element(query, page_instance) + ) + await results.click(**options) if options else await results.click() + + @common.catch_exception + async def screenshot( + self, + page_instance: Optional[Page] = None, + query: Optional[str] = None, + options: Optional[dict] = None, + *, + find_options: Optional[dict] = None, + element: Optional[ElementHandle] = None, + ) -> None: + if element: + await element.screenshot(**options) if options else await element.screenshot() + else: + results = ( + await self.find_element(query, page_instance, **find_options) + if find_options + else await self.find_element(query, page_instance) + ) + await results.screenshot(**options) if options else await results.screenshot() + + +@attrs(auto_attribs=True) +class RedditScreenshot(Flaky, Browser): + """ + Args: + reddit_object (Dict): Reddit object received from reddit/subreddit.py + screenshot_idx (int): List with indexes of voiced comments + story_mode (bool): If submission is a story takes screenshot of the story + """ + reddit_object: dict + screenshot_idx: list + story_mode: Optional[bool] = attrib( + validator=instance_of(bool), + default=False, + kw_only=True + ) + + def __attrs_post_init__( + self + ): + self.post_lang: Optional[bool] = settings.config["reddit"]["thread"]["post_lang"] + + async def __dark_theme( # TODO isn't working + self, + page_instance: Page, + ) -> None: + """ + Enables dark theme in Reddit + + Args: + page_instance: Pyppeteer page instance with reddit page opened + """ + + await self.click( + page_instance, + ".header-user-dropdown", + ) + + # It's normal not to find it, sometimes there is none :shrug: + await self.click( + page_instance, + "button >> span:has-text('Settings')", + ) + + await self.click( + page_instance, + "button >> span:has-text('Dark Mode')", + ) + + # Closes settings + await self.click( + page_instance, + ".header-user-dropdown" + ) + + async def __close_nsfw( + self, + page_instance: Page, + ) -> None: + """ + Closes NSFW stuff + + Args: + page_instance: Instance of main page + """ + + print_substep("Post is NSFW. You are spicy...") + + # Triggers indirectly reload + await self.click( + page_instance, + "button:has-text('Yes')", + {"timeout": 5000}, + ) + + # Await indirect reload + await page_instance.wait_for_load_state() + + await self.click( + page_instance, + "button:has-text('Click to see nsfw')", + {"timeout": 5000}, + ) + + async def __collect_comment( + self, + comment_obj: dict, + filename_idx: int, + ) -> None: + """ + Makes a screenshot of the comment + + Args: + comment_obj: prew comment object + filename_idx: index for the filename + """ + comment_page = await self.context.new_page() + await comment_page.goto(f'https://reddit.com{comment_obj["comment_url"]}') + + # Translates submission' comment + if self.post_lang: + comment_tl = ts.google( + comment_obj["comment_body"], + to_language=self.post_lang, + ) + await comment_page.evaluate( + '([comment_id, comment_tl]) => document.querySelector(`#t1_${comment_id} > div:nth-child(2) > div > div[data-testid="comment"] > div`).textContent = comment_tl', # noqa + [comment_obj["comment_id"], comment_tl], + ) + + await self.screenshot( + comment_page, + f"id=t1_{comment_obj['comment_id']}", + {"path": f"assets/temp/png/comment_{filename_idx}.png"}, + ) + + # WIP TODO test it + async def __collect_story( + self, + main_page: Page, + ): + # Translates submission text + if self.post_lang: + story_tl = ts.google( + self.reddit_object["thread_post"], + to_language=self.post_lang, + ) + split_story_tl = story_tl.split('\n') + + await main_page.evaluate( + "(split_story_tl) => split_story_tl.map(function(element, i) { return [element, document.querySelectorAll('[data-test-id=\"post-content\"] > [data-click-id=\"text\"] > div > p')[i]]; }).forEach(mappedElement => mappedElement[1].textContent = mappedElement[0])", # noqa + split_story_tl, + ) + + await self.screenshot( + main_page, + "//div[@data-test-id='post-content']//div[@data-click-id='text']", + {"path": "assets/temp/png/story_content.png"}, + ) + + async def download( + self, + ): + """ + Downloads screenshots of reddit posts as seen on the web. Downloads to assets/temp/png + """ + print_step("Downloading screenshots of reddit posts...") + + print_substep("Launching Headless Browser...") + await self.get_browser() + + # ! Make sure the reddit screenshots folder exists + Path("assets/temp/png").mkdir(parents=True, exist_ok=True) + + # Get the thread screenshot + reddit_main = await self.context.new_page() + await reddit_main.goto(self.reddit_object["thread_url"]) # noqa + + if settings.config["settings"]["theme"] == "dark": + await self.__dark_theme(reddit_main) + + if self.reddit_object["is_nsfw"]: + # This means the post is NSFW and requires to click the proceed button. + await self.__close_nsfw(reddit_main) + + # Translates submission title + if self.post_lang: + print_substep("Translating post...") + texts_in_tl = ts.google( + self.reddit_object["thread_title"], + to_language=self.post_lang, + ) + + await reddit_main.evaluate( + f"(texts_in_tl) => document.querySelector('[data-test-id=\"post-content\"] > div:nth-child(3) > div > div').textContent = texts_in_tl", # noqa + texts_in_tl, + ) + else: + print_substep("Skipping translation...") + + # No sense to move it to common.py + async_tasks_primary = ( # noqa + [ + self.__collect_comment(self.reddit_object["comments"][idx], idx) for idx in + self.screenshot_idx + ] + if not self.story_mode + else [ + self.__collect_story(reddit_main) + ] + ) + + async_tasks_primary.append( + self.screenshot( + reddit_main, + f"id=t3_{self.reddit_object['thread_id']}", + {"path": "assets/temp/png/title.png"}, + ) + ) + + for idx, chunked_tasks in enumerate( + [chunk for chunk in common.chunks(async_tasks_primary, 10)], + start=1, + ): + chunk_list = async_tasks_primary.__len__() // 10 + (1 if async_tasks_primary.__len__() % 10 != 0 else 0) + for task in track( + as_completed(chunked_tasks), + description=f"Downloading comments: Chunk {idx}/{chunk_list}", + total=chunked_tasks.__len__(), + ): + await task + + print_substep("Comments downloaded Successfully.", style="bold green") + await self.close_browser() diff --git a/webdriver/pyppeteer.py b/webdriver/pyppeteer.py new file mode 100644 index 0000000..d38c3a3 --- /dev/null +++ b/webdriver/pyppeteer.py @@ -0,0 +1,371 @@ +from asyncio import as_completed + +from pyppeteer import launch +from pyppeteer.page import Page as PageCls +from pyppeteer.browser import Browser as BrowserCls +from pyppeteer.element_handle import ElementHandle as ElementHandleCls +from pyppeteer.errors import TimeoutError as BrowserTimeoutError + +from pathlib import Path +from utils import settings +from utils.console import print_step, print_substep +from rich.progress import track +import translators as ts + +from attr import attrs, attrib +from attr.validators import instance_of +from typing import Optional + +import webdriver.common as common + +common.default_exception = BrowserTimeoutError + + +@attrs +class Browser: + """ + Args: + default_Viewport (dict):Pyppeteer Browser default_Viewport options + browser (BrowserCls): Pyppeteer Browser instance + """ + default_Viewport: dict = attrib( + validator=instance_of(dict), + default={ + # 9x21 to see long posts + "defaultViewport": { + "width": 500, + "height": 1200, + }, + }, + kw_only=True, + ) + browser: BrowserCls + + async def get_browser( + self, + ) -> None: + """ + Creates Pyppeteer browser + """ + self.browser = await launch(self.default_Viewport) + + async def close_browser( + self, + ) -> None: + """ + Closes Pyppeteer browser + """ + await self.browser.close() + + +class Wait: + @staticmethod + @common.catch_exception + async def find_xpath( + page_instance: PageCls, + xpath: Optional[str] = None, + options: Optional[dict] = None, + ) -> 'ElementHandleCls': + """ + Explicitly finds element on the page + + Args: + page_instance: Pyppeteer page instance + xpath: xpath query + options: Pyppeteer waitForXPath parameters + + Available options are: + + * ``visible`` (bool): wait for element to be present in DOM and to be + visible, i.e. to not have ``display: none`` or ``visibility: hidden`` + CSS properties. Defaults to ``False``. + * ``hidden`` (bool): wait for element to not be found in the DOM or to + be hidden, i.e. have ``display: none`` or ``visibility: hidden`` CSS + properties. Defaults to ``False``. + * ``timeout`` (int|float): maximum time to wait for in milliseconds. + Defaults to 30000 (30 seconds). Pass ``0`` to disable timeout. + Returns: + Pyppeteer element instance + """ + if options: + el = await page_instance.waitForXPath(xpath, options=options) + else: + el = await page_instance.waitForXPath(xpath) + return el + + @common.catch_exception + async def click( + self, + page_instance: Optional[PageCls] = None, + xpath: Optional[str] = None, + options: Optional[dict] = None, + *, + find_options: Optional[dict] = None, + el: Optional[ElementHandleCls] = None, + ) -> None: + """ + Clicks on the element + + Args: + page_instance: Pyppeteer page instance + xpath: xpath query + find_options: Pyppeteer waitForXPath parameters + options: Pyppeteer click parameters + el: Pyppeteer element instance + """ + if not el: + el = await self.find_xpath(page_instance, xpath, find_options) + if options: + await el.click(options) + else: + await el.click() + + @common.catch_exception + async def screenshot( + self, + page_instance: Optional[PageCls] = None, + xpath: Optional[str] = None, + options: Optional[dict] = None, + *, + find_options: Optional[dict] = None, + el: Optional[ElementHandleCls] = None, + ) -> None: + """ + Makes a screenshot of the element + + Args: + page_instance: Pyppeteer page instance + xpath: xpath query + options: Pyppeteer screenshot parameters + find_options: Pyppeteer waitForXPath parameters + el: Pyppeteer element instance + """ + if not el: + el = await self.find_xpath(page_instance, xpath, find_options) + if options: + await el.screenshot(options) + else: + await el.screenshot() + + +@attrs(auto_attribs=True) +class RedditScreenshot(Browser, Wait): + """ + Args: + reddit_object (Dict): Reddit object received from reddit/subreddit.py + screenshot_idx (int): List with indexes of voiced comments + story_mode (bool): If submission is a story takes screenshot of the story + """ + reddit_object: dict + screenshot_idx: list + story_mode: Optional[bool] = attrib( + validator=instance_of(bool), + default=False, + kw_only=True + ) + + def __attrs_post_init__( + self, + ): + self.post_lang: Optional[bool] = settings.config["reddit"]["thread"]["post_lang"] + + async def __dark_theme( + self, + page_instance: PageCls, + ) -> None: + """ + Enables dark theme in Reddit + + Args: + page_instance: Pyppeteer page instance with reddit page opened + """ + + await self.click( + page_instance, + "//div[@class='header-user-dropdown']", + find_options={"timeout": 5000}, + ) + + # It's normal not to find it, sometimes there is none :shrug: + await self.click( + page_instance, + "//span[text()='Settings']/ancestor::button[1]", + find_options={"timeout": 5000}, + ) + + await self.click( + page_instance, + "//span[text()='Dark Mode']/ancestor::button[1]", + find_options={"timeout": 5000}, + ) + + # Closes settings + await self.click( + page_instance, + "//div[@class='header-user-dropdown']", + find_options={"timeout": 5000}, + ) + + async def __close_nsfw( + self, + page_instance: PageCls, + ) -> None: + """ + Closes NSFW stuff + + Args: + page_instance: Instance of main page + """ + + from asyncio import ensure_future + + print_substep("Post is NSFW. You are spicy...") + # To await indirectly reload + navigation = ensure_future(page_instance.waitForNavigation()) + + # Triggers indirectly reload + await self.click( + page_instance, + '//button[text()="Yes"]', + find_options={"timeout": 5000}, + ) + + # Await reload + await navigation + + await self.click( + page_instance, + '//button[text()="Click to see nsfw"]', + find_options={"timeout": 5000}, + ) + + async def __collect_comment( + self, + comment_obj: dict, + filename_idx: int, + ) -> None: + """ + Makes a screenshot of the comment + + Args: + comment_obj: prew comment object + filename_idx: index for the filename + """ + comment_page = await self.browser.newPage() + await comment_page.goto(f'https://reddit.com{comment_obj["comment_url"]}') + + # Translates submission' comment + if self.post_lang: + comment_tl = ts.google( + comment_obj["comment_body"], + to_language=self.post_lang, + ) + await comment_page.evaluate( + '([comment_id, comment_tl]) => document.querySelector(`#t1_${comment_id} > div:nth-child(2) > div > div[data-testid="comment"] > div`).textContent = comment_tl', # noqa + [comment_obj["comment_id"], comment_tl], + ) + + await self.screenshot( + comment_page, + f"//div[@id='t1_{comment_obj['comment_id']}']", + {"path": f"assets/temp/png/comment_{filename_idx}.png"}, + ) + + # WIP TODO test it + async def __collect_story( + self, + main_page: PageCls, + ): + # Translates submission text + if self.post_lang: + story_tl = ts.google( + self.reddit_object["thread_post"], + to_language=self.post_lang, + ) + split_story_tl = story_tl.split('\n') + + await main_page.evaluate( + "(split_story_tl) => split_story_tl.map(function(element, i) { return [element, document.querySelectorAll('[data-test-id=\"post-content\"] > [data-click-id=\"text\"] > div > p')[i]]; }).forEach(mappedElement => mappedElement[1].textContent = mappedElement[0])", # noqa + split_story_tl, + ) + + await self.screenshot( + main_page, + "//div[@data-test-id='post-content']//div[@data-click-id='text']", + {"path": "assets/temp/png/story_content.png"}, + ) + + async def download( + self, + ): + """ + Downloads screenshots of reddit posts as seen on the web. Downloads to assets/temp/png + """ + print_step("Downloading screenshots of reddit posts...") + + print_substep("Launching Headless Browser...") + await self.get_browser() + + # ! Make sure the reddit screenshots folder exists + Path("assets/temp/png").mkdir(parents=True, exist_ok=True) + + # Get the thread screenshot + reddit_main = await self.browser.newPage() + await reddit_main.goto(self.reddit_object["thread_url"]) # noqa + + if settings.config["settings"]["theme"] == "dark": + await self.__dark_theme(reddit_main) + + if self.reddit_object["is_nsfw"]: + # This means the post is NSFW and requires to click the proceed button. + await self.__close_nsfw(reddit_main) + + # Translates submission title + if self.post_lang: + print_substep("Translating post...") + texts_in_tl = ts.google( + self.reddit_object["thread_title"], + to_language=self.post_lang, + ) + + await reddit_main.evaluate( + f"(texts_in_tl) => document.querySelector('[data-test-id=\"post-content\"] > div:nth-child(3) > div > div').textContent = texts_in_tl", # noqa + texts_in_tl, + ) + else: + print_substep("Skipping translation...") + + # No sense to move it to common.py + async_tasks_primary = ( # noqa + [ + self.__collect_comment(self.reddit_object["comments"][idx], idx) for idx in + self.screenshot_idx + ] + if not self.story_mode + else [ + self.__collect_story(reddit_main) + ] + ) + + async_tasks_primary.append( + self.screenshot( + reddit_main, + f"//div[@data-testid='post-container']", + {"path": "assets/temp/png/title.png"}, + ) + ) + + for idx, chunked_tasks in enumerate( + [chunk for chunk in common.chunks(async_tasks_primary, 10)], + start=1, + ): + chunk_list = async_tasks_primary.__len__() // 10 + (1 if async_tasks_primary.__len__() % 10 != 0 else 0) + for task in track( + as_completed(chunked_tasks), + description=f"Downloading comments: Chunk {idx}/{chunk_list}", + total=chunked_tasks.__len__(), + ): + await task + + print_substep("Comments downloaded Successfully.", style="bold green") + await self.close_browser() diff --git a/webdriver/web_engine.py b/webdriver/web_engine.py new file mode 100644 index 0000000..42d5853 --- /dev/null +++ b/webdriver/web_engine.py @@ -0,0 +1,23 @@ +from typing import Union + +from webdriver.pyppeteer import RedditScreenshot as Pyppeteer +from webdriver.playwright import RedditScreenshot as Playwright + + +def screenshot_factory( + driver: str, +) -> Union[type(Pyppeteer), type(Playwright)]: + """ + Factory for webdriver + Args: + driver: (str) Name of a driver + + Returns: + Webdriver instance + """ + web_drivers = { + "pyppeteer": Pyppeteer, + "playwright": Playwright, + } + + return web_drivers[driver]