diff --git a/tts/engine_wrapper.py b/tts/engine_wrapper.py index 046713d..0465d72 100644 --- a/tts/engine_wrapper.py +++ b/tts/engine_wrapper.py @@ -6,6 +6,7 @@ from rich.progress import track from moviepy.editor import AudioFileClip, CompositeAudioClip, concatenate_audioclips from utils.console import print_step, print_substep + class TTSEngine: """Calls the given TTS engine to reduce code duplication and allow multiple TTS engines. @@ -20,7 +21,13 @@ class TTSEngine: tts_module must take the arguments text and filepath. """ - def __init__(self, tts_module, reddit_object: dict, path: str = "assets/mp3", max_length: int = 50): + def __init__( + self, + tts_module, + reddit_object: dict, + path: str = "assets/mp3", + max_length: int = 50, + ): self.tts_module = tts_module self.reddit_object = reddit_object self.path = path @@ -44,26 +51,38 @@ class TTSEngine: if self.reddit_object["thread_post"] != "": self.call_tts("posttext", self.reddit_object["thread_post"]) - for idx, comment in track(enumerate(self.reddit_object["comments"]), "Saving..."): + idx = None + for idx, comment in track( + enumerate(self.reddit_object["comments"]), "Saving..." + ): # ! Stop creating mp3 files if the length is greater than max length. if self.length > self.max_length: break if not self.tts_module.max_chars: - self.call_tts(f"{idx}",comment["comment_body"]) + self.call_tts(f"{idx}", comment["comment_body"]) else: self.split_post(comment["comment_body"], idx) print_substep("Saved Text to MP3 files successfully.", style="bold green") return self.length, idx - def split_post(self, text: str, idx:int) -> str: + def split_post(self, text: str, idx: int) -> str: split_files = [] - split_text = [x.group().strip() for x in re.finditer(fr' *((.{{0,{self.tts_module.max_chars}}})(\.|.$))', text)] + split_text = [ + x.group().strip() + for x in re.finditer( + rf" *((.{{0,{self.tts_module.max_chars}}})(\.|.$))", text + ) + ] + + idy = None for idy, text_cut in enumerate(split_text): print(f"{idx}-{idy}: {text_cut}\n") self.call_tts(f"{idx}-{idy}.part", text_cut) split_files.append(AudioFileClip(f"{self.path}/{idx}-{idy}.part.mp3")) - CompositeAudioClip([concatenate_audioclips(split_files)]).write_audiofile(f"{self.path}/{idx}.mp3", fps=44100, verbose=False, logger=None) + CompositeAudioClip([concatenate_audioclips(split_files)]).write_audiofile( + f"{self.path}/{idx}.mp3", fps=44100, verbose=False, logger=None + ) for i in range(0, idy + 1): print(f"Cleaning up {self.path}/{idx}-{i}.part.mp3") diff --git a/tts/example_tts.py b/tts/example_tts.py index 8fd3b0e..3b8e785 100644 --- a/tts/example_tts.py +++ b/tts/example_tts.py @@ -11,4 +11,4 @@ # ... # # any extra functions you need -################################ \ No newline at end of file +################################ diff --git a/tts/google_translate_tts.py b/tts/google_translate_tts.py index 58c8719..a28cbf4 100644 --- a/tts/google_translate_tts.py +++ b/tts/google_translate_tts.py @@ -2,6 +2,7 @@ from gtts import gTTS max_chars = 0 + def run(text, filepath): tts = gTTS(text=text, lang="en", slow=False) tts.save(filepath) diff --git a/video_creation/voices.py b/video_creation/voices.py index 667b2bd..4fa4412 100644 --- a/video_creation/voices.py +++ b/video_creation/voices.py @@ -5,9 +5,8 @@ from tts.engine_wrapper import TTSEngine import tts.google_translate_tts ## Add your provider here on a new line -TTSProviders ={ - "GoogleTranslate": tts.google_translate_tts -} +TTSProviders = {"GoogleTranslate": tts.google_translate_tts} + def save_text_to_mp3(reddit_obj): """Saves Text to MP3 files. @@ -15,8 +14,9 @@ def save_text_to_mp3(reddit_obj): Args: reddit_obj : The reddit object you received from the reddit API in the askreddit.py file. """ - env = os.getenv("TTS_PROVIDER","") - if env in TTSProviders: text_to_mp3 = TTSEngine(env, reddit_obj) + env = os.getenv("TTS_PROVIDER", "") + if env in TTSProviders: + text_to_mp3 = TTSEngine(env, reddit_obj) else: chosen = False choice = "" @@ -29,10 +29,19 @@ def save_text_to_mp3(reddit_obj): print("Unknown Choice") else: chosen = True - text_to_mp3 = TTSEngine(get_case_insensitive_key_value(TTSProviders, choice), reddit_obj) + text_to_mp3 = TTSEngine( + get_case_insensitive_key_value(TTSProviders, choice), reddit_obj + ) return text_to_mp3.run() def get_case_insensitive_key_value(input_dict, key): - return next((value for dict_key, value in input_dict.items() if dict_key.lower() == key.lower()), None) \ No newline at end of file + return next( + ( + value + for dict_key, value in input_dict.items() + if dict_key.lower() == key.lower() + ), + None, + )