diff --git a/tts/engine_wrapper.py b/tts/engine_wrapper.py index 3ad917a..e60a093 100644 --- a/tts/engine_wrapper.py +++ b/tts/engine_wrapper.py @@ -44,8 +44,7 @@ class TTSEngine: if self.reddit_object["thread_post"] != "": self.call_tts("posttext", self.reddit_object["thread_post"]) - #for idx, comment in track(enumerate(self.reddit_object["comments"]), "Saving..."): - for idx, comment in enumerate(self.reddit_object["comments"]): + for idx, comment in track(enumerate(self.reddit_object["comments"]), "Saving..."): # ! Stop creating mp3 files if the length is greater than max length. if self.length > self.max_length: break diff --git a/video_creation/voices.py b/video_creation/voices.py index d57642d..deebb56 100644 --- a/video_creation/voices.py +++ b/video_creation/voices.py @@ -15,7 +15,22 @@ def save_text_to_mp3(reddit_obj): Args: reddit_obj : The reddit object you received from the reddit API in the askreddit.py file. """ - text_to_mp3 = TTSEngine(tts.google_translate_tts, reddit_obj) + env = os.getenv("TTS_PROVIDER","") + if env in TTSProviders.keys(): text_to_mp3 = TTSEngine(env, reddit_obj) + else: + chosen = False + choice = "" + while not chosen: + print("Please choose one of the following TTS providers: ") + for i in TTSProviders.keys(): + print(i) + choice = input("\n") + if choice.casefold() not in map(lambda _: _.casefold(), TTSProviders.keys()): + print("Unknown Choice") + else: + chosen = True + text_to_mp3 = TTSEngine(get_case_insensitive_key_value(TTSProviders, choice), reddit_obj) + return text_to_mp3.run()