Merge pull request #963 from Drugsosos/move-to-async-webdriver

async WebDriver
New-WebDriver
Jason 2 years ago committed by GitHub
commit ade0e16510
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,23 +1,27 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import random
from utils import settings from utils import settings
from gtts import gTTS from gtts import gTTS
max_chars = 0
class GTTS: class GTTS:
def __init__(self): max_chars = 0
self.max_chars = 0 # voices = []
self.voices = []
@staticmethod
def run(
text,
filepath
) -> None:
"""
Calls for TTS api
def run(self, text, filepath): Args:
text: text to be voiced over
filepath: name of the audio file
"""
tts = gTTS( tts = gTTS(
text=text, text=text,
lang=settings.config["reddit"]["thread"]["post_lang"] or "en", lang=settings.config["reddit"]["thread"]["post_lang"] or "en",
slow=False, slow=False,
) )
tts.save(filepath) tts.save(filepath)
def randomvoice(self):
return random.choice(self.voices)

@ -1,15 +1,17 @@
import base64
from utils import settings from utils import settings
import random
import requests import requests
from requests.adapters import HTTPAdapter, Retry from requests.adapters import HTTPAdapter, Retry
# from profanity_filter import ProfanityFilter from attr import attrs, attrib
# pf = ProfanityFilter() from attr.validators import instance_of
# Code by @JasonLovesDoggo
# https://twitter.com/scanlime/status/1512598559769702406
nonhuman = [ # DISNEY VOICES from TTS.common import BaseApiTTS, get_random_voice
# TTS examples: https://twitter.com/scanlime/status/1512598559769702406
voices = dict()
voices["nonhuman"] = [ # DISNEY VOICES
"en_us_ghostface", # Ghost Face "en_us_ghostface", # Ghost Face
"en_us_chewbacca", # Chewbacca "en_us_chewbacca", # Chewbacca
"en_us_c3po", # C3PO "en_us_c3po", # C3PO
@ -18,7 +20,7 @@ nonhuman = [ # DISNEY VOICES
"en_us_rocket", # Rocket "en_us_rocket", # Rocket
# ENGLISH VOICES # ENGLISH VOICES
] ]
human = [ voices["human"] = [
"en_au_001", # English AU - Female "en_au_001", # English AU - Female
"en_au_002", # English AU - Male "en_au_002", # English AU - Male
"en_uk_001", # English UK - Male 1 "en_uk_001", # English UK - Male 1
@ -30,9 +32,8 @@ human = [
"en_us_009", # English US - Male 3 "en_us_009", # English US - Male 3
"en_us_010", "en_us_010",
] ]
voices = nonhuman + human
noneng = [ voices["non_eng"] = [
"fr_001", # French - Male 1 "fr_001", # French - Male 1
"fr_002", # French - Male 2 "fr_002", # French - Male 2
"de_001", # German - Female "de_001", # German - Female
@ -56,32 +57,51 @@ noneng = [
] ]
# good_voices = {'good': ['en_us_002', 'en_us_006'], # good_voices: 'en_us_002', 'en_us_006'
# 'ok': ['en_au_002', 'en_uk_001']} # less en_us_stormtrooper more less en_us_rocket en_us_ghostface # ok: 'en_au_002', 'en_uk_001'
# less: en_us_stormtrooper
# more or less: en_us_rocket, en_us_ghostface
class TikTok: # TikTok Text-to-Speech Wrapper @attrs
def __init__(self): class TikTok(BaseApiTTS): # TikTok Text-to-Speech Wrapper
self.URI_BASE = ( random_voice: bool = attrib(
"https://api16-normal-useast5.us.tiktokv.com/media/api/text/speech/invoke/?text_speaker=" validator=instance_of(bool),
) default=False
self.max_chars = 300 )
self.voices = {"human": human, "nonhuman": nonhuman, "noneng": noneng} uri_base: str = "https://api16-normal-useast5.us.tiktokv.com/media/api/text/speech/invoke/"
max_chars: int = 300
decode_base64: bool = True
def run(self, text, filepath, random_voice: bool = False): def make_request(
# if censor: self,
# req_text = pf.censor(req_text) text: str,
# pass ):
"""
Makes a requests to remote TTS service
Args:
text: text to be voice over
Returns:
Request's response
"""
voice = ( voice = (
self.randomvoice() get_random_voice(voices, "human")
if random_voice if self.random_voice
else ( else str(settings.config["settings"]["tts"]["tiktok_voice"]).lower()
settings.config["settings"]["tts"]["tiktok_voice"] if str(settings.config["settings"]["tts"]["tiktok_voice"]).lower() in [
or random.choice(self.voices["human"]) voice.lower() for dict_title in voices for voice in voices[dict_title]]
) else get_random_voice(voices, "human")
) )
try: try:
r = requests.post(f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0") r = requests.post(
self.uri_base,
params={
"text_speaker": voice,
"req_text": text,
"speaker_map_type": 0,
})
except requests.exceptions.SSLError: except requests.exceptions.SSLError:
# https://stackoverflow.com/a/47475019/18516611 # https://stackoverflow.com/a/47475019/18516611
session = requests.Session() session = requests.Session()
@ -89,13 +109,6 @@ class TikTok: # TikTok Text-to-Speech Wrapper
adapter = HTTPAdapter(max_retries=retry) adapter = HTTPAdapter(max_retries=retry)
session.mount("http://", adapter) session.mount("http://", adapter)
session.mount("https://", adapter) session.mount("https://", adapter)
r = session.post(f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0") r = session.post(f"{self.uri_base}{voice}&req_text={text}&speaker_map_type=0")
# print(r.text) # print(r.text)
vstr = [r.json()["data"]["v_str"]][0] return r.json()["data"]["v_str"]
b64d = base64.b64decode(vstr)
with open(filepath, "wb") as out:
out.write(b64d)
def randomvoice(self):
return random.choice(self.voices["human"])

@ -1,9 +1,14 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from boto3 import Session from boto3 import Session
from botocore.exceptions import BotoCoreError, ClientError, ProfileNotFound from botocore.exceptions import BotoCoreError, ClientError, ProfileNotFound
import sys import sys
from utils import settings from utils import settings
import random from attr import attrs, attrib
from attr.validators import instance_of
from TTS.common import get_random_voice
voices = [ voices = [
"Brian", "Brian",
@ -24,23 +29,37 @@ voices = [
] ]
@attrs
class AWSPolly: class AWSPolly:
def __init__(self): random_voice: bool = attrib(
self.max_chars = 0 validator=instance_of(bool),
self.voices = voices default=False
)
max_chars: int = 0
def run(
self,
text,
filepath,
) -> None:
"""
Calls for TTS api
def run(self, text, filepath, random_voice: bool = False): Args:
text: text to be voiced over
filepath: name of the audio file
"""
try: try:
session = Session(profile_name="polly") session = Session(profile_name="polly")
polly = session.client("polly") polly = session.client("polly")
if random_voice: voice = (
voice = self.randomvoice() get_random_voice(voices)
else: if self.random_voice
if not settings.config["settings"]["tts"]["aws_polly_voice"]: else str(settings.config["settings"]["tts"]["aws_polly_voice"]).capitalize()
raise ValueError( if str(settings.config["settings"]["tts"]["aws_polly_voice"]).lower() in [voice.lower() for voice in
f"Please set the TOML variable AWS_VOICE to a valid voice. options are: {voices}" voices]
) else get_random_voice(voices)
voice = str(settings.config["settings"]["tts"]["aws_polly_voice"]).capitalize() )
try: try:
# Request speech synthesis # Request speech synthesis
response = polly.synthesize_speech( response = polly.synthesize_speech(
@ -71,6 +90,3 @@ class AWSPolly:
""" """
) )
sys.exit(-1) sys.exit(-1)
def randomvoice(self):
return random.choice(self.voices)

@ -0,0 +1,141 @@
import base64
from random import choice
from typing import Union, Optional
class BaseApiTTS:
max_chars: int
decode_base64: bool = False
@staticmethod
def text_len_sanitize(
text: str,
max_length: int,
) -> list:
"""
Splits text if it's too long to be a query
Args:
text: text to be sanitized
max_length: maximum length of the query
Returns:
Split text as a list
"""
# Split by comma or dot (else you can lose intonations), if there is non, split by groups of 299 chars
split_text = list(
map(lambda x: x.strip() if x.strip()[-1] != "." else x.strip()[:-1],
filter(lambda x: True if x else False, text.split(".")))
)
if split_text and all([chunk.__len__() < max_length for chunk in split_text]):
return split_text
split_text = list(
map(lambda x: x.strip() if x.strip()[-1] != "," else x.strip()[:-1],
filter(lambda x: True if x else False, text.split(","))
)
)
if split_text and all([chunk.__len__() < max_length for chunk in split_text]):
return split_text
return list(
map(
lambda x: x.strip() if x.strip()[-1] != "." or x.strip()[-1] != "," else x.strip()[:-1],
filter(
lambda x: True if x else False,
[text[i:i + max_length] for i in range(0, len(text), max_length)]
)
)
)
def write_file(
self,
output_text: str,
filepath: str,
) -> None:
"""
Writes and decodes TTS responses in files
Args:
output_text: text to be written
filepath: path/name of the file
"""
decoded_text = base64.b64decode(output_text) if self.decode_base64 else output_text
with open(filepath, "wb") as out:
out.write(decoded_text)
def run(
self,
text: str,
filepath: str,
) -> None:
"""
Calls for TTS api and writes audio file
Args:
text: text to be voice over
filepath: path/name of the file
Returns:
"""
output_text = ""
if len(text) > self.max_chars:
for part in self.text_len_sanitize(text, self.max_chars):
if part:
output_text += self.make_request(part)
else:
output_text = self.make_request(text)
self.write_file(output_text, filepath)
def get_random_voice(
voices: Union[list, dict],
key: Optional[str] = None,
) -> str:
"""
Return random voice from list or dict
Args:
voices: list or dict of voices
key: key of a dict if you are using one
Returns:
random voice as a str
"""
if isinstance(voices, list):
return choice(voices)
else:
return choice(voices[key] if key else list(voices.values())[0])
def audio_length(
path: str,
) -> Union[float, int]:
"""
Gets the length of the audio file
Args:
path: audio file path
Returns:
length in seconds as an int
"""
from moviepy.editor import AudioFileClip
try:
# please use something else here in the future
audio_clip = AudioFileClip(path)
audio_duration = audio_clip.duration
audio_clip.close()
return audio_duration
except Exception as e:
import logging
logger = logging.getLogger("tts_logger")
logger.setLevel(logging.ERROR)
handler = logging.FileHandler(".tts.log", mode="a+", encoding="utf-8")
logger.addHandler(handler)
logger.error("Error occurred in audio_length:", e)
return 0

@ -1,145 +1,140 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from pathlib import Path from pathlib import Path
from typing import Tuple from typing import Union
import re
# import sox
# from mutagen import MutagenError
# from mutagen.mp3 import MP3, HeaderNotFoundError
import translators as ts import translators as ts
from rich.progress import track from rich.progress import track
from moviepy.editor import AudioFileClip, CompositeAudioClip, concatenate_audioclips from attr import attrs, attrib
from utils.console import print_step, print_substep from utils.console import print_step, print_substep
from utils.voice import sanitize_text from utils.voice import sanitize_text
from utils import settings from utils import settings
from TTS.common import audio_length
DEFAULT_MAX_LENGTH: int = 50 # video length variable from TTS.GTTS import GTTS
from TTS.streamlabs_polly import StreamlabsPolly
from TTS.TikTok import TikTok
from TTS.aws_polly import AWSPolly
@attrs
class TTSEngine: class TTSEngine:
"""Calls the given TTS engine to reduce code duplication and allow multiple TTS engines. """Calls the given TTS engine to reduce code duplication and allow multiple TTS engines.
Args: Args:
tts_module : The TTS module. Your module should handle the TTS itself and saving to the given path under the run method. tts_module : The TTS module. Your module should handle the TTS itself and saving to the given path under the run method.
reddit_object : The reddit object that contains the posts to read. reddit_object : The reddit object that contains the posts to read.
path (Optional) : The unix style path to save the mp3 files to. This must not have leading or trailing slashes.
max_length (Optional) : The maximum length of the mp3 files in total.
Notes: Notes:
tts_module must take the arguments text and filepath. tts_module must take the arguments text and filepath.
""" """
tts_module: Union[GTTS, StreamlabsPolly, TikTok, AWSPolly] = attrib()
reddit_object: dict = attrib()
__path: str = "assets/temp/mp3"
__total_length: int = 0
def __attrs_post_init__(self):
# Calls an instance of the tts_module class
self.tts_module = self.tts_module()
# Loading settings from the config
self.max_length: int = settings.config["settings"]["video_length"]
self.time_before_tts: float = settings.config["settings"]["time_before_tts"]
self.time_between_pictures: float = settings.config["settings"]["time_between_pictures"]
self.__total_length = (
settings.config["settings"]["time_before_first_picture"] +
settings.config["settings"]["delay_before_end"]
)
def run(
self
) -> list:
"""
Voices over comments & title of the submission
Returns:
Indexes of comments to be used in the final video
"""
Path(self.__path).mkdir(parents=True, exist_ok=True)
def __init__( # This file needs to be removed in case this post does not use post text
self, # so that it won't appear in the final video
tts_module,
reddit_object: dict,
path: str = "assets/temp/mp3",
max_length: int = DEFAULT_MAX_LENGTH,
last_clip_length: int = 0,
):
self.tts_module = tts_module()
self.reddit_object = reddit_object
self.path = path
self.max_length = max_length
self.length = 0
self.last_clip_length = last_clip_length
def run(self) -> Tuple[int, int]:
Path(self.path).mkdir(parents=True, exist_ok=True)
# This file needs to be removed in case this post does not use post text, so that it won't appear in the final video
try: try:
Path(f"{self.path}/posttext.mp3").unlink() Path(f"{self.__path}/posttext.mp3").unlink()
except OSError: except OSError:
pass pass
print_step("Saving Text to MP3 files...") print_step("Saving Text to MP3 files...")
self.call_tts("title", self.reddit_object["thread_title"]) self.call_tts("title", self.reddit_object["thread_title"])
if (
self.reddit_object["thread_post"] != "" if self.reddit_object["thread_post"] and settings.config["settings"]["storymode"]:
and settings.config["settings"]["storymode"] == True
):
self.call_tts("posttext", self.reddit_object["thread_post"]) self.call_tts("posttext", self.reddit_object["thread_post"])
idx = None sync_tasks_primary = [
for idx, comment in track( self.call_tts(str(idx), comment["comment_body"])
enumerate(self.reddit_object["comments"]), "Saving..." for idx, comment in track(
): enumerate(self.reddit_object["comments"]),
# ! Stop creating mp3 files if the length is greater than max length. description="Saving...",
if self.length > self.max_length: total=self.reddit_object["comments"].__len__())
self.length -= self.last_clip_length # Crunch, there will be fix in async TTS api, maybe
idx -= 1 if self.__total_length + self.__total_length * 0.05 < self.max_length
break ]
if (
len(comment["comment_body"]) > self.tts_module.max_chars
): # Split the comment if it is too long
self.split_post(comment["comment_body"], idx) # Split the comment
else: # If the comment is not too long, just call the tts engine
self.call_tts(f"{idx}", comment["comment_body"])
print_substep("Saved Text to MP3 files successfully.", style="bold green") print_substep("Saved Text to MP3 files successfully.", style="bold green")
return self.length, idx return [
comments for comments, condition in
def split_post(self, text: str, idx: int): zip(range(self.reddit_object["comments"].__len__()), sync_tasks_primary)
split_files = [] if condition
split_text = [
x.group().strip()
for x in re.finditer(
r" *(((.|\n){0," + str(self.tts_module.max_chars) + "})(\.|.$))", text
)
] ]
offset = 0
for idy, text_cut in enumerate(split_text):
# print(f"{idx}-{idy}: {text_cut}\n")
if not text_cut or text_cut.isspace():
offset += 1
continue
self.call_tts(f"{idx}-{idy - offset}.part", text_cut)
split_files.append(
AudioFileClip(f"{self.path}/{idx}-{idy - offset}.part.mp3")
)
CompositeAudioClip([concatenate_audioclips(split_files)]).write_audiofile(
f"{self.path}/{idx}.mp3", fps=44100, verbose=False, logger=None
)
for i in split_files: def call_tts(
name = i.filename self,
i.close() filename: str,
Path(name).unlink() text: str
) -> bool:
"""
Calls for TTS api from the factory
# for i in range(0, idy + 1): Args:
# print(f"Cleaning up {self.path}/{idx}-{i}.part.mp3") filename: name of audio file w/o .mp3
text: text to be voiced over
# Path(f"{self.path}/{idx}-{i}.part.mp3").unlink() Returns:
True if audio files not exceeding the maximum length else false
"""
if not text:
return False
def call_tts(self, filename: str, text: str):
self.tts_module.run( self.tts_module.run(
text=process_text(text), filepath=f"{self.path}/{filename}.mp3" text=self.process_text(text),
filepath=f"{self.__path}/{filename}.mp3"
) )
# try:
# self.length += MP3(f"{self.path}/{filename}.mp3").info.length clip_length = audio_length(f"{self.__path}/{filename}.mp3")
# except (MutagenError, HeaderNotFoundError): clip_offset = self.time_between_pictures + self.time_before_tts * 2
# self.length += sox.file_info.duration(f"{self.path}/{filename}.mp3")
try: if clip_length and self.__total_length + clip_length + clip_offset <= self.max_length:
clip = AudioFileClip(f"{self.path}/{filename}.mp3") self.__total_length += clip_length + clip_offset
if clip.duration + self.length < self.max_length: return True
self.last_clip_length = clip.duration return False
self.length += clip.duration
clip.close() @staticmethod
except: def process_text(
self.length = 0 text: str,
) -> str:
"""
def process_text(text: str): Sanitizes text for illegal characters and translates text
lang = settings.config["reddit"]["thread"]["post_lang"]
new_text = sanitize_text(text) Args:
if lang: text: text to be sanitized & translated
print_substep("Translating Text...")
translated_text = ts.google(text, to_language=lang) Returns:
new_text = sanitize_text(translated_text) Processed text as a str
return new_text """
lang = settings.config["reddit"]["thread"]["post_lang"]
new_text = sanitize_text(text)
if lang:
print_substep("Translating Text...")
translated_text = ts.google(text, to_language=lang)
new_text = sanitize_text(translated_text)
return new_text

@ -1,7 +1,10 @@
import random
import requests import requests
from requests.exceptions import JSONDecodeError from requests.exceptions import JSONDecodeError
from utils import settings from utils import settings
from attr import attrs, attrib
from attr.validators import instance_of
from TTS.common import BaseApiTTS, get_random_voice
from utils.voice import check_ratelimit from utils.voice import check_ratelimit
voices = [ voices = [
@ -26,37 +29,52 @@ voices = [
# valid voices https://lazypy.ro/tts/ # valid voices https://lazypy.ro/tts/
class StreamlabsPolly: @attrs
def __init__(self): class StreamlabsPolly(BaseApiTTS):
self.url = "https://streamlabs.com/polly/speak" random_voice: bool = attrib(
self.max_chars = 550 validator=instance_of(bool),
self.voices = voices default=False
)
url: str = "https://streamlabs.com/polly/speak"
max_chars: int = 550
def run(self, text, filepath, random_voice: bool = False): def make_request(
if random_voice: self,
voice = self.randomvoice() text,
else: ):
if not settings.config["settings"]["tts"]["streamlabs_polly_voice"]: """
raise ValueError( Makes a requests to remote TTS service
f"Please set the config variable STREAMLABS_POLLY_VOICE to a valid voice. options are: {voices}"
)
voice = str(settings.config["settings"]["tts"]["streamlabs_polly_voice"]).capitalize()
body = {"voice": voice, "text": text, "service": "polly"}
response = requests.post(self.url, data=body)
if not check_ratelimit(response):
self.run(text, filepath, random_voice)
Args:
text: text to be voice over
Returns:
Request's response
"""
voice = (
get_random_voice(voices)
if self.random_voice
else str(settings.config["settings"]["tts"]["streamlabs_polly_voice"]).capitalize()
if str(settings.config["settings"]["tts"]["streamlabs_polly_voice"]).lower() in [
voice.lower() for voice in voices]
else get_random_voice(voices)
)
response = requests.post(
self.url,
data={
"voice": voice,
"text": text,
"service": "polly",
})
if not check_ratelimit(response):
return self.make_request(text)
else: else:
try: try:
voice_data = requests.get(response.json()["speak_url"]) results = requests.get(response.json()["speak_url"])
with open(filepath, "wb") as f: return results.content
f.write(voice_data.content)
except (KeyError, JSONDecodeError): except (KeyError, JSONDecodeError):
try: try:
if response.json()["error"] == "No text specified!": if response.json()["error"] == "No text specified!":
raise ValueError("Please specify a text to convert to speech.") raise ValueError("Please specify a text to convert to speech.")
except (KeyError, JSONDecodeError): except (KeyError, JSONDecodeError):
print("Error occurred calling Streamlabs Polly") print("Error occurred calling Streamlabs Polly")
def randomvoice(self):
return random.choice(self.voices)

@ -1,5 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import math from asyncio import run
from subprocess import Popen from subprocess import Popen
from os import name from os import name
@ -11,12 +11,10 @@ from utils.console import print_markdown, print_step
from utils import settings from utils import settings
from video_creation.background import ( from video_creation.background import (
download_background,
chop_background_video,
get_background_config, get_background_config,
) )
from video_creation.final_video import make_final_video from video_creation.final_video import FinalVideo
from video_creation.screenshot_downloader import download_screenshots_of_reddit_posts from webdriver.web_engine import screenshot_factory
from video_creation.voices import save_text_to_mp3 from video_creation.voices import save_text_to_mp3
__VERSION__ = "2.3.1" __VERSION__ = "2.3.1"
@ -39,24 +37,22 @@ print_markdown(
print_step(f"You are using v{__VERSION__} of the bot") print_step(f"You are using v{__VERSION__} of the bot")
def main(POST_ID=None): async def main(POST_ID=None):
cleanup() cleanup()
reddit_object = get_subreddit_threads(POST_ID) reddit_object = get_subreddit_threads(POST_ID)
length, number_of_comments = save_text_to_mp3(reddit_object) comments_created = save_text_to_mp3(reddit_object)
length = math.ceil(length) webdriver = screenshot_factory(config["settings"]["webdriver"])
download_screenshots_of_reddit_posts(reddit_object, number_of_comments) await webdriver(reddit_object, comments_created).download()
bg_config = get_background_config() bg_config = get_background_config()
download_background(bg_config) FinalVideo().make(comments_created, reddit_object, bg_config)
chop_background_video(bg_config, length)
make_final_video(number_of_comments, length, reddit_object, bg_config)
def run_many(times): async def run_many(times):
for x in range(1, times + 1): for x in range(1, times + 1):
print_step( print_step(
f'on the {x}{("th", "st", "nd", "rd", "th", "th", "th", "th", "th", "th")[x % 10]} iteration of {times}' f'on the {x}{("th", "st", "nd", "rd", "th", "th", "th", "th", "th", "th")[x % 10]} iteration of {times}'
) # correct 1st 2nd 3rd 4th 5th.... ) # correct 1st 2nd 3rd 4th 5th....
main() await main()
Popen("cls" if name == "nt" else "clear", shell=True).wait() Popen("cls" if name == "nt" else "clear", shell=True).wait()
@ -72,7 +68,9 @@ if __name__ == "__main__":
config is False and exit() config is False and exit()
try: try:
if config["settings"]["times_to_run"]: if config["settings"]["times_to_run"]:
run_many(config["settings"]["times_to_run"]) run(
run_many(config["settings"]["times_to_run"])
)
elif len(config["reddit"]["thread"]["post_id"].split("+")) > 1: elif len(config["reddit"]["thread"]["post_id"].split("+")) > 1:
for index, post_id in enumerate(config["reddit"]["thread"]["post_id"].split("+")): for index, post_id in enumerate(config["reddit"]["thread"]["post_id"].split("+")):
@ -80,11 +78,13 @@ if __name__ == "__main__":
print_step( print_step(
f'on the {index}{("st" if index % 10 == 1 else ("nd" if index % 10 == 2 else ("rd" if index % 10 == 3 else "th")))} post of {len(config["reddit"]["thread"]["post_id"].split("+"))}' f'on the {index}{("st" if index % 10 == 1 else ("nd" if index % 10 == 2 else ("rd" if index % 10 == 3 else "th")))} post of {len(config["reddit"]["thread"]["post_id"].split("+"))}'
) )
main(post_id) run(
main(post_id)
)
Popen("cls" if name == "nt" else "clear", shell=True).wait() Popen("cls" if name == "nt" else "clear", shell=True).wait()
else: else:
main() main()
except KeyboardInterrupt: except KeyboardInterrupt: # TODO won't work with async code
shutdown() shutdown()
except ResponseException: except ResponseException:
# error for invalid credentials # error for invalid credentials

@ -87,6 +87,7 @@ def get_subreddit_threads(POST_ID: str):
content["thread_title"] = submission.title content["thread_title"] = submission.title
content["thread_post"] = submission.selftext content["thread_post"] = submission.selftext
content["thread_id"] = submission.id content["thread_id"] = submission.id
content["is_nsfw"] = "nsfw" in submission.whitelist_status
content["comments"] = [] content["comments"] = []
for top_level_comment in submission.comments: for top_level_comment in submission.comments:

@ -9,5 +9,6 @@ requests==2.28.1
rich==12.5.1 rich==12.5.1
toml==0.10.2 toml==0.10.2
translators==5.3.1 translators==5.3.1
pyppeteer==1.0.2
attrs==21.4.0
Pillow~=9.1.1 Pillow~=9.1.1

@ -16,7 +16,8 @@ subreddit = { optional = false, regex = "[_0-9a-zA-Z]+$", nmin = 3, explanation
post_id = { optional = true, default = "", regex = "^((?!://|://)[+a-zA-Z])*$", explanation = "Used if you want to use a specific post.", example = "urdtfx" } post_id = { optional = true, default = "", regex = "^((?!://|://)[+a-zA-Z])*$", explanation = "Used if you want to use a specific post.", example = "urdtfx" }
max_comment_length = { default = 500, optional = false, nmin = 10, nmax = 10000, type = "int", explanation = "max number of characters a comment can have. default is 500", example = 500, oob_error = "the max comment length should be between 10 and 10000" } max_comment_length = { default = 500, optional = false, nmin = 10, nmax = 10000, type = "int", explanation = "max number of characters a comment can have. default is 500", example = 500, oob_error = "the max comment length should be between 10 and 10000" }
post_lang = { default = "", optional = true, explanation = "The language you would like to translate to.", example = "es-cr" } post_lang = { default = "", optional = true, explanation = "The language you would like to translate to.", example = "es-cr" }
min_comments = { default = 20, optional = false, nmin = 15, type = "int", explanation = "The minimum number of comments a post should have to be included. default is 20", example = 29, oob_error = "the minimum number of comments should be between 15 and 999999" } min_comments = { default = 20, optional = false, nmin = 15, type = "int", explanation = "The minimum number of comments a post should have to be included. default is 20", example = 29, oob_error = "the minimum number of comments must be at least 15" }
[settings] [settings]
allow_nsfw = { optional = false, type = "bool", default = false, example = false, options = [true, allow_nsfw = { optional = false, type = "bool", default = false, example = false, options = [true,
false, false,
@ -30,6 +31,14 @@ transition = { optional = true, default = 0.2, example = 0.2, explanation = "Set
storymode = { optional = true, type = "bool", default = false, example = false, options = [true, storymode = { optional = true, type = "bool", default = false, example = false, options = [true,
false, false,
], explanation = "not yet implemented" } ], explanation = "not yet implemented" }
video_length = { optional = false, default = 50, example = 60, explanation = "Approximated final video length", type = "int", nmin = 15, oob_error = "15 seconds is short enought" }
time_before_first_picture = { optional = false, default = 0.5, example = 1.0, explanation = "Deley before first screenshot apears", type = "float", nmin = 0, oob_error = "Choose at least 0 second" }
time_before_tts = { optional = false, default = 0.5, example = 1.0, explanation = "Deley between screenshot and TTS", type = "float", nmin = 0, oob_error = "Choose at least 0 second" }
time_between_pictures = { optional = false, default = 0.5, example = 1.0, explanation = "Time between every screenshot", type = "float", nmin = 0, oob_error = "Choose at least 0 second" }
delay_before_end = { optional = false, default = 0.5, example = 1.0, explanation = "Deley before video ends", type = "float", nmin = 0, oob_error = "Choose at least 0 second" }
video_width = { optional = true, default = 1080, example = 1080, explanation = "Final video width", type = "int", nmin = 600, oob_error = "Choose at least 600 pixels wide" }
video_height = { optional = true, default = 1920, example = 1920, explanation = "Final video height", type = "int", nmin = 600, oob_error = "Choose at least 600 pixels long" }
webdriver = { optional = true, default = "pyppeteer", example = "pyppeteer", options = ["pyppeteer", "playwright"], explanation = "Driver used to take screenshots (use pyppeteer if you have some problems with playwright)"}
[settings.background] [settings.background]
background_choice = { optional = true, default = "minecraft", example = "minecraft", options = ["minecraft", "gta", "rocket-league", "motor-gta", "csgo-surf", "cluster-truck", ""], explanation = "Sets the background for the video" } background_choice = { optional = true, default = "minecraft", example = "minecraft", options = ["minecraft", "gta", "rocket-league", "motor-gta", "csgo-surf", "cluster-truck", ""], explanation = "Sets the background for the video" }

@ -3,13 +3,13 @@ import toml
from rich.console import Console from rich.console import Console
import re import re
from typing import Tuple, Dict from typing import Dict, Optional, Union
from utils.console import handle_input from utils.console import handle_input
console = Console() console = Console()
config = dict # autocomplete config: Optional[dict] = None # autocomplete
def crawl(obj: dict, func=lambda x, y: print(x, y, end="\n"), path=None): def crawl(obj: dict, func=lambda x, y: print(x, y, end="\n"), path=None):
@ -108,7 +108,7 @@ def check_vars(path, checks):
crawl_and_check(config, path, checks) crawl_and_check(config, path, checks)
def check_toml(template_file, config_file) -> Tuple[bool, Dict]: def check_toml(template_file, config_file) -> Union[bool, Dict]:
global config global config
config = None config = None
try: try:

@ -9,6 +9,7 @@ def get_subreddit_undone(submissions: list, subreddit, times_checked=0):
"""_summary_ """_summary_
Args: Args:
times_checked: (int): For internal use, number of times function was called
submissions (list): List of posts that are going to potentially be generated into a video submissions (list): List of posts that are going to potentially be generated into a video
subreddit (praw.Reddit.SubredditHelper): Chosen subreddit subreddit (praw.Reddit.SubredditHelper): Chosen subreddit
@ -34,22 +35,23 @@ def get_subreddit_undone(submissions: list, subreddit, times_checked=0):
if submission.stickied: if submission.stickied:
print_substep("This post was pinned by moderators. Skipping...") print_substep("This post was pinned by moderators. Skipping...")
continue continue
if submission.num_comments <= int(settings.config["reddit"]["thread"]["min_comments"]): if submission.num_comments < int(settings.config["reddit"]["thread"]["min_comments"]):
print_substep( print_substep(
f'This post has under the specified minimum of comments ({settings.config["reddit"]["thread"]["min_comments"]}). Skipping...' "This post has under the specified minimum of comments"
f'({settings.config["reddit"]["thread"]["min_comments"]}). Skipping...'
) )
continue continue
return submission return submission
print("all submissions have been done going by top submission order") print("all submissions have been done going by top submission order")
VALID_TIME_FILTERS = [ VALID_TIME_FILTERS = [
"day",
"hour", "hour",
"day",
"month", "month",
"week", "week",
"year", "year",
"all", "all",
] # set doesn't have __getitem__ ] # set doesn't have __getitem__
index = times_checked + 1 index = times_checked + 1 if times_checked != 0 else times_checked
if index == len(VALID_TIME_FILTERS): if index == len(VALID_TIME_FILTERS):
print("all time filters have been checked you absolute madlad ") print("all time filters have been checked you absolute madlad ")

@ -48,7 +48,7 @@ class Video:
img_clip = img_clip.set_opacity(opacity).set_duration(duration) img_clip = img_clip.set_opacity(opacity).set_duration(duration)
img_clip = img_clip.set_position( img_clip = img_clip.set_position(
position, relative=True position, relative=True
) # todo get dara from utils/CONSTANTS.py and adapt position accordingly ) # todo get data from utils/CONSTANTS.py and adapt position accordingly
# Overlay the img clip on the first video clip # Overlay the img clip on the first video clip
self.video = CompositeVideoClip([self.video, img_clip]) self.video = CompositeVideoClip([self.video, img_clip])

@ -1,14 +0,0 @@
[
{
"name": "USER",
"value": "eyJwcmVmcyI6eyJ0b3BDb250ZW50RGlzbWlzc2FsVGltZSI6MCwiZ2xvYmFsVGhlbWUiOiJSRURESVQiLCJuaWdodG1vZGUiOnRydWUsImNvbGxhcHNlZFRyYXlTZWN0aW9ucyI6eyJmYXZvcml0ZXMiOmZhbHNlLCJtdWx0aXMiOmZhbHNlLCJtb2RlcmF0aW5nIjpmYWxzZSwic3Vic2NyaXB0aW9ucyI6ZmFsc2UsInByb2ZpbGVzIjpmYWxzZX0sInRvcENvbnRlbnRUaW1lc0Rpc21pc3NlZCI6MH19",
"domain": ".reddit.com",
"path": "/"
},
{
"name": "eu_cookie",
"value": "{%22opted%22:true%2C%22nonessential%22:false}",
"domain": ".reddit.com",
"path": "/"
}
]

@ -1,8 +0,0 @@
[
{
"name": "eu_cookie",
"value": "{%22opted%22:true%2C%22nonessential%22:false}",
"domain": ".reddit.com",
"path": "/"
}
]

@ -3,165 +3,286 @@ import multiprocessing
import os import os
import re import re
from os.path import exists from os.path import exists
from typing import Tuple, Any from typing import Tuple, Any, Union
from moviepy.audio.AudioClip import concatenate_audioclips, CompositeAudioClip
from moviepy.audio.io.AudioFileClip import AudioFileClip from moviepy.editor import (
from moviepy.video.VideoClip import ImageClip VideoFileClip,
from moviepy.video.compositing.CompositeVideoClip import CompositeVideoClip AudioFileClip,
from moviepy.video.compositing.concatenate import concatenate_videoclips ImageClip,
from moviepy.video.io.VideoFileClip import VideoFileClip CompositeAudioClip,
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip CompositeVideoClip,
)
from rich.console import Console from rich.console import Console
from rich.progress import track
from attr import attrs
from utils.cleanup import cleanup from utils.cleanup import cleanup
from utils.console import print_step, print_substep from utils.console import print_step, print_substep
from utils.video import Video from utils.video import Video
from utils.videos import save_data from utils.videos import save_data
from utils import settings from utils import settings
from video_creation.background import download_background, chop_background_video
@attrs
class FinalVideo:
video_duration: int = 0
console = Console()
def __attrs_post_init__(self):
self.W: int = int(settings.config["settings"]["video_width"])
self.H: int = int(settings.config["settings"]["video_height"])
if not self.W or not self.H:
self.W, self.H = 1080, 1920
self.vertical_video: bool = self.W < self.H
self.max_length: int = int(settings.config["settings"]["video_length"])
self.time_before_first_picture: float = settings.config["settings"]["time_before_first_picture"]
self.time_before_tts: float = settings.config["settings"]["time_before_tts"]
self.time_between_pictures: float = settings.config["settings"]["time_between_pictures"]
self.delay_before_end: float = settings.config["settings"]["delay_before_end"]
self.opacity = settings.config["settings"]["opacity"]
self.opacity = 1 if self.opacity is None or self.opacity >= 1 else self.opacity
self.transition = settings.config["settings"]["transition"]
self.transition = 0 if self.transition is None or self.transition > 2 else self.transition
@staticmethod
def name_normalize(
name: str
) -> str:
name = re.sub(r'[?\\"%*:|<>]', "", name)
name = re.sub(r"( [w,W]\s?/\s?[oO0])", r" without", name)
name = re.sub(r"( [w,W]\s?/)", r" with", name)
name = re.sub(r"(\d+)\s?/\s?(\d+)", r"\1 of \2", name)
name = re.sub(r"(\w+)\s?/\s?(\w+)", r"\1 or \2", name)
name = re.sub(r"/", "", name)
lang = settings.config["reddit"]["thread"]["post_lang"]
translated_name = None
if lang:
import translators as ts
print_substep("Translating filename...")
translated_name = ts.google(name, to_language=lang)
return translated_name[:30] if translated_name else name[:30]
@staticmethod
def create_audio_clip(
clip_title: Union[str, int],
clip_start: float,
) -> AudioFileClip:
return (
AudioFileClip(f"assets/temp/mp3/{clip_title}.mp3")
.set_start(clip_start)
)
def create_image_clip(
self,
image_title: Union[str, int],
audio_start: float,
audio_duration: float,
clip_position: str,
) -> ImageClip:
return (
ImageClip(f"assets/temp/png/{image_title}.png")
.set_start(audio_start - self.time_before_tts)
.set_duration(self.time_before_tts * 2 + audio_duration)
.set_opacity(self.opacity)
.set_position(clip_position)
.resize(
width=self.W - self.W / 20 if self.vertical_video else None,
height=self.H - self.H / 5 if not self.vertical_video else None,
)
.crossfadein(self.transition)
.crossfadeout(self.transition)
)
def make(
self,
indexes_of_clips: list,
reddit_obj: dict,
background_config: Tuple[str, str, str, Any],
) -> None:
"""Gathers audio clips, gathers all screenshots, stitches them together and saves the final video to assets/temp
Args:
indexes_of_clips (list): Indexes of voiced comments
reddit_obj (dict): The reddit object that contains the posts to read.
background_config (Tuple[str, str, str, Any]): The background config to use.
"""
# try: # if it isn't found (i.e you just updated and copied over config.toml) it will throw an error
# VOLUME_MULTIPLIER = settings.config["settings"]['background']["background_audio_volume"]
# except (TypeError, KeyError):
# print('No background audio volume found in config.toml. Using default value of 1.')
# VOLUME_MULTIPLIER = 1
print_step("Creating the final video 🎥")
VideoFileClip.reW = lambda clip: clip.resize(width=self.W)
VideoFileClip.reH = lambda clip: clip.resize(width=self.H)
# Gather all audio clips
audio_clips = list()
correct_audio_offset = self.time_before_tts * 2 + self.time_between_pictures
audio_title = self.create_audio_clip(
"title",
self.time_before_first_picture + self.time_before_tts,
)
self.video_duration += audio_title.duration + self.time_before_first_picture + self.time_before_tts
audio_clips.append(audio_title)
indexes_for_videos = list()
for audio_title in track(
indexes_of_clips,
description="Gathering audio clips...",
total=indexes_of_clips.__len__()
):
temp_audio_clip = self.create_audio_clip(
audio_title,
correct_audio_offset + self.video_duration,
)
if self.video_duration + temp_audio_clip.duration + \
correct_audio_offset + self.delay_before_end <= self.max_length:
self.video_duration += temp_audio_clip.duration + correct_audio_offset
audio_clips.append(temp_audio_clip)
indexes_for_videos.append(audio_title)
self.video_duration += self.delay_before_end + self.time_before_tts
# Can't use concatenate_audioclips here, it resets clips' start point
audio_composite = CompositeAudioClip(audio_clips)
self.console.log("[bold green] Video Will Be: %.2f Seconds Long" % self.video_duration)
# Gather all images
image_clips = list()
# Accounting for title and other stuff if audio_clips
index_offset = 1
console = Console()
W, H = 1080, 1920
def name_normalize(name: str) -> str:
name = re.sub(r'[?\\"%*:|<>]', "", name)
name = re.sub(r"( [w,W]\s?\/\s?[o,O,0])", r" without", name)
name = re.sub(r"( [w,W]\s?\/)", r" with", name)
name = re.sub(r"(\d+)\s?\/\s?(\d+)", r"\1 of \2", name)
name = re.sub(r"(\w+)\s?\/\s?(\w+)", r"\1 or \2", name)
name = re.sub(r"\/", r"", name)
name[:30]
lang = settings.config["reddit"]["thread"]["post_lang"]
if lang:
import translators as ts
print_substep("Translating filename...")
translated_name = ts.google(name, to_language=lang)
return translated_name
else:
return name
def make_final_video(
number_of_clips: int,
length: int,
reddit_obj: dict,
background_config: Tuple[str, str, str, Any],
):
"""Gathers audio clips, gathers all screenshots, stitches them together and saves the final video to assets/temp
Args:
number_of_clips (int): Index to end at when going through the screenshots'
length (int): Length of the video
reddit_obj (dict): The reddit object that contains the posts to read.
background_config (Tuple[str, str, str, Any]): The background config to use.
"""
# try: # if it isn't found (i.e you just updated and copied over config.toml) it will throw an error
# VOLUME_MULTIPLIER = settings.config["settings"]['background']["background_audio_volume"]
# except (TypeError, KeyError):
# print('No background audio volume found in config.toml. Using default value of 1.')
# VOLUME_MULTIPLIER = 1
print_step("Creating the final video 🎥")
VideoFileClip.reW = lambda clip: clip.resize(width=W)
VideoFileClip.reH = lambda clip: clip.resize(width=H)
opacity = settings.config["settings"]["opacity"]
transition = settings.config["settings"]["transition"]
background_clip = (
VideoFileClip("assets/temp/background.mp4")
.without_audio()
.resize(height=H)
.crop(x1=1166.6, y1=0, x2=2246.6, y2=1920)
)
# Gather all audio clips
audio_clips = [AudioFileClip(f"assets/temp/mp3/{i}.mp3") for i in range(number_of_clips)]
audio_clips.insert(0, AudioFileClip("assets/temp/mp3/title.mp3"))
audio_concat = concatenate_audioclips(audio_clips)
audio_composite = CompositeAudioClip([audio_concat])
console.log(f"[bold green] Video Will Be: {length} Seconds Long")
# add title to video
image_clips = []
# Gather all images
new_opacity = 1 if opacity is None or float(opacity) >= 1 else float(opacity)
new_transition = 0 if transition is None or float(transition) > 2 else float(transition)
image_clips.insert(
0,
ImageClip("assets/temp/png/title.png")
.set_duration(audio_clips[0].duration)
.resize(width=W - 100)
.set_opacity(new_opacity)
.crossfadein(new_transition)
.crossfadeout(new_transition),
)
for i in range(0, number_of_clips):
image_clips.append( image_clips.append(
ImageClip(f"assets/temp/png/comment_{i}.png") self.create_image_clip(
.set_duration(audio_clips[i + 1].duration) "title",
.resize(width=W - 100) audio_clips[0].start,
.set_opacity(new_opacity) audio_clips[0].duration,
.crossfadein(new_transition) background_config[3],
.crossfadeout(new_transition) )
) )
# if os.path.exists("assets/mp3/posttext.mp3"): for idx, photo_idx in track(
# image_clips.insert( enumerate(
# 0, indexes_for_videos,
# ImageClip("assets/png/title.png") start=index_offset,
# .set_duration(audio_clips[0].duration + audio_clips[1].duration) ),
# .set_position("center") description="Gathering audio clips...",
# .resize(width=W - 100) total=indexes_for_videos.__len__()
# .set_opacity(float(opacity)), ):
# ) image_clips.append(
# else: story mode stuff self.create_image_clip(
img_clip_pos = background_config[3] f"comment_{photo_idx}",
image_concat = concatenate_videoclips(image_clips).set_position( audio_clips[idx].start,
img_clip_pos audio_clips[idx].duration,
) # note transition kwarg for delay in imgs background_config[3],
image_concat.audio = audio_composite )
final = CompositeVideoClip([background_clip, image_concat]) )
title = re.sub(r"[^\w\s-]", "", reddit_obj["thread_title"])
idx = re.sub(r"[^\w\s-]", "", reddit_obj["thread_id"]) # if os.path.exists("assets/mp3/posttext.mp3"):
# image_clips.insert(
filename = f"{name_normalize(title)}.mp4" # 0,
subreddit = settings.config["reddit"]["thread"]["subreddit"] # ImageClip("assets/png/title.png")
# .set_duration(audio_clips[0].duration + audio_clips[1].duration)
if not exists(f"./results/{subreddit}"): # .set_position("center")
print_substep("The results folder didn't exist so I made it") # .resize(width=W - 100)
os.makedirs(f"./results/{subreddit}") # .set_opacity(float(opacity)),
# )
# if settings.config["settings"]['background']["background_audio"] and exists(f"assets/backgrounds/background.mp3"): # else: story mode stuff
# audioclip = mpe.AudioFileClip(f"assets/backgrounds/background.mp3").set_duration(final.duration)
# audioclip = audioclip.fx( volumex, 0.2) # Can't use concatenate_videoclips here, it resets clips' start point
# final_audio = mpe.CompositeAudioClip([final.audio, audioclip])
# # lowered_audio = audio_background.multiply_volume( # todo get this to work download_background(background_config)
# # VOLUME_MULTIPLIER) # lower volume by background_audio_volume, use with fx chop_background_video(background_config, self.video_duration)
# final.set_audio(final_audio) background_clip = (
final = Video(final).add_watermark( VideoFileClip("assets/temp/background.mp4")
text=f"Background credit: {background_config[2]}", opacity=0.4 .set_start(0)
) .set_end(self.video_duration)
final.write_videofile( .without_audio()
"assets/temp/temp.mp4", .resize(height=self.H)
fps=30, )
audio_codec="aac",
audio_bitrate="192k", back_video_width, back_video_height = background_clip.size
verbose=False,
threads=multiprocessing.cpu_count(), # Fix for crop with vertical videos
) if back_video_width < self.H:
ffmpeg_extract_subclip( background_clip = (
"assets/temp/temp.mp4", background_clip
0, .resize(width=self.W)
length, )
targetname=f"results/{subreddit}/{filename}", back_video_width, back_video_height = background_clip.size
) background_clip = background_clip.crop(
save_data(subreddit, filename, title, idx, background_config[2]) x1=0,
print_step("Removing temporary files 🗑") x2=back_video_width,
cleanups = cleanup() y1=back_video_height / 2 - self.H / 2,
print_substep(f"Removed {cleanups} temporary files 🗑") y2=back_video_height / 2 + self.H / 2
print_substep("See result in the results folder!") )
else:
print_step( background_clip = background_clip.crop(
f'Reddit title: {reddit_obj["thread_title"]} \n Background Credit: {background_config[2]}' x1=back_video_width / 2 - self.W / 2,
) x2=back_video_width / 2 + self.W / 2,
y1=0,
y2=back_video_height
)
final = CompositeVideoClip([background_clip, *image_clips])
final.audio = audio_composite
title = re.sub(r"[^\w\s-]", "", reddit_obj["thread_title"])
idx = re.sub(r"[^\w\s-]", "", reddit_obj["thread_id"])
filename = f"{self.name_normalize(title)}.mp4"
subreddit = str(settings.config["reddit"]["thread"]["subreddit"])
if not exists(f"./results/{subreddit}"):
print_substep("The results folder didn't exist so I made it")
os.makedirs(f"./results/{subreddit}")
# if (
# settings.config["settings"]['background']["background_audio"] and
# exists(f"assets/backgrounds/background.mp3")
# ):
# audioclip = (
# AudioFileClip(f"assets/backgrounds/background.mp3")
# .set_duration(final.duration)
# .volumex(0.2)
# )
# final_audio = CompositeAudioClip([final.audio, audioclip])
# # lowered_audio = audio_background.multiply_volume( # TODO get this to work
# # VOLUME_MULTIPLIER) # lower volume by background_audio_volume, use with fx
# final.set_audio(final_audio)
final = Video(final).add_watermark(
text=f"Background credit: {background_config[2]}", opacity=0.4
)
final.write_videofile(
"assets/temp/temp.mp4",
fps=30,
audio_codec="aac",
audio_bitrate="192k",
verbose=False,
threads=multiprocessing.cpu_count(),
)
# Moves file in subreddit folder and renames it
os.rename(
"assets/temp/temp.mp4",
f"results/{subreddit}/{filename}",
)
save_data(subreddit, filename, title, idx, background_config[2])
print_step("Removing temporary files 🗑")
cleanups = cleanup()
print_substep(f"Removed {cleanups} temporary files 🗑")
print_substep("See result in the results folder!")
print_step(
f'Reddit title: {reddit_obj["thread_title"]} \n Background Credit: {background_config[2]}'
)

@ -1,114 +0,0 @@
import json
from pathlib import Path
from typing import Dict
from utils import settings
from playwright.async_api import async_playwright # pylint: disable=unused-import
# do not remove the above line
from playwright.sync_api import sync_playwright, ViewportSize
from rich.progress import track
import translators as ts
from utils.console import print_step, print_substep
storymode = False
def download_screenshots_of_reddit_posts(reddit_object: dict, screenshot_num: int):
"""Downloads screenshots of reddit posts as seen on the web. Downloads to assets/temp/png
Args:
reddit_object (Dict): Reddit object received from reddit/subreddit.py
screenshot_num (int): Number of screenshots to download
"""
print_step("Downloading screenshots of reddit posts...")
# ! Make sure the reddit screenshots folder exists
Path("assets/temp/png").mkdir(parents=True, exist_ok=True)
with sync_playwright() as p:
print_substep("Launching Headless Browser...")
browser = p.chromium.launch()
context = browser.new_context()
if settings.config["settings"]["theme"] == "dark":
cookie_file = open("./video_creation/data/cookie-dark-mode.json", encoding="utf-8")
else:
cookie_file = open("./video_creation/data/cookie-light-mode.json", encoding="utf-8")
cookies = json.load(cookie_file)
context.add_cookies(cookies) # load preference cookies
# Get the thread screenshot
page = context.new_page()
page.goto(reddit_object["thread_url"], timeout=0)
page.set_viewport_size(ViewportSize(width=1920, height=1080))
if page.locator('[data-testid="content-gate"]').is_visible():
# This means the post is NSFW and requires to click the proceed button.
print_substep("Post is NSFW. You are spicy...")
page.locator('[data-testid="content-gate"] button').click()
page.wait_for_load_state() # Wait for page to fully load
if page.locator('[data-click-id="text"] button').is_visible():
page.locator(
'[data-click-id="text"] button'
).click() # Remove "Click to see nsfw" Button in Screenshot
# translate code
if settings.config["reddit"]["thread"]["post_lang"]:
print_substep("Translating post...")
texts_in_tl = ts.google(
reddit_object["thread_title"],
to_language=settings.config["reddit"]["thread"]["post_lang"],
)
page.evaluate(
"tl_content => document.querySelector('[data-test-id=\"post-content\"] > div:nth-child(3) > div > div').textContent = tl_content",
texts_in_tl,
)
else:
print_substep("Skipping translation...")
page.locator('[data-test-id="post-content"]').screenshot(path="assets/temp/png/title.png")
if storymode:
page.locator('[data-click-id="text"]').screenshot(
path="assets/temp/png/story_content.png"
)
else:
for idx, comment in enumerate(
track(reddit_object["comments"], "Downloading screenshots...")
):
# Stop if we have reached the screenshot_num
if idx >= screenshot_num:
break
if page.locator('[data-testid="content-gate"]').is_visible():
page.locator('[data-testid="content-gate"] button').click()
page.goto(f'https://reddit.com{comment["comment_url"]}', timeout=0)
# translate code
if settings.config["reddit"]["thread"]["post_lang"]:
comment_tl = ts.google(
comment["comment_body"],
to_language=settings.config["reddit"]["thread"]["post_lang"],
)
page.evaluate(
'([tl_content, tl_id]) => document.querySelector(`#t1_${tl_id} > div:nth-child(2) > div > div[data-testid="comment"] > div`).textContent = tl_content',
[comment_tl, comment["comment_id"]],
)
try:
page.locator(f"#t1_{comment['comment_id']}").screenshot(
path=f"assets/temp/png/comment_{idx}.png"
)
except TimeoutError:
del reddit_object["comments"]
screenshot_num += 1
print("TimeoutError: Skipping screenshot...")
continue
print_substep("Screenshots downloaded Successfully.", style="bold green")

@ -1,9 +1,3 @@
#!/usr/bin/env python
from typing import Dict, Tuple
from rich.console import Console
from TTS.engine_wrapper import TTSEngine from TTS.engine_wrapper import TTSEngine
from TTS.GTTS import GTTS from TTS.GTTS import GTTS
from TTS.streamlabs_polly import StreamlabsPolly from TTS.streamlabs_polly import StreamlabsPolly
@ -13,8 +7,6 @@ from utils import settings
from utils.console import print_table, print_step from utils.console import print_table, print_step
console = Console()
TTSProviders = { TTSProviders = {
"GoogleTranslate": GTTS, "GoogleTranslate": GTTS,
"AWSPolly": AWSPolly, "AWSPolly": AWSPolly,
@ -23,29 +15,29 @@ TTSProviders = {
} }
def save_text_to_mp3(reddit_obj) -> Tuple[int, int]: def save_text_to_mp3(
reddit_obj: dict,
) -> list:
"""Saves text to MP3 files. """Saves text to MP3 files.
Args: Args:
reddit_obj (): Reddit object received from reddit API in reddit/subreddit.py reddit_obj (): Reddit object received from reddit API in reddit/subreddit.py
Returns: Returns:
tuple[int,int]: (total length of the audio, the number of comments audio was generated for) The number of comments audio was generated for
""" """
voice = settings.config["settings"]["tts"]["choice"] voice = settings.config["settings"]["tts"]["choice"]
if str(voice).casefold() in map(lambda _: _.casefold(), TTSProviders): if voice.casefold() not in map(lambda _: _.casefold(), TTSProviders):
text_to_mp3 = TTSEngine(get_case_insensitive_key_value(TTSProviders, voice), reddit_obj)
else:
while True: while True:
print_step("Please choose one of the following TTS providers: ") print_step("Please choose one of the following TTS providers: ")
print_table(TTSProviders) print_table(TTSProviders)
choice = input("\n") voice = input("\n")
if choice.casefold() in map(lambda _: _.casefold(), TTSProviders): if voice.casefold() in map(lambda _: _.casefold(), TTSProviders):
break break
print("Unknown Choice") print("Unknown Choice")
text_to_mp3 = TTSEngine(get_case_insensitive_key_value(TTSProviders, choice), reddit_obj) engine_instance = TTSEngine(get_case_insensitive_key_value(TTSProviders, voice), reddit_obj)
return text_to_mp3.run() return engine_instance.run()
def get_case_insensitive_key_value(input_dict, key): def get_case_insensitive_key_value(input_dict, key):

@ -0,0 +1,84 @@
from attr import attrs, attrib
from typing import TypeVar, Optional, Callable, Union
_function = TypeVar("_function", bound=Callable[..., object])
_exceptions = TypeVar("_exceptions", bound=Optional[Union[type, tuple, list]])
default_exception = None
@attrs
class ExceptionDecorator:
"""
Decorator for catching exceptions and writing logs
"""
exception: Optional[_exceptions] = attrib(default=None)
def __attrs_post_init__(self):
if not self.exception:
self.exception = default_exception
def __call__(
self,
func: _function,
):
async def wrapper(*args, **kwargs):
try:
obj_to_return = await func(*args, **kwargs)
return obj_to_return
except Exception as caughtException:
import logging
logger = logging.getLogger("webdriver_log")
logger.setLevel(logging.ERROR)
handler = logging.FileHandler(".webdriver.log", mode="a+", encoding="utf-8")
logger.addHandler(handler)
if isinstance(self.exception, type):
if not type(caughtException) == self.exception:
logger.error(f"unexpected error - {caughtException}")
else:
if not type(caughtException) in self.exception:
logger.error(f"unexpected error - {caughtException}")
return wrapper
def catch_exception(
func: Optional[_function],
exception: Optional[_exceptions] = None,
) -> Union[object, _function]:
"""
Decorator for catching exceptions and writing logs
Args:
func: Function to be decorated
exception: Expected exception(s)
Returns:
Decorated function
"""
exceptor = ExceptionDecorator(exception)
if func:
exceptor = exceptor(func)
return exceptor
# Lots of tabs - lots of memory
# chunk needed to minimize memory required
def chunks(
array: list,
size: int,
):
"""
Yield successive n-sized chunks from list.
Args:
array: List to be chunked
size: size of a chunk
Returns:
Generator with chunked list
"""
for i in range(0, len(array), size):
yield array[i:i + size]

@ -0,0 +1,332 @@
from asyncio import as_completed
from pathlib import Path
from typing import Dict, Optional
import translators as ts
from attr import attrs, attrib
from attr.validators import instance_of
from playwright.async_api import Browser, Playwright, Page, BrowserContext, ElementHandle
from playwright.async_api import async_playwright, TimeoutError
from rich.progress import track
from utils import settings
from utils.console import print_step, print_substep
import webdriver.common as common
common.default_exception = TimeoutError
@attrs
class Browser:
"""
Args:
default_Viewport (dict):Pyppeteer Browser default_Viewport options
browser (BrowserCls): Pyppeteer Browser instance
"""
default_Viewport: dict = attrib(
validator=instance_of(dict),
default={
# 9x21 to see long posts
"width": 500,
"height": 1200,
},
kw_only=True,
)
playwright: Playwright
browser: Browser
context: BrowserContext
async def get_browser(
self,
) -> None:
"""
Creates Playwright instance & browser
"""
self.playwright = await async_playwright().start()
self.browser = await self.playwright.chromium.launch()
self.context = await self.browser.new_context(viewport=self.default_Viewport)
async def close_browser(
self,
) -> None:
"""
Closes Playwright stuff
"""
await self.context.close()
await self.browser.close()
await self.playwright.stop()
class Flaky:
"""
All methods decorated with function catching default exceptions and writing logs
"""
@staticmethod
@common.catch_exception
async def find_element(
selector: str,
page_instance: Page,
options: Optional[dict] = None,
) -> ElementHandle:
return (
await page_instance.wait_for_selector(selector, **options)
if options
else await page_instance.wait_for_selector(selector)
)
@common.catch_exception
async def click(
self,
page_instance: Optional[Page] = None,
query: Optional[str] = None,
options: Optional[dict] = None,
*,
find_options: Optional[dict] = None,
element: Optional[ElementHandle] = None,
) -> None:
if element:
await element.click(**options) if options else await element.click()
else:
results = (
await self.find_element(query, page_instance, **find_options)
if find_options
else await self.find_element(query, page_instance)
)
await results.click(**options) if options else await results.click()
@common.catch_exception
async def screenshot(
self,
page_instance: Optional[Page] = None,
query: Optional[str] = None,
options: Optional[dict] = None,
*,
find_options: Optional[dict] = None,
element: Optional[ElementHandle] = None,
) -> None:
if element:
await element.screenshot(**options) if options else await element.screenshot()
else:
results = (
await self.find_element(query, page_instance, **find_options)
if find_options
else await self.find_element(query, page_instance)
)
await results.screenshot(**options) if options else await results.screenshot()
@attrs(auto_attribs=True)
class RedditScreenshot(Flaky, Browser):
"""
Args:
reddit_object (Dict): Reddit object received from reddit/subreddit.py
screenshot_idx (int): List with indexes of voiced comments
story_mode (bool): If submission is a story takes screenshot of the story
"""
reddit_object: dict
screenshot_idx: list
story_mode: Optional[bool] = attrib(
validator=instance_of(bool),
default=False,
kw_only=True
)
def __attrs_post_init__(
self
):
self.post_lang: Optional[bool] = settings.config["reddit"]["thread"]["post_lang"]
async def __dark_theme( # TODO isn't working
self,
page_instance: Page,
) -> None:
"""
Enables dark theme in Reddit
Args:
page_instance: Pyppeteer page instance with reddit page opened
"""
await self.click(
page_instance,
".header-user-dropdown",
)
# It's normal not to find it, sometimes there is none :shrug:
await self.click(
page_instance,
"button >> span:has-text('Settings')",
)
await self.click(
page_instance,
"button >> span:has-text('Dark Mode')",
)
# Closes settings
await self.click(
page_instance,
".header-user-dropdown"
)
async def __close_nsfw(
self,
page_instance: Page,
) -> None:
"""
Closes NSFW stuff
Args:
page_instance: Instance of main page
"""
print_substep("Post is NSFW. You are spicy...")
# Triggers indirectly reload
await self.click(
page_instance,
"button:has-text('Yes')",
{"timeout": 5000},
)
# Await indirect reload
await page_instance.wait_for_load_state()
await self.click(
page_instance,
"button:has-text('Click to see nsfw')",
{"timeout": 5000},
)
async def __collect_comment(
self,
comment_obj: dict,
filename_idx: int,
) -> None:
"""
Makes a screenshot of the comment
Args:
comment_obj: prew comment object
filename_idx: index for the filename
"""
comment_page = await self.context.new_page()
await comment_page.goto(f'https://reddit.com{comment_obj["comment_url"]}')
# Translates submission' comment
if self.post_lang:
comment_tl = ts.google(
comment_obj["comment_body"],
to_language=self.post_lang,
)
await comment_page.evaluate(
'([comment_id, comment_tl]) => document.querySelector(`#t1_${comment_id} > div:nth-child(2) > div > div[data-testid="comment"] > div`).textContent = comment_tl', # noqa
[comment_obj["comment_id"], comment_tl],
)
await self.screenshot(
comment_page,
f"id=t1_{comment_obj['comment_id']}",
{"path": f"assets/temp/png/comment_{filename_idx}.png"},
)
# WIP TODO test it
async def __collect_story(
self,
main_page: Page,
):
# Translates submission text
if self.post_lang:
story_tl = ts.google(
self.reddit_object["thread_post"],
to_language=self.post_lang,
)
split_story_tl = story_tl.split('\n')
await main_page.evaluate(
"(split_story_tl) => split_story_tl.map(function(element, i) { return [element, document.querySelectorAll('[data-test-id=\"post-content\"] > [data-click-id=\"text\"] > div > p')[i]]; }).forEach(mappedElement => mappedElement[1].textContent = mappedElement[0])", # noqa
split_story_tl,
)
await self.screenshot(
main_page,
"//div[@data-test-id='post-content']//div[@data-click-id='text']",
{"path": "assets/temp/png/story_content.png"},
)
async def download(
self,
):
"""
Downloads screenshots of reddit posts as seen on the web. Downloads to assets/temp/png
"""
print_step("Downloading screenshots of reddit posts...")
print_substep("Launching Headless Browser...")
await self.get_browser()
# ! Make sure the reddit screenshots folder exists
Path("assets/temp/png").mkdir(parents=True, exist_ok=True)
# Get the thread screenshot
reddit_main = await self.context.new_page()
await reddit_main.goto(self.reddit_object["thread_url"]) # noqa
if settings.config["settings"]["theme"] == "dark":
await self.__dark_theme(reddit_main)
if self.reddit_object["is_nsfw"]:
# This means the post is NSFW and requires to click the proceed button.
await self.__close_nsfw(reddit_main)
# Translates submission title
if self.post_lang:
print_substep("Translating post...")
texts_in_tl = ts.google(
self.reddit_object["thread_title"],
to_language=self.post_lang,
)
await reddit_main.evaluate(
f"(texts_in_tl) => document.querySelector('[data-test-id=\"post-content\"] > div:nth-child(3) > div > div').textContent = texts_in_tl", # noqa
texts_in_tl,
)
else:
print_substep("Skipping translation...")
# No sense to move it to common.py
async_tasks_primary = ( # noqa
[
self.__collect_comment(self.reddit_object["comments"][idx], idx) for idx in
self.screenshot_idx
]
if not self.story_mode
else [
self.__collect_story(reddit_main)
]
)
async_tasks_primary.append(
self.screenshot(
reddit_main,
f"id=t3_{self.reddit_object['thread_id']}",
{"path": "assets/temp/png/title.png"},
)
)
for idx, chunked_tasks in enumerate(
[chunk for chunk in common.chunks(async_tasks_primary, 10)],
start=1,
):
chunk_list = async_tasks_primary.__len__() // 10 + (1 if async_tasks_primary.__len__() % 10 != 0 else 0)
for task in track(
as_completed(chunked_tasks),
description=f"Downloading comments: Chunk {idx}/{chunk_list}",
total=chunked_tasks.__len__(),
):
await task
print_substep("Comments downloaded Successfully.", style="bold green")
await self.close_browser()

@ -0,0 +1,371 @@
from asyncio import as_completed
from pyppeteer import launch
from pyppeteer.page import Page as PageCls
from pyppeteer.browser import Browser as BrowserCls
from pyppeteer.element_handle import ElementHandle as ElementHandleCls
from pyppeteer.errors import TimeoutError as BrowserTimeoutError
from pathlib import Path
from utils import settings
from utils.console import print_step, print_substep
from rich.progress import track
import translators as ts
from attr import attrs, attrib
from attr.validators import instance_of
from typing import Optional
import webdriver.common as common
common.default_exception = BrowserTimeoutError
@attrs
class Browser:
"""
Args:
default_Viewport (dict):Pyppeteer Browser default_Viewport options
browser (BrowserCls): Pyppeteer Browser instance
"""
default_Viewport: dict = attrib(
validator=instance_of(dict),
default={
# 9x21 to see long posts
"defaultViewport": {
"width": 500,
"height": 1200,
},
},
kw_only=True,
)
browser: BrowserCls
async def get_browser(
self,
) -> None:
"""
Creates Pyppeteer browser
"""
self.browser = await launch(self.default_Viewport)
async def close_browser(
self,
) -> None:
"""
Closes Pyppeteer browser
"""
await self.browser.close()
class Wait:
@staticmethod
@common.catch_exception
async def find_xpath(
page_instance: PageCls,
xpath: Optional[str] = None,
options: Optional[dict] = None,
) -> 'ElementHandleCls':
"""
Explicitly finds element on the page
Args:
page_instance: Pyppeteer page instance
xpath: xpath query
options: Pyppeteer waitForXPath parameters
Available options are:
* ``visible`` (bool): wait for element to be present in DOM and to be
visible, i.e. to not have ``display: none`` or ``visibility: hidden``
CSS properties. Defaults to ``False``.
* ``hidden`` (bool): wait for element to not be found in the DOM or to
be hidden, i.e. have ``display: none`` or ``visibility: hidden`` CSS
properties. Defaults to ``False``.
* ``timeout`` (int|float): maximum time to wait for in milliseconds.
Defaults to 30000 (30 seconds). Pass ``0`` to disable timeout.
Returns:
Pyppeteer element instance
"""
if options:
el = await page_instance.waitForXPath(xpath, options=options)
else:
el = await page_instance.waitForXPath(xpath)
return el
@common.catch_exception
async def click(
self,
page_instance: Optional[PageCls] = None,
xpath: Optional[str] = None,
options: Optional[dict] = None,
*,
find_options: Optional[dict] = None,
el: Optional[ElementHandleCls] = None,
) -> None:
"""
Clicks on the element
Args:
page_instance: Pyppeteer page instance
xpath: xpath query
find_options: Pyppeteer waitForXPath parameters
options: Pyppeteer click parameters
el: Pyppeteer element instance
"""
if not el:
el = await self.find_xpath(page_instance, xpath, find_options)
if options:
await el.click(options)
else:
await el.click()
@common.catch_exception
async def screenshot(
self,
page_instance: Optional[PageCls] = None,
xpath: Optional[str] = None,
options: Optional[dict] = None,
*,
find_options: Optional[dict] = None,
el: Optional[ElementHandleCls] = None,
) -> None:
"""
Makes a screenshot of the element
Args:
page_instance: Pyppeteer page instance
xpath: xpath query
options: Pyppeteer screenshot parameters
find_options: Pyppeteer waitForXPath parameters
el: Pyppeteer element instance
"""
if not el:
el = await self.find_xpath(page_instance, xpath, find_options)
if options:
await el.screenshot(options)
else:
await el.screenshot()
@attrs(auto_attribs=True)
class RedditScreenshot(Browser, Wait):
"""
Args:
reddit_object (Dict): Reddit object received from reddit/subreddit.py
screenshot_idx (int): List with indexes of voiced comments
story_mode (bool): If submission is a story takes screenshot of the story
"""
reddit_object: dict
screenshot_idx: list
story_mode: Optional[bool] = attrib(
validator=instance_of(bool),
default=False,
kw_only=True
)
def __attrs_post_init__(
self,
):
self.post_lang: Optional[bool] = settings.config["reddit"]["thread"]["post_lang"]
async def __dark_theme(
self,
page_instance: PageCls,
) -> None:
"""
Enables dark theme in Reddit
Args:
page_instance: Pyppeteer page instance with reddit page opened
"""
await self.click(
page_instance,
"//div[@class='header-user-dropdown']",
find_options={"timeout": 5000},
)
# It's normal not to find it, sometimes there is none :shrug:
await self.click(
page_instance,
"//span[text()='Settings']/ancestor::button[1]",
find_options={"timeout": 5000},
)
await self.click(
page_instance,
"//span[text()='Dark Mode']/ancestor::button[1]",
find_options={"timeout": 5000},
)
# Closes settings
await self.click(
page_instance,
"//div[@class='header-user-dropdown']",
find_options={"timeout": 5000},
)
async def __close_nsfw(
self,
page_instance: PageCls,
) -> None:
"""
Closes NSFW stuff
Args:
page_instance: Instance of main page
"""
from asyncio import ensure_future
print_substep("Post is NSFW. You are spicy...")
# To await indirectly reload
navigation = ensure_future(page_instance.waitForNavigation())
# Triggers indirectly reload
await self.click(
page_instance,
'//button[text()="Yes"]',
find_options={"timeout": 5000},
)
# Await reload
await navigation
await self.click(
page_instance,
'//button[text()="Click to see nsfw"]',
find_options={"timeout": 5000},
)
async def __collect_comment(
self,
comment_obj: dict,
filename_idx: int,
) -> None:
"""
Makes a screenshot of the comment
Args:
comment_obj: prew comment object
filename_idx: index for the filename
"""
comment_page = await self.browser.newPage()
await comment_page.goto(f'https://reddit.com{comment_obj["comment_url"]}')
# Translates submission' comment
if self.post_lang:
comment_tl = ts.google(
comment_obj["comment_body"],
to_language=self.post_lang,
)
await comment_page.evaluate(
'([comment_id, comment_tl]) => document.querySelector(`#t1_${comment_id} > div:nth-child(2) > div > div[data-testid="comment"] > div`).textContent = comment_tl', # noqa
[comment_obj["comment_id"], comment_tl],
)
await self.screenshot(
comment_page,
f"//div[@id='t1_{comment_obj['comment_id']}']",
{"path": f"assets/temp/png/comment_{filename_idx}.png"},
)
# WIP TODO test it
async def __collect_story(
self,
main_page: PageCls,
):
# Translates submission text
if self.post_lang:
story_tl = ts.google(
self.reddit_object["thread_post"],
to_language=self.post_lang,
)
split_story_tl = story_tl.split('\n')
await main_page.evaluate(
"(split_story_tl) => split_story_tl.map(function(element, i) { return [element, document.querySelectorAll('[data-test-id=\"post-content\"] > [data-click-id=\"text\"] > div > p')[i]]; }).forEach(mappedElement => mappedElement[1].textContent = mappedElement[0])", # noqa
split_story_tl,
)
await self.screenshot(
main_page,
"//div[@data-test-id='post-content']//div[@data-click-id='text']",
{"path": "assets/temp/png/story_content.png"},
)
async def download(
self,
):
"""
Downloads screenshots of reddit posts as seen on the web. Downloads to assets/temp/png
"""
print_step("Downloading screenshots of reddit posts...")
print_substep("Launching Headless Browser...")
await self.get_browser()
# ! Make sure the reddit screenshots folder exists
Path("assets/temp/png").mkdir(parents=True, exist_ok=True)
# Get the thread screenshot
reddit_main = await self.browser.newPage()
await reddit_main.goto(self.reddit_object["thread_url"]) # noqa
if settings.config["settings"]["theme"] == "dark":
await self.__dark_theme(reddit_main)
if self.reddit_object["is_nsfw"]:
# This means the post is NSFW and requires to click the proceed button.
await self.__close_nsfw(reddit_main)
# Translates submission title
if self.post_lang:
print_substep("Translating post...")
texts_in_tl = ts.google(
self.reddit_object["thread_title"],
to_language=self.post_lang,
)
await reddit_main.evaluate(
f"(texts_in_tl) => document.querySelector('[data-test-id=\"post-content\"] > div:nth-child(3) > div > div').textContent = texts_in_tl", # noqa
texts_in_tl,
)
else:
print_substep("Skipping translation...")
# No sense to move it to common.py
async_tasks_primary = ( # noqa
[
self.__collect_comment(self.reddit_object["comments"][idx], idx) for idx in
self.screenshot_idx
]
if not self.story_mode
else [
self.__collect_story(reddit_main)
]
)
async_tasks_primary.append(
self.screenshot(
reddit_main,
f"//div[@data-testid='post-container']",
{"path": "assets/temp/png/title.png"},
)
)
for idx, chunked_tasks in enumerate(
[chunk for chunk in common.chunks(async_tasks_primary, 10)],
start=1,
):
chunk_list = async_tasks_primary.__len__() // 10 + (1 if async_tasks_primary.__len__() % 10 != 0 else 0)
for task in track(
as_completed(chunked_tasks),
description=f"Downloading comments: Chunk {idx}/{chunk_list}",
total=chunked_tasks.__len__(),
):
await task
print_substep("Comments downloaded Successfully.", style="bold green")
await self.close_browser()

@ -0,0 +1,23 @@
from typing import Union
from webdriver.pyppeteer import RedditScreenshot as Pyppeteer
from webdriver.playwright import RedditScreenshot as Playwright
def screenshot_factory(
driver: str,
) -> Union[type(Pyppeteer), type(Playwright)]:
"""
Factory for webdriver
Args:
driver: (str) Name of a driver
Returns:
Webdriver instance
"""
web_drivers = {
"pyppeteer": Pyppeteer,
"playwright": Playwright,
}
return web_drivers[driver]
Loading…
Cancel
Save