From c58fa10f53c06a1702f169105080f5e79f77eb61 Mon Sep 17 00:00:00 2001 From: Callum Leslie Date: Sun, 19 Jun 2022 21:25:30 +0100 Subject: [PATCH] Reduced code duplication in TTS engines --- .pylintrc | 2 +- TTS/GTTS.py | 23 +++++---- TTS/POLLY.py | 106 -------------------------------------- TTS/TikTok.py | 107 ++++++++++++--------------------------- TTS/aws_polly.py | 66 ++++++++++++++++++++++++ TTS/engine_wrapper.py | 99 ++++++++++++++++++++++++++++++++++++ TTS/streamlabs_polly.py | 53 +++++++++++++++++++ TTS/swapper.py | 24 --------- utils/console.py | 7 +++ utils/voice.py | 2 +- video_creation/voices.py | 102 +++++++++++++++---------------------- 11 files changed, 316 insertions(+), 275 deletions(-) delete mode 100644 TTS/POLLY.py create mode 100644 TTS/aws_polly.py create mode 100644 TTS/engine_wrapper.py create mode 100644 TTS/streamlabs_polly.py delete mode 100644 TTS/swapper.py diff --git a/.pylintrc b/.pylintrc index e3fead7..b03c808 100644 --- a/.pylintrc +++ b/.pylintrc @@ -149,7 +149,7 @@ disable=raw-checker-failed, suppressed-message, useless-suppression, deprecated-pragma, - use-symbolic-message-instead + use-symbolic-message-instead, attribute-defined-outside-init, invalid-name, missing-docstring, diff --git a/TTS/GTTS.py b/TTS/GTTS.py index fcbcb9b..a0df172 100644 --- a/TTS/GTTS.py +++ b/TTS/GTTS.py @@ -1,13 +1,18 @@ +#!/usr/bin/env python3 +import random from gtts import gTTS +max_chars = 0 + class GTTS: - def tts( - self, - req_text: str = "Google Text To Speech", - filename: str = "title.mp3", - random_speaker=False, - censor=False, - ): - tts = gTTS(text=req_text, lang="en", slow=False) - tts.save(f"{filename}") + def __init__(self): + self.max_chars = 0 + self.voices = [] + + def run(self, text, filepath): + tts = gTTS(text=text, lang="en", slow=False) + tts.save(filepath) + + def randomvoice(self): + return random.choice(self.voices) diff --git a/TTS/POLLY.py b/TTS/POLLY.py deleted file mode 100644 index da1fae0..0000000 --- a/TTS/POLLY.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -import random -import re - -import requests -import sox -from moviepy.audio.AudioClip import concatenate_audioclips, CompositeAudioClip -from moviepy.audio.io.AudioFileClip import AudioFileClip -from requests.exceptions import JSONDecodeError - -voices = [ - "Brian", - "Emma", - "Russell", - "Joey", - "Matthew", - "Joanna", - "Kimberly", - "Amy", - "Geraint", - "Nicole", - "Justin", - "Ivy", - "Kendra", - "Salli", - "Raveena", -] - - -# valid voices https://lazypy.ro/tts/ - - -class POLLY: - def __init__(self): - self.url = "https://streamlabs.com/polly/speak" - - def tts( - self, - req_text: str = "Amazon Text To Speech", - filename: str = "title.mp3", - random_speaker=False, - censor=False, - ): - if random_speaker: - voice = self.randomvoice() - else: - if not os.getenv("VOICE"): - return ValueError( - "Please set the environment variable VOICE to a valid voice. options are: {}".format( - voices - ) - ) - voice = str(os.getenv("VOICE")).capitalize() - body = {"voice": voice, "text": req_text, "service": "polly"} - response = requests.post(self.url, data=body) - try: - voice_data = requests.get(response.json()["speak_url"]) - with open(filename, "wb") as f: - f.write(voice_data.content) - except (KeyError, JSONDecodeError): - if response.json()["error"] == "Text length is too long!": - chunks = [m.group().strip() for m in re.finditer(r" *((.{0,499})(\.|.$))", req_text)] - - audio_clips = [] - cbn = sox.Combiner() - - chunkId = 0 - for chunk in chunks: - body = {"voice": voice, "text": chunk, "service": "polly"} - resp = requests.post(self.url, data=body) - voice_data = requests.get(resp.json()["speak_url"]) - with open(filename.replace(".mp3", f"-{chunkId}.mp3"), "wb") as out: - out.write(voice_data.content) - - audio_clips.append(filename.replace(".mp3", f"-{chunkId}.mp3")) - - chunkId = chunkId + 1 - try: - if len(audio_clips) > 1: - cbn.convert(samplerate=44100, n_channels=2) - cbn.build(audio_clips, filename, "concatenate") - else: - os.rename(audio_clips[0], filename) - except ( - sox.core.SoxError, - FileNotFoundError, - ): # https://github.com/JasonLovesDoggo/RedditVideoMakerBot/issues/67#issuecomment-1150466339 - for clip in audio_clips: - i = audio_clips.index(clip) # get the index of the clip - audio_clips = ( - audio_clips[:i] + [AudioFileClip(clip)] + audio_clips[i + 1 :] - ) # replace the clip with an AudioFileClip - audio_concat = concatenate_audioclips(audio_clips) - audio_composite = CompositeAudioClip([audio_concat]) - audio_composite.write_audiofile(filename, 44100, 2, 2000, None) - - def make_readable(self, text): - """ - Amazon Polly fails to read some symbols properly such as '& (and)'. - So we normalize input text before passing it to the service - """ - text = text.replace("&", "and") - return text - - def randomvoice(self): - return random.choice(voices) diff --git a/TTS/TikTok.py b/TTS/TikTok.py index 662e498..ccec427 100644 --- a/TTS/TikTok.py +++ b/TTS/TikTok.py @@ -1,12 +1,7 @@ import base64 import os import random -import re - import requests -import sox -from moviepy.audio.AudioClip import concatenate_audioclips, CompositeAudioClip -from moviepy.audio.io.AudioFileClip import AudioFileClip from requests.adapters import HTTPAdapter, Retry # from profanity_filter import ProfanityFilter @@ -67,75 +62,39 @@ noneng = [ class TikTok: # TikTok Text-to-Speech Wrapper def __init__(self): - self.URI_BASE = ( - "https://api16-normal-useast5.us.tiktokv.com/media/api/text/speech/invoke/?text_speaker=" - ) - - def tts( - self, - req_text: str = "TikTok Text To Speech", - filename: str = "title.mp3", - random_speaker: bool = False, - censor=False, - ): - req_text = req_text.replace("+", "plus").replace(" ", "+").replace("&", "and") - if censor: - # req_text = pf.censor(req_text) - pass + self.URI_BASE = "https://api16-normal-useast5.us.tiktokv.com/media/api/text/speech/invoke/?text_speaker=" + self.max_chars = 330 + self.voices = {"human": human, "nonhuman": nonhuman, "noneng": noneng} + + def run(self, text, filepath, random_voice: bool = False): + # if censor: + # req_text = pf.censor(req_text) + # pass voice = ( - self.randomvoice() if random_speaker else (os.getenv("VOICE") or random.choice(human)) + self.randomvoice() + if random_voice + else (os.getenv("VOICE") or random.choice(self.voices["human"])) ) - - chunks = [m.group().strip() for m in re.finditer(r" *((.{0,299})(\.|.$))", req_text)] - - audio_clips = [] - cbn = sox.Combiner() - # cbn.set_input_format(file_type=["mp3" for _ in chunks]) - - chunkId = 0 - for chunk in chunks: - try: - r = requests.post(f"{self.URI_BASE}{voice}&req_text={chunk}&speaker_map_type=0") - except requests.exceptions.SSLError: - # https://stackoverflow.com/a/47475019/18516611 - session = requests.Session() - retry = Retry(connect=3, backoff_factor=0.5) - adapter = HTTPAdapter(max_retries=retry) - session.mount("http://", adapter) - session.mount("https://", adapter) - r = session.post(f"{self.URI_BASE}{voice}&req_text={chunk}&speaker_map_type=0") - print(r.text) - vstr = [r.json()["data"]["v_str"]][0] - b64d = base64.b64decode(vstr) - - with open(filename.replace(".mp3", f"-{chunkId}.mp3"), "wb") as out: - out.write(b64d) - - audio_clips.append(filename.replace(".mp3", f"-{chunkId}.mp3")) - - chunkId = chunkId + 1 try: - if len(audio_clips) > 1: - cbn.convert(samplerate=44100, n_channels=2) - cbn.build(audio_clips, filename, "concatenate") - else: - os.rename(audio_clips[0], filename) - except ( - sox.core.SoxError, - FileNotFoundError, - ): # https://github.com/JasonLovesDoggo/RedditVideoMakerBot/issues/67#issuecomment-1150466339 - for clip in audio_clips: - i = audio_clips.index(clip) # get the index of the clip - audio_clips = ( - audio_clips[:i] + [AudioFileClip(clip)] + audio_clips[i + 1 :] - ) # replace the clip with an AudioFileClip - audio_concat = concatenate_audioclips(audio_clips) - audio_composite = CompositeAudioClip([audio_concat]) - audio_composite.write_audiofile(filename, 44100, 2, 2000, None) - - @staticmethod - def randomvoice(): - ok_or_good = random.randrange(1, 10) - if ok_or_good == 1: # 1/10 chance of ok voice - return random.choice(voices) - return random.choice(human) # 9/10 chance of good voice + r = requests.post( + f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0" + ) + except requests.exceptions.SSLError: + # https://stackoverflow.com/a/47475019/18516611 + session = requests.Session() + retry = Retry(connect=3, backoff_factor=0.5) + adapter = HTTPAdapter(max_retries=retry) + session.mount("http://", adapter) + session.mount("https://", adapter) + r = session.post( + f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0" + ) + print(r.text) + vstr = [r.json()["data"]["v_str"]][0] + b64d = base64.b64decode(vstr) + + with open(filepath, "wb") as out: + out.write(b64d) + + def randomvoice(self): + return random.choice(self.voices["human"]) diff --git a/TTS/aws_polly.py b/TTS/aws_polly.py new file mode 100644 index 0000000..3bf7090 --- /dev/null +++ b/TTS/aws_polly.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 +from boto3 import Session +from botocore.exceptions import BotoCoreError, ClientError +import sys +import os +import random + +voices = [ + "Brian", + "Emma", + "Russell", + "Joey", + "Matthew", + "Joanna", + "Kimberly", + "Amy", + "Geraint", + "Nicole", + "Justin", + "Ivy", + "Kendra", + "Salli", + "Raveena", +] + + +class AWSPolly: + def __init__(self): + self.max_chars = 0 + self.voices = voices + + def run(self, text, filepath, random_voice: bool = False): + session = Session(profile_name="polly") + polly = session.client("polly") + if random_voice: + voice = self.randomvoice() + else: + if not os.getenv("VOICE"): + return ValueError( + f"Please set the environment variable VOICE to a valid voice. options are: {voices}" + ) + voice = str(os.getenv("VOICE")).capitalize() + try: + # Request speech synthesis + response = polly.synthesize_speech( + Text=text, OutputFormat="mp3", VoiceId=voice, Engine="neural" + ) + except (BotoCoreError, ClientError) as error: + # The service returned an error, exit gracefully + print(error) + sys.exit(-1) + + # Access the audio stream from the response + if "AudioStream" in response: + file = open(filepath, "wb") + file.write(response["AudioStream"].read()) + file.close() + # print_substep(f"Saved Text {idx} to MP3 files successfully.", style="bold green") + + else: + # The response didn't contain audio data, exit gracefully + print("Could not stream audio") + sys.exit(-1) + + def randomvoice(self): + return random.choice(self.voices) diff --git a/TTS/engine_wrapper.py b/TTS/engine_wrapper.py new file mode 100644 index 0000000..4c1c2c8 --- /dev/null +++ b/TTS/engine_wrapper.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +from pathlib import Path +from typing import Tuple +import re +from os import getenv +from mutagen.mp3 import MP3 +from rich.progress import track +from moviepy.editor import AudioFileClip, CompositeAudioClip, concatenate_audioclips +from utils.console import print_step, print_substep +from utils.voice import sanitize_text + + +class TTSEngine: + + """Calls the given TTS engine to reduce code duplication and allow multiple TTS engines. + + Args: + tts_module : The TTS module. Your module should handle the TTS itself and saving to the given path under the run method. + reddit_object : The reddit object that contains the posts to read. + path (Optional) : The unix style path to save the mp3 files to. This must not have leading or trailing slashes. + max_length (Optional) : The maximum length of the mp3 files in total. + + Notes: + tts_module must take the arguments text and filepath. + """ + + def __init__( + self, + tts_module, + reddit_object: dict, + path: str = "assets/mp3", + max_length: int = 50, + ): + self.tts_module = tts_module() + self.reddit_object = reddit_object + self.path = path + self.max_length = max_length + self.length = 0 + + def run(self) -> Tuple[int, int]: + + Path(self.path).mkdir(parents=True, exist_ok=True) + + # This file needs to be removed in case this post does not use post text, so that it wont appear in the final video + try: + Path(f"{self.path}/posttext.mp3").unlink() + except OSError: + pass + + print_step("Saving Text to MP3 files...") + + self.call_tts("title", self.reddit_object["thread_title"]) + if ( + self.reddit_object["thread_post"] != "" + and getenv("STORYMODE", "").casefold() == "true" + ): + + self.call_tts("posttext", sanitize_text(self.reddit_object["thread_post"])) + + idx = None + for idx, comment in track( + enumerate(self.reddit_object["comments"]), "Saving..." + ): + # ! Stop creating mp3 files if the length is greater than max length. + if self.length > self.max_length: + break + if not self.tts_module.max_chars: + self.call_tts(f"{idx}", sanitize_text(comment["comment_body"])) + else: + self.split_post(sanitize_text(comment["comment_body"]), idx) + + print_substep("Saved Text to MP3 files successfully.", style="bold green") + return self.length, idx + + def split_post(self, text: str, idx: int) -> str: + split_files = [] + split_text = [ + x.group().strip() + for x in re.finditer( + rf" *((.{{0,{self.tts_module.max_chars}}})(\.|.$))", text + ) + ] + + idy = None + for idy, text_cut in enumerate(split_text): + print(f"{idx}-{idy}: {text_cut}\n") + self.call_tts(f"{idx}-{idy}.part", text_cut) + split_files.append(AudioFileClip(f"{self.path}/{idx}-{idy}.part.mp3")) + CompositeAudioClip([concatenate_audioclips(split_files)]).write_audiofile( + f"{self.path}/{idx}.mp3", fps=44100, verbose=False, logger=None + ) + + for i in range(0, idy + 1): + print(f"Cleaning up {self.path}/{idx}-{i}.part.mp3") + Path(f"{self.path}/{idx}-{i}.part.mp3").unlink() + + def call_tts(self, filename: str, text: str): + self.tts_module.run(text=text, filepath=f"{self.path}/{filename}.mp3") + self.length += MP3(f"{self.path}/{filename}.mp3").info.length diff --git a/TTS/streamlabs_polly.py b/TTS/streamlabs_polly.py new file mode 100644 index 0000000..07f2c17 --- /dev/null +++ b/TTS/streamlabs_polly.py @@ -0,0 +1,53 @@ +import random +import os +import requests +from requests.exceptions import JSONDecodeError + +voices = [ + "Brian", + "Emma", + "Russell", + "Joey", + "Matthew", + "Joanna", + "Kimberly", + "Amy", + "Geraint", + "Nicole", + "Justin", + "Ivy", + "Kendra", + "Salli", + "Raveena", +] + + +# valid voices https://lazypy.ro/tts/ + + +class StreamlabsPolly: + def __init__(self): + self.url = "https://streamlabs.com/polly/speak" + self.max_chars = 550 + self.voices = voices + + def run(self, text, filepath, random_voice: bool = False): + if random_voice: + voice = self.randomvoice() + else: + if not os.getenv("VOICE"): + return ValueError( + f"Please set the environment variable VOICE to a valid voice. options are: {voices}" + ) + voice = str(os.getenv("VOICE")).capitalize() + body = {"voice": voice, "text": text, "service": "polly"} + response = requests.post(self.url, data=body) + try: + voice_data = requests.get(response.json()["speak_url"]) + with open(filepath, "wb") as f: + f.write(voice_data.content) + except (KeyError, JSONDecodeError): + print("Error occured calling Streamlabs Polly") + + def randomvoice(self): + return random.choice(self.voices) diff --git a/TTS/swapper.py b/TTS/swapper.py deleted file mode 100644 index c5f6776..0000000 --- a/TTS/swapper.py +++ /dev/null @@ -1,24 +0,0 @@ -from os import getenv - -from dotenv import load_dotenv - -from TTS.GTTS import GTTS -from TTS.POLLY import POLLY -from TTS.TikTok import TikTok -from utils.console import print_substep - -CHOICE_DIR = {"tiktok": TikTok, "gtts": GTTS, "polly": POLLY} - - -class TTS: - def __new__(cls): - load_dotenv() - try: - CHOICE = getenv("TTsChoice").casefold() - except AttributeError: - print_substep("None defined. Defaulting to 'polly.'") - CHOICE = "polly" - valid_keys = [key.lower() for key in CHOICE_DIR.keys()] - if CHOICE not in valid_keys: - raise ValueError(f"{CHOICE} is not valid. Please use one of these {valid_keys} options") - return CHOICE_DIR.get(CHOICE)() diff --git a/utils/console.py b/utils/console.py index 11ee429..842c60a 100644 --- a/utils/console.py +++ b/utils/console.py @@ -4,6 +4,7 @@ from rich.markdown import Markdown from rich.padding import Padding from rich.panel import Panel from rich.text import Text +from rich.columns import Columns console = Console() @@ -25,3 +26,9 @@ def print_step(text): def print_substep(text, style=""): """Prints a rich info message without the panelling.""" console.print(text, style=style) + + +def print_table(items): + """Prints items in a table.""" + + console.print(Columns([Panel(f"[yellow]{item}", expand=True) for item in items])) diff --git a/utils/voice.py b/utils/voice.py index 120ee60..7aed2ad 100644 --- a/utils/voice.py +++ b/utils/voice.py @@ -17,6 +17,6 @@ def sanitize_text(text): # note: not removing apostrophes regex_expr = r"\s['|’]|['|’]\s|[\^_~@!&;#:\-%“”‘\"%\*/{}\[\]\(\)\\|<>=+]" result = re.sub(regex_expr, " ", result) - + result = result.replace("+", "plus").replace(" ", "+").replace("&", "and") # remove extra whitespace return " ".join(result.split()) diff --git a/video_creation/voices.py b/video_creation/voices.py index be7da96..9407d59 100644 --- a/video_creation/voices.py +++ b/video_creation/voices.py @@ -1,20 +1,25 @@ #!/usr/bin/env python3 -from os import getenv -from pathlib import Path +import os -import sox -from mutagen import MutagenError -from mutagen.mp3 import MP3, HeaderNotFoundError from rich.console import Console -from rich.progress import track -from TTS.swapper import TTS +from TTS.engine_wrapper import TTSEngine +from TTS.GTTS import GTTS +from TTS.streamlabs_polly import StreamlabsPolly +from TTS.aws_polly import AWSPolly +from TTS.TikTok import TikTok + +from utils.console import print_table, print_step -from utils.console import print_step, print_substep -from utils.voice import sanitize_text console = Console() +TTSProviders = { + "GoogleTranslate": GTTS, + "AWSPolly": AWSPolly, + "StreamlabsPolly": StreamlabsPolly, + "TikTok": TikTok, +} VIDEO_LENGTH: int = 40 # secs @@ -22,58 +27,35 @@ VIDEO_LENGTH: int = 40 # secs def save_text_to_mp3(reddit_obj): """Saves Text to MP3 files. Args: - reddit_obj : The reddit object you received from the reddit API in the askreddit.py file. + reddit_obj : The reddit object you received from the reddit API in the askreddit.py file. """ - print_step("Saving Text to MP3 files...") - length = 0 - - # Create a folder for the mp3 files. - Path("assets/temp/mp3").mkdir(parents=True, exist_ok=True) - TextToSpeech = TTS() - TextToSpeech.tts( - sanitize_text(reddit_obj["thread_title"]), - filename="assets/temp/mp3/title.mp3", - random_speaker=False, - ) - try: - length += MP3("assets/temp/mp3/title.mp3").info.length - except HeaderNotFoundError: # note to self AudioFileClip - length += sox.file_info.duration("assets/temp/mp3/title.mp3") - if getenv("STORYMODE").casefold() == "true": - TextToSpeech.tts( - sanitize_text(reddit_obj["thread_content"]), - filename="assets/temp/mp3/story_content.mp3", - random_speaker=False, + env = os.getenv("TTS_PROVIDER", "") + if env in TTSProviders: + text_to_mp3 = TTSEngine(env, reddit_obj) + else: + chosen = False + choice = "" + while not chosen: + print_step("Please choose one of the following TTS providers: ") + print_table(TTSProviders) + choice = input("\n") + if choice.casefold() not in map(lambda _: _.casefold(), TTSProviders): + print("Unknown Choice") + else: + chosen = True + text_to_mp3 = TTSEngine( + get_case_insensitive_key_value(TTSProviders, choice), reddit_obj ) - # 'story_content' - com = 0 - for comment in track((reddit_obj["comments"]), "Saving..."): - # ! Stop creating mp3 files if the length is greater than VIDEO_LENGTH seconds. This can be longer - # but this is just a good_voices starting point - if length > VIDEO_LENGTH: - break - TextToSpeech.tts( - sanitize_text(comment["comment_body"]), - filename=f"assets/temp/mp3/{com}.mp3", - random_speaker=False, - ) - try: - length += MP3(f"assets/temp/mp3/{com}.mp3").info.length - com += 1 - except (HeaderNotFoundError, MutagenError, Exception): - try: - length += sox.file_info.duration(f"assets/temp/mp3/{com}.mp3") - com += 1 - except (OSError, IOError): - print( - "would have removed" - f"assets/temp/mp3/{com}.mp3" - f"assets/temp/png/comment_{com}.png" - ) - # remove(f"assets/temp/mp3/{com}.mp3") - # remove(f"assets/temp/png/comment_{com}.png")# todo might cause odd un-syncing + return text_to_mp3.run() + - print_substep("Saved Text to MP3 files Successfully.", style="bold green") - # ! Return the index, so we know how many screenshots of comments we need to make. - return length, com +def get_case_insensitive_key_value(input_dict, key): + return next( + ( + value + for dict_key, value in input_dict.items() + if dict_key.lower() == key.lower() + ), + None, + )