From acf679bfb6838733dc25a26204454aad4abe4fd6 Mon Sep 17 00:00:00 2001 From: Drugsosos <44712637+Drugsosos@users.noreply.github.com> Date: Fri, 15 Jul 2022 00:01:37 +0300 Subject: [PATCH] self review: improved typing & logging, removed unused imports, fixes in README --- README.md | 1 - TTS/GTTS.py | 8 ++++ TTS/TikTok.py | 27 +++++++----- TTS/aws_polly.py | 9 +++- TTS/common.py | 57 ++++++++++++++++++++++++- TTS/engine_wrapper.py | 26 +++++++++++ TTS/streamlabs_polly.py | 9 ++++ main.py | 1 - utils/settings.py | 2 +- utils/voice.py | 2 +- video_creation/data/videos.json | 2 +- video_creation/final_video.py | 4 +- video_creation/screenshot_downloader.py | 12 +++--- video_creation/voices.py | 3 +- 14 files changed, 136 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index cb82738..8aaf3d1 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,6 @@ The only original thing being done is the editing and gathering of all materials ## Requirements - Python 3.9+ -- Playwright (this should install automatically in installation) ## Installation 👩‍💻 diff --git a/TTS/GTTS.py b/TTS/GTTS.py index c8d6ae8..8974ddc 100644 --- a/TTS/GTTS.py +++ b/TTS/GTTS.py @@ -5,12 +5,20 @@ from gtts import gTTS class GTTS: max_chars = 0 + # voices = [] @staticmethod async def run( text, filepath ) -> None: + """ + Calls for TTS api + + Args: + text: text to be voiced over + filepath: name of the audio file + """ tts = gTTS( text=text, lang=settings.config["reddit"]["thread"]["post_lang"] or "en", diff --git a/TTS/TikTok.py b/TTS/TikTok.py index 6a23bb8..83521b3 100644 --- a/TTS/TikTok.py +++ b/TTS/TikTok.py @@ -1,10 +1,8 @@ -import base64 from utils import settings import requests from requests.adapters import HTTPAdapter, Retry from attr import attrs, attrib -from attr.validators import instance_of from TTS.common import BaseApiTTS, get_random_voice @@ -74,8 +72,20 @@ class TikTok(BaseApiTTS): # TikTok Text-to-Speech Wrapper max_chars = 300 decode_base64 = True - def __attrs_post_init__(self): - self.voice = ( + def make_request( + self, + text: str, + ): + """ + Makes a requests to remote TTS service + + Args: + text: text to be voice over + + Returns: + Request's response + """ + voice = ( get_random_voice(voices, 'human') if self.random_voice else str(settings.config['settings']['tts']['tiktok_voice']).lower() @@ -83,16 +93,11 @@ class TikTok(BaseApiTTS): # TikTok Text-to-Speech Wrapper voice.lower() for dict_title in voices for voice in voices[dict_title]] else get_random_voice(voices, 'human') ) - - def make_request( - self, - text: str, - ): try: r = requests.post( self.uri_base, params={ - 'text_speaker': self.voice, + 'text_speaker': voice, 'req_text': text, 'speaker_map_type': 0, }) @@ -103,6 +108,6 @@ class TikTok(BaseApiTTS): # TikTok Text-to-Speech Wrapper adapter = HTTPAdapter(max_retries=retry) session.mount('http://', adapter) session.mount('https://', adapter) - r = session.post(f'{self.uri_base}{self.voice}&req_text={text}&speaker_map_type=0') + r = session.post(f'{self.uri_base}{voice}&req_text={text}&speaker_map_type=0') # print(r.text) return r.json()['data']['v_str'] diff --git a/TTS/aws_polly.py b/TTS/aws_polly.py index 9d52f6f..f8c28cd 100644 --- a/TTS/aws_polly.py +++ b/TTS/aws_polly.py @@ -37,7 +37,14 @@ class AWSPolly: self, text, filepath, - ): + ) -> None: + """ + Calls for TTS api + + Args: + text: text to be voiced over + filepath: name of the audio file + """ try: session = Session(profile_name='polly') polly = session.client('polly') diff --git a/TTS/common.py b/TTS/common.py index 73884f4..d4d0200 100644 --- a/TTS/common.py +++ b/TTS/common.py @@ -12,6 +12,16 @@ class BaseApiTTS: text: str, max_length: int, ) -> list: + """ + Splits text if it's too long to be a query + + Args: + text: text to be sanitized + max_length: maximum length of the query + + Returns: + Split text as a list + """ # Split by comma or dot (else you can lose intonations), if there is non, split by groups of 299 chars if '.' in text and all([split_text.__len__() < max_length for split_text in text.split('.')]): return text.split('.') @@ -26,6 +36,13 @@ class BaseApiTTS: output_text: str, filepath: str, ) -> None: + """ + Writes and decodes TTS responses in files + + Args: + output_text: text to be written + filepath: path/name of the file + """ decoded_text = base64.b64decode(output_text) if self.decode_base64 else output_text with open(filepath, 'wb') as out: @@ -36,6 +53,16 @@ class BaseApiTTS: text: str, filepath: str, ) -> None: + """ + Calls for TTS api and writes audio file + + Args: + text: text to be voice over + filepath: path/name of the file + + Returns: + + """ output_text = '' if len(text) > self.max_chars: for part in self.text_len_sanitize(text, self.max_chars): @@ -50,19 +77,45 @@ def get_random_voice( voices: Union[list, dict], key: Optional[str] = None, ) -> str: + """ + Return random voice from list or dict + + Args: + voices: list or dict of voices + key: key of a dict if you are using one + + Returns: + random voice as a str + """ if isinstance(voices, list): return choice(voices) else: - return choice(voices[key]) + return choice(voices[key] if key else list(voices.values())[0]) def audio_length( path: str, ) -> float | int: + """ + Gets the length of the audio file + + Args: + path: audio file path + + Returns: + length in seconds as an int + """ from mutagen.mp3 import MP3 try: audio = MP3(path) return audio.info.length - except Exception as e: # TODO add logging + except Exception as e: + import logging + + logger = logging.getLogger('spam_application') + logger.setLevel(logging.DEBUG) + handler = logging.FileHandler('tts_log', mode='a+', encoding='utf-8') + logger.addHandler(handler) + logger.error('Error occurred in audio_length:', e) return 0 diff --git a/TTS/engine_wrapper.py b/TTS/engine_wrapper.py index 64e439a..0733198 100644 --- a/TTS/engine_wrapper.py +++ b/TTS/engine_wrapper.py @@ -39,7 +39,9 @@ class TTSEngine: ) def __attrs_post_init__(self): + # Calls an instance of the tts_module class self.tts_module = self.tts_module() + # Loading settings from the config self.max_length: int = settings.config['settings']['video_length'] self.time_before_tts: float = settings.config['settings']['time_before_tts'] self.time_between_pictures: float = settings.config['settings']['time_between_pictures'] @@ -51,7 +53,12 @@ class TTSEngine: def run( self ) -> list: + """ + Voices over comments & title of the submission + Returns: + Indexes of comments to be used in the final video + """ Path(self.path).mkdir(parents=True, exist_ok=True) # This file needs to be removed in case this post does not use post text @@ -87,6 +94,16 @@ class TTSEngine: filename: str, text: str ) -> bool: + """ + Calls for TTS api from the factory + + Args: + filename: name of audio file w/o .mp3 + text: text to be voiced over + + Returns: + True if audio files not exceeding the maximum length else false + """ if not text: return False @@ -107,6 +124,15 @@ class TTSEngine: def process_text( text: str, ) -> str: + """ + Sanitizes text for illegal characters and translates text + + Args: + text: text to be sanitized & translated + + Returns: + Processed text as a str + """ lang = settings.config['reddit']['thread']['post_lang'] new_text = sanitize_text(text) if lang: diff --git a/TTS/streamlabs_polly.py b/TTS/streamlabs_polly.py index d2b765a..ca6102b 100644 --- a/TTS/streamlabs_polly.py +++ b/TTS/streamlabs_polly.py @@ -42,6 +42,15 @@ class StreamlabsPolly(BaseApiTTS): self, text, ): + """ + Makes a requests to remote TTS service + + Args: + text: text to be voice over + + Returns: + Request's response + """ voice = ( get_random_voice(voices) if self.random_voice diff --git a/main.py b/main.py index a72246c..6d6f04a 100755 --- a/main.py +++ b/main.py @@ -7,7 +7,6 @@ from utils.cleanup import cleanup from utils.console import print_markdown, print_step from utils import settings -# from utils.checker import envUpdate from video_creation.background import ( get_background_config, ) diff --git a/utils/settings.py b/utils/settings.py index 1c77eba..8acae2a 100755 --- a/utils/settings.py +++ b/utils/settings.py @@ -9,7 +9,7 @@ from utils.console import handle_input console = Console() -config = dict() # calling instance of a dict to calm lint down +config = dict() # calling instance of a dict to calm lint down (dict[any] will work as well) def crawl(obj: dict, func=lambda x, y: print(x, y, end="\n"), path=None): diff --git a/utils/voice.py b/utils/voice.py index 7d20b1b..3113227 100644 --- a/utils/voice.py +++ b/utils/voice.py @@ -11,7 +11,7 @@ if sys.version_info[0] >= 3: def check_ratelimit( - response: Response + response: Response, ): """ Checks if the response is a ratelimit response. diff --git a/video_creation/data/videos.json b/video_creation/data/videos.json index 0637a08..fe51488 100644 --- a/video_creation/data/videos.json +++ b/video_creation/data/videos.json @@ -1 +1 @@ -[] \ No newline at end of file +[] diff --git a/video_creation/final_video.py b/video_creation/final_video.py index 454b675..578eaab 100755 --- a/video_creation/final_video.py +++ b/video_creation/final_video.py @@ -13,7 +13,7 @@ from moviepy.editor import ( CompositeAudioClip, CompositeVideoClip, ) -from moviepy.video.io.ffmpeg_tools import ffmpeg_merge_video_audio, ffmpeg_extract_subclip +from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip from rich.console import Console from rich.progress import track @@ -35,7 +35,7 @@ def name_normalize( name = re.sub(r'(\d+)\s?\/\s?(\d+)', r'\1 of \2', name) name = re.sub(r'(\w+)\s?\/\s?(\w+)', r'\1 or \2', name) name = re.sub(r'\/', '', name) - name[:30] + name[:30] # the hell this little guy does? lang = settings.config['reddit']['thread']['post_lang'] if lang: diff --git a/video_creation/screenshot_downloader.py b/video_creation/screenshot_downloader.py index 60aaa4f..a779f6f 100644 --- a/video_creation/screenshot_downloader.py +++ b/video_creation/screenshot_downloader.py @@ -25,7 +25,7 @@ _exceptions = TypeVar('_exceptions', bound=Optional[Union[type, tuple, list]]) @attrs class ExceptionDecorator: """ - Factory for decorating functions + Decorator factory for catching exceptions and writing logs """ exception: Optional[_exceptions] = attrib(default=None) __default_exception: _exceptions = attrib(default=BrowserTimeoutError) @@ -45,15 +45,17 @@ class ExceptionDecorator: except Exception as caughtException: import logging - logging.basicConfig(filename='.webdriver.log', filemode='a+', - encoding='utf-8', level=logging.ERROR) + logger = logging.getLogger('webdriver_log') + logger.setLevel(logging.DEBUG) + handler = logging.FileHandler('.webdriver.log', mode='a+', encoding='utf-8') + logger.addHandler(handler) if isinstance(self.exception, type): if not type(caughtException) == self.exception: - logging.error(f'unexpected error - {caughtException}') + logger.error(f'unexpected error - {caughtException}') else: if not type(caughtException) in self.exception: - logging.error(f'unexpected error - {caughtException}') + logger.error(f'unexpected error - {caughtException}') return wrapper diff --git a/video_creation/voices.py b/video_creation/voices.py index 95f0b2b..7d78e5f 100644 --- a/video_creation/voices.py +++ b/video_creation/voices.py @@ -29,7 +29,7 @@ def save_text_to_mp3( """ voice = settings.config['settings']['tts']['choice'] - if str(voice).casefold() not in map(lambda _: _.casefold(), TTSProviders): + if voice.casefold() not in map(lambda _: _.casefold(), TTSProviders): while True: print_step('Please choose one of the following TTS providers: ') print_table(TTSProviders) @@ -45,6 +45,7 @@ def get_case_insensitive_key_value( input_dict, key, ) -> object: + # TODO add a factory later return next( (value for dict_key, value in input_dict.items() if dict_key.lower() == key.lower()), None,