Use correct formatting

pull/432/head
Callum Leslie 3 years ago
parent 1ee299d6f0
commit 48eac4cd94

@ -6,6 +6,7 @@ from rich.progress import track
from moviepy.editor import AudioFileClip, CompositeAudioClip, concatenate_audioclips
from utils.console import print_step, print_substep
class TTSEngine:
"""Calls the given TTS engine to reduce code duplication and allow multiple TTS engines.
@ -20,7 +21,13 @@ class TTSEngine:
tts_module must take the arguments text and filepath.
"""
def __init__(self, tts_module, reddit_object: dict, path: str = "assets/mp3", max_length: int = 50):
def __init__(
self,
tts_module,
reddit_object: dict,
path: str = "assets/mp3",
max_length: int = 50,
):
self.tts_module = tts_module
self.reddit_object = reddit_object
self.path = path
@ -44,7 +51,10 @@ class TTSEngine:
if self.reddit_object["thread_post"] != "":
self.call_tts("posttext", self.reddit_object["thread_post"])
for idx, comment in track(enumerate(self.reddit_object["comments"]), "Saving..."):
idx = None
for idx, comment in track(
enumerate(self.reddit_object["comments"]), "Saving..."
):
# ! Stop creating mp3 files if the length is greater than max length.
if self.length > self.max_length:
break
@ -58,12 +68,21 @@ class TTSEngine:
def split_post(self, text: str, idx: int) -> str:
split_files = []
split_text = [x.group().strip() for x in re.finditer(fr' *((.{{0,{self.tts_module.max_chars}}})(\.|.$))', text)]
split_text = [
x.group().strip()
for x in re.finditer(
rf" *((.{{0,{self.tts_module.max_chars}}})(\.|.$))", text
)
]
idy = None
for idy, text_cut in enumerate(split_text):
print(f"{idx}-{idy}: {text_cut}\n")
self.call_tts(f"{idx}-{idy}.part", text_cut)
split_files.append(AudioFileClip(f"{self.path}/{idx}-{idy}.part.mp3"))
CompositeAudioClip([concatenate_audioclips(split_files)]).write_audiofile(f"{self.path}/{idx}.mp3", fps=44100, verbose=False, logger=None)
CompositeAudioClip([concatenate_audioclips(split_files)]).write_audiofile(
f"{self.path}/{idx}.mp3", fps=44100, verbose=False, logger=None
)
for i in range(0, idy + 1):
print(f"Cleaning up {self.path}/{idx}-{i}.part.mp3")

@ -2,6 +2,7 @@ from gtts import gTTS
max_chars = 0
def run(text, filepath):
tts = gTTS(text=text, lang="en", slow=False)
tts.save(filepath)

@ -5,9 +5,8 @@ from tts.engine_wrapper import TTSEngine
import tts.google_translate_tts
## Add your provider here on a new line
TTSProviders ={
"GoogleTranslate": tts.google_translate_tts
}
TTSProviders = {"GoogleTranslate": tts.google_translate_tts}
def save_text_to_mp3(reddit_obj):
"""Saves Text to MP3 files.
@ -16,7 +15,8 @@ def save_text_to_mp3(reddit_obj):
reddit_obj : The reddit object you received from the reddit API in the askreddit.py file.
"""
env = os.getenv("TTS_PROVIDER", "")
if env in TTSProviders: text_to_mp3 = TTSEngine(env, reddit_obj)
if env in TTSProviders:
text_to_mp3 = TTSEngine(env, reddit_obj)
else:
chosen = False
choice = ""
@ -29,10 +29,19 @@ def save_text_to_mp3(reddit_obj):
print("Unknown Choice")
else:
chosen = True
text_to_mp3 = TTSEngine(get_case_insensitive_key_value(TTSProviders, choice), reddit_obj)
text_to_mp3 = TTSEngine(
get_case_insensitive_key_value(TTSProviders, choice), reddit_obj
)
return text_to_mp3.run()
def get_case_insensitive_key_value(input_dict, key):
return next((value for dict_key, value in input_dict.items() if dict_key.lower() == key.lower()), None)
return next(
(
value
for dict_key, value in input_dict.items()
if dict_key.lower() == key.lower()
),
None,
)

Loading…
Cancel
Save