diff --git a/tts/engine_wrapper.py b/tts/engine_wrapper.py index bd36f0c..61c0596 100644 --- a/tts/engine_wrapper.py +++ b/tts/engine_wrapper.py @@ -1,29 +1,31 @@ from pathlib import Path -from typing import Callable +from typing import Callable, Tuple from mutagen.mp3 import MP3 from utils.console import print_step, print_substep from rich.progress import track +from moviepy.editor import AudioFileClip, CompositeAudioClip, concatenate_audioclips +import re class TTSEngine: """Calls the given TTS engine to reduce code duplication and allow multiple TTS engines. Args: - tts_function : The function that will be called. Your function should handle the TTS itself and saving to the given path. + tts_module : The TTS module. Your module should handle the TTS itself and saving to the given path under the run method. reddit_object : The reddit object that contains the posts to read. path (Optional) : The unix style path to save the mp3 files to. This must not have leading or trailing slashes. max_length (Optional) : The maximum length of the mp3 files in total. Notes: - tts_function must take the arguments text and filepath. + tts_module must take the arguments text and filepath. """ - def __init__(self, tts_function: Callable[[str, str], None], reddit_object: dict, path: str = "assets/mp3", max_length: int = 50): - self.tts_function = tts_function + def __init__(self, tts_module, reddit_object: dict, path: str = "assets/mp3", max_length: int = 50): + self.tts_module = tts_module self.reddit_object = reddit_object self.path = path self.max_length = max_length self.length = 0 - def run(self): + def run(self) -> Tuple[int, int]: Path(self.path).mkdir(parents=True, exist_ok=True) @@ -40,16 +42,32 @@ class TTSEngine: if self.reddit_object["thread_post"] != "": self.call_tts("posttext", self.reddit_object["thread_post"]) - for idx, comment in track(enumerate(self.reddit_object["comments"]), "Saving..."): + #for idx, comment in track(enumerate(self.reddit_object["comments"]), "Saving..."): + for idx, comment in enumerate(self.reddit_object["comments"]): # ! Stop creating mp3 files if the length is greater than max length. if self.length > self.max_length: break - - self.call_tts(f"{idx}",comment["comment_body"]) + if not self.tts_module.max_chars: + self.call_tts(f"{idx}",comment["comment_body"]) + else: + self.split_post(comment["comment_body"], idx) print_substep("Saved Text to MP3 files successfully.", style="bold green") return self.length, idx - def call_tts(self, filename, text): - self.tts_function(text=text, filepath=f"{self.path}/{filename}.mp3") + def split_post(self, text: str, idx:int) -> str: + split_files = [] + split_text = [x.group().strip() for x in re.finditer(fr' *((.{{0,{self.tts_module.max_chars}}})(\.|.$))', text)] + for idy, text_cut in enumerate(split_text): + print(f"{idx}-{idy}: {text_cut}\n") + self.call_tts(f"{idx}-{idy}.part", text_cut) + split_files.append(AudioFileClip(f"{self.path}/{idx}-{idy}.part.mp3")) + CompositeAudioClip([concatenate_audioclips(split_files)]).write_audiofile(f"{self.path}/{idx}.mp3", fps=44100, verbose=False, logger=None) + + for i in range(0, idy + 1): + print(f"Cleaning up {self.path}/{idx}-{i}.part.mp3") + Path(f"{self.path}/{idx}-{i}.part.mp3").unlink() + + def call_tts(self, filename: str, text: str): + self.tts_module.run(text=text, filepath=f"{self.path}/{filename}.mp3") self.length += MP3(f"{self.path}/{filename}.mp3").info.length diff --git a/tts/example_tts.py b/tts/example_tts.py index d297a59..8fd3b0e 100644 --- a/tts/example_tts.py +++ b/tts/example_tts.py @@ -1,9 +1,14 @@ -## Your TTS imports, etc. - +################################ +# Your TTS imports, etc. +# ... +# max_chars = - this is the maximum number of characters your TTS supports +# set to 0 for no limit. +# # def run(text, filepath): # Call your TTS on the text variable # ... # Save your TTS to the a file using the filepath variable. The engine assumes it will be an mp3 file. # ... # -# any extra functions you need \ No newline at end of file +# any extra functions you need +################################ \ No newline at end of file diff --git a/tts/google_translate_tts.py b/tts/google_translate_tts.py index df1b998..58c8719 100644 --- a/tts/google_translate_tts.py +++ b/tts/google_translate_tts.py @@ -1,5 +1,7 @@ from gtts import gTTS +max_chars = 0 + def run(text, filepath): tts = gTTS(text=text, lang="en", slow=False) tts.save(filepath) diff --git a/video_creation/voices.py b/video_creation/voices.py index ea8e501..5c16088 100644 --- a/video_creation/voices.py +++ b/video_creation/voices.py @@ -9,5 +9,5 @@ def save_text_to_mp3(reddit_obj): Args: reddit_obj : The reddit object you received from the reddit API in the askreddit.py file. """ - text_to_mp3 = TTSEngine(tts.google_translate_tts.run, reddit_obj) + text_to_mp3 = TTSEngine(tts.google_translate_tts, reddit_obj) return text_to_mp3.run()