Use correct formatting

pull/432/head
Callum Leslie 3 years ago
parent 1ee299d6f0
commit 48eac4cd94

@ -6,6 +6,7 @@ from rich.progress import track
from moviepy.editor import AudioFileClip, CompositeAudioClip, concatenate_audioclips from moviepy.editor import AudioFileClip, CompositeAudioClip, concatenate_audioclips
from utils.console import print_step, print_substep from utils.console import print_step, print_substep
class TTSEngine: class TTSEngine:
"""Calls the given TTS engine to reduce code duplication and allow multiple TTS engines. """Calls the given TTS engine to reduce code duplication and allow multiple TTS engines.
@ -20,7 +21,13 @@ class TTSEngine:
tts_module must take the arguments text and filepath. tts_module must take the arguments text and filepath.
""" """
def __init__(self, tts_module, reddit_object: dict, path: str = "assets/mp3", max_length: int = 50): def __init__(
self,
tts_module,
reddit_object: dict,
path: str = "assets/mp3",
max_length: int = 50,
):
self.tts_module = tts_module self.tts_module = tts_module
self.reddit_object = reddit_object self.reddit_object = reddit_object
self.path = path self.path = path
@ -44,26 +51,38 @@ class TTSEngine:
if self.reddit_object["thread_post"] != "": if self.reddit_object["thread_post"] != "":
self.call_tts("posttext", self.reddit_object["thread_post"]) self.call_tts("posttext", self.reddit_object["thread_post"])
for idx, comment in track(enumerate(self.reddit_object["comments"]), "Saving..."): idx = None
for idx, comment in track(
enumerate(self.reddit_object["comments"]), "Saving..."
):
# ! Stop creating mp3 files if the length is greater than max length. # ! Stop creating mp3 files if the length is greater than max length.
if self.length > self.max_length: if self.length > self.max_length:
break break
if not self.tts_module.max_chars: if not self.tts_module.max_chars:
self.call_tts(f"{idx}",comment["comment_body"]) self.call_tts(f"{idx}", comment["comment_body"])
else: else:
self.split_post(comment["comment_body"], idx) self.split_post(comment["comment_body"], idx)
print_substep("Saved Text to MP3 files successfully.", style="bold green") print_substep("Saved Text to MP3 files successfully.", style="bold green")
return self.length, idx return self.length, idx
def split_post(self, text: str, idx:int) -> str: def split_post(self, text: str, idx: int) -> str:
split_files = [] split_files = []
split_text = [x.group().strip() for x in re.finditer(fr' *((.{{0,{self.tts_module.max_chars}}})(\.|.$))', text)] split_text = [
x.group().strip()
for x in re.finditer(
rf" *((.{{0,{self.tts_module.max_chars}}})(\.|.$))", text
)
]
idy = None
for idy, text_cut in enumerate(split_text): for idy, text_cut in enumerate(split_text):
print(f"{idx}-{idy}: {text_cut}\n") print(f"{idx}-{idy}: {text_cut}\n")
self.call_tts(f"{idx}-{idy}.part", text_cut) self.call_tts(f"{idx}-{idy}.part", text_cut)
split_files.append(AudioFileClip(f"{self.path}/{idx}-{idy}.part.mp3")) split_files.append(AudioFileClip(f"{self.path}/{idx}-{idy}.part.mp3"))
CompositeAudioClip([concatenate_audioclips(split_files)]).write_audiofile(f"{self.path}/{idx}.mp3", fps=44100, verbose=False, logger=None) CompositeAudioClip([concatenate_audioclips(split_files)]).write_audiofile(
f"{self.path}/{idx}.mp3", fps=44100, verbose=False, logger=None
)
for i in range(0, idy + 1): for i in range(0, idy + 1):
print(f"Cleaning up {self.path}/{idx}-{i}.part.mp3") print(f"Cleaning up {self.path}/{idx}-{i}.part.mp3")

@ -11,4 +11,4 @@
# ... # ...
# #
# any extra functions you need # any extra functions you need
################################ ################################

@ -2,6 +2,7 @@ from gtts import gTTS
max_chars = 0 max_chars = 0
def run(text, filepath): def run(text, filepath):
tts = gTTS(text=text, lang="en", slow=False) tts = gTTS(text=text, lang="en", slow=False)
tts.save(filepath) tts.save(filepath)

@ -5,9 +5,8 @@ from tts.engine_wrapper import TTSEngine
import tts.google_translate_tts import tts.google_translate_tts
## Add your provider here on a new line ## Add your provider here on a new line
TTSProviders ={ TTSProviders = {"GoogleTranslate": tts.google_translate_tts}
"GoogleTranslate": tts.google_translate_tts
}
def save_text_to_mp3(reddit_obj): def save_text_to_mp3(reddit_obj):
"""Saves Text to MP3 files. """Saves Text to MP3 files.
@ -15,8 +14,9 @@ def save_text_to_mp3(reddit_obj):
Args: Args:
reddit_obj : The reddit object you received from the reddit API in the askreddit.py file. reddit_obj : The reddit object you received from the reddit API in the askreddit.py file.
""" """
env = os.getenv("TTS_PROVIDER","") env = os.getenv("TTS_PROVIDER", "")
if env in TTSProviders: text_to_mp3 = TTSEngine(env, reddit_obj) if env in TTSProviders:
text_to_mp3 = TTSEngine(env, reddit_obj)
else: else:
chosen = False chosen = False
choice = "" choice = ""
@ -29,10 +29,19 @@ def save_text_to_mp3(reddit_obj):
print("Unknown Choice") print("Unknown Choice")
else: else:
chosen = True chosen = True
text_to_mp3 = TTSEngine(get_case_insensitive_key_value(TTSProviders, choice), reddit_obj) text_to_mp3 = TTSEngine(
get_case_insensitive_key_value(TTSProviders, choice), reddit_obj
)
return text_to_mp3.run() return text_to_mp3.run()
def get_case_insensitive_key_value(input_dict, key): def get_case_insensitive_key_value(input_dict, key):
return next((value for dict_key, value in input_dict.items() if dict_key.lower() == key.lower()), None) return next(
(
value
for dict_key, value in input_dict.items()
if dict_key.lower() == key.lower()
),
None,
)

Loading…
Cancel
Save