Reduced code duplication in TTS engines

pull/653/head
Callum Leslie 3 years ago
parent fc14049dba
commit c58fa10f53

@ -149,7 +149,7 @@ disable=raw-checker-failed,
suppressed-message, suppressed-message,
useless-suppression, useless-suppression,
deprecated-pragma, deprecated-pragma,
use-symbolic-message-instead use-symbolic-message-instead,
attribute-defined-outside-init, attribute-defined-outside-init,
invalid-name, invalid-name,
missing-docstring, missing-docstring,

@ -1,13 +1,18 @@
#!/usr/bin/env python3
import random
from gtts import gTTS from gtts import gTTS
max_chars = 0
class GTTS: class GTTS:
def tts( def __init__(self):
self, self.max_chars = 0
req_text: str = "Google Text To Speech", self.voices = []
filename: str = "title.mp3",
random_speaker=False, def run(self, text, filepath):
censor=False, tts = gTTS(text=text, lang="en", slow=False)
): tts.save(filepath)
tts = gTTS(text=req_text, lang="en", slow=False)
tts.save(f"{filename}") def randomvoice(self):
return random.choice(self.voices)

@ -1,106 +0,0 @@
import os
import random
import re
import requests
import sox
from moviepy.audio.AudioClip import concatenate_audioclips, CompositeAudioClip
from moviepy.audio.io.AudioFileClip import AudioFileClip
from requests.exceptions import JSONDecodeError
voices = [
"Brian",
"Emma",
"Russell",
"Joey",
"Matthew",
"Joanna",
"Kimberly",
"Amy",
"Geraint",
"Nicole",
"Justin",
"Ivy",
"Kendra",
"Salli",
"Raveena",
]
# valid voices https://lazypy.ro/tts/
class POLLY:
def __init__(self):
self.url = "https://streamlabs.com/polly/speak"
def tts(
self,
req_text: str = "Amazon Text To Speech",
filename: str = "title.mp3",
random_speaker=False,
censor=False,
):
if random_speaker:
voice = self.randomvoice()
else:
if not os.getenv("VOICE"):
return ValueError(
"Please set the environment variable VOICE to a valid voice. options are: {}".format(
voices
)
)
voice = str(os.getenv("VOICE")).capitalize()
body = {"voice": voice, "text": req_text, "service": "polly"}
response = requests.post(self.url, data=body)
try:
voice_data = requests.get(response.json()["speak_url"])
with open(filename, "wb") as f:
f.write(voice_data.content)
except (KeyError, JSONDecodeError):
if response.json()["error"] == "Text length is too long!":
chunks = [m.group().strip() for m in re.finditer(r" *((.{0,499})(\.|.$))", req_text)]
audio_clips = []
cbn = sox.Combiner()
chunkId = 0
for chunk in chunks:
body = {"voice": voice, "text": chunk, "service": "polly"}
resp = requests.post(self.url, data=body)
voice_data = requests.get(resp.json()["speak_url"])
with open(filename.replace(".mp3", f"-{chunkId}.mp3"), "wb") as out:
out.write(voice_data.content)
audio_clips.append(filename.replace(".mp3", f"-{chunkId}.mp3"))
chunkId = chunkId + 1
try:
if len(audio_clips) > 1:
cbn.convert(samplerate=44100, n_channels=2)
cbn.build(audio_clips, filename, "concatenate")
else:
os.rename(audio_clips[0], filename)
except (
sox.core.SoxError,
FileNotFoundError,
): # https://github.com/JasonLovesDoggo/RedditVideoMakerBot/issues/67#issuecomment-1150466339
for clip in audio_clips:
i = audio_clips.index(clip) # get the index of the clip
audio_clips = (
audio_clips[:i] + [AudioFileClip(clip)] + audio_clips[i + 1 :]
) # replace the clip with an AudioFileClip
audio_concat = concatenate_audioclips(audio_clips)
audio_composite = CompositeAudioClip([audio_concat])
audio_composite.write_audiofile(filename, 44100, 2, 2000, None)
def make_readable(self, text):
"""
Amazon Polly fails to read some symbols properly such as '& (and)'.
So we normalize input text before passing it to the service
"""
text = text.replace("&", "and")
return text
def randomvoice(self):
return random.choice(voices)

@ -1,12 +1,7 @@
import base64 import base64
import os import os
import random import random
import re
import requests import requests
import sox
from moviepy.audio.AudioClip import concatenate_audioclips, CompositeAudioClip
from moviepy.audio.io.AudioFileClip import AudioFileClip
from requests.adapters import HTTPAdapter, Retry from requests.adapters import HTTPAdapter, Retry
# from profanity_filter import ProfanityFilter # from profanity_filter import ProfanityFilter
@ -67,75 +62,39 @@ noneng = [
class TikTok: # TikTok Text-to-Speech Wrapper class TikTok: # TikTok Text-to-Speech Wrapper
def __init__(self): def __init__(self):
self.URI_BASE = ( self.URI_BASE = "https://api16-normal-useast5.us.tiktokv.com/media/api/text/speech/invoke/?text_speaker="
"https://api16-normal-useast5.us.tiktokv.com/media/api/text/speech/invoke/?text_speaker=" self.max_chars = 330
) self.voices = {"human": human, "nonhuman": nonhuman, "noneng": noneng}
def tts( def run(self, text, filepath, random_voice: bool = False):
self, # if censor:
req_text: str = "TikTok Text To Speech", # req_text = pf.censor(req_text)
filename: str = "title.mp3", # pass
random_speaker: bool = False,
censor=False,
):
req_text = req_text.replace("+", "plus").replace(" ", "+").replace("&", "and")
if censor:
# req_text = pf.censor(req_text)
pass
voice = ( voice = (
self.randomvoice() if random_speaker else (os.getenv("VOICE") or random.choice(human)) self.randomvoice()
if random_voice
else (os.getenv("VOICE") or random.choice(self.voices["human"]))
) )
chunks = [m.group().strip() for m in re.finditer(r" *((.{0,299})(\.|.$))", req_text)]
audio_clips = []
cbn = sox.Combiner()
# cbn.set_input_format(file_type=["mp3" for _ in chunks])
chunkId = 0
for chunk in chunks:
try:
r = requests.post(f"{self.URI_BASE}{voice}&req_text={chunk}&speaker_map_type=0")
except requests.exceptions.SSLError:
# https://stackoverflow.com/a/47475019/18516611
session = requests.Session()
retry = Retry(connect=3, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry)
session.mount("http://", adapter)
session.mount("https://", adapter)
r = session.post(f"{self.URI_BASE}{voice}&req_text={chunk}&speaker_map_type=0")
print(r.text)
vstr = [r.json()["data"]["v_str"]][0]
b64d = base64.b64decode(vstr)
with open(filename.replace(".mp3", f"-{chunkId}.mp3"), "wb") as out:
out.write(b64d)
audio_clips.append(filename.replace(".mp3", f"-{chunkId}.mp3"))
chunkId = chunkId + 1
try: try:
if len(audio_clips) > 1: r = requests.post(
cbn.convert(samplerate=44100, n_channels=2) f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0"
cbn.build(audio_clips, filename, "concatenate") )
else: except requests.exceptions.SSLError:
os.rename(audio_clips[0], filename) # https://stackoverflow.com/a/47475019/18516611
except ( session = requests.Session()
sox.core.SoxError, retry = Retry(connect=3, backoff_factor=0.5)
FileNotFoundError, adapter = HTTPAdapter(max_retries=retry)
): # https://github.com/JasonLovesDoggo/RedditVideoMakerBot/issues/67#issuecomment-1150466339 session.mount("http://", adapter)
for clip in audio_clips: session.mount("https://", adapter)
i = audio_clips.index(clip) # get the index of the clip r = session.post(
audio_clips = ( f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0"
audio_clips[:i] + [AudioFileClip(clip)] + audio_clips[i + 1 :] )
) # replace the clip with an AudioFileClip print(r.text)
audio_concat = concatenate_audioclips(audio_clips) vstr = [r.json()["data"]["v_str"]][0]
audio_composite = CompositeAudioClip([audio_concat]) b64d = base64.b64decode(vstr)
audio_composite.write_audiofile(filename, 44100, 2, 2000, None)
with open(filepath, "wb") as out:
@staticmethod out.write(b64d)
def randomvoice():
ok_or_good = random.randrange(1, 10) def randomvoice(self):
if ok_or_good == 1: # 1/10 chance of ok voice return random.choice(self.voices["human"])
return random.choice(voices)
return random.choice(human) # 9/10 chance of good voice

@ -0,0 +1,66 @@
#!/usr/bin/env python3
from boto3 import Session
from botocore.exceptions import BotoCoreError, ClientError
import sys
import os
import random
voices = [
"Brian",
"Emma",
"Russell",
"Joey",
"Matthew",
"Joanna",
"Kimberly",
"Amy",
"Geraint",
"Nicole",
"Justin",
"Ivy",
"Kendra",
"Salli",
"Raveena",
]
class AWSPolly:
def __init__(self):
self.max_chars = 0
self.voices = voices
def run(self, text, filepath, random_voice: bool = False):
session = Session(profile_name="polly")
polly = session.client("polly")
if random_voice:
voice = self.randomvoice()
else:
if not os.getenv("VOICE"):
return ValueError(
f"Please set the environment variable VOICE to a valid voice. options are: {voices}"
)
voice = str(os.getenv("VOICE")).capitalize()
try:
# Request speech synthesis
response = polly.synthesize_speech(
Text=text, OutputFormat="mp3", VoiceId=voice, Engine="neural"
)
except (BotoCoreError, ClientError) as error:
# The service returned an error, exit gracefully
print(error)
sys.exit(-1)
# Access the audio stream from the response
if "AudioStream" in response:
file = open(filepath, "wb")
file.write(response["AudioStream"].read())
file.close()
# print_substep(f"Saved Text {idx} to MP3 files successfully.", style="bold green")
else:
# The response didn't contain audio data, exit gracefully
print("Could not stream audio")
sys.exit(-1)
def randomvoice(self):
return random.choice(self.voices)

@ -0,0 +1,99 @@
#!/usr/bin/env python3
from pathlib import Path
from typing import Tuple
import re
from os import getenv
from mutagen.mp3 import MP3
from rich.progress import track
from moviepy.editor import AudioFileClip, CompositeAudioClip, concatenate_audioclips
from utils.console import print_step, print_substep
from utils.voice import sanitize_text
class TTSEngine:
"""Calls the given TTS engine to reduce code duplication and allow multiple TTS engines.
Args:
tts_module : The TTS module. Your module should handle the TTS itself and saving to the given path under the run method.
reddit_object : The reddit object that contains the posts to read.
path (Optional) : The unix style path to save the mp3 files to. This must not have leading or trailing slashes.
max_length (Optional) : The maximum length of the mp3 files in total.
Notes:
tts_module must take the arguments text and filepath.
"""
def __init__(
self,
tts_module,
reddit_object: dict,
path: str = "assets/mp3",
max_length: int = 50,
):
self.tts_module = tts_module()
self.reddit_object = reddit_object
self.path = path
self.max_length = max_length
self.length = 0
def run(self) -> Tuple[int, int]:
Path(self.path).mkdir(parents=True, exist_ok=True)
# This file needs to be removed in case this post does not use post text, so that it wont appear in the final video
try:
Path(f"{self.path}/posttext.mp3").unlink()
except OSError:
pass
print_step("Saving Text to MP3 files...")
self.call_tts("title", self.reddit_object["thread_title"])
if (
self.reddit_object["thread_post"] != ""
and getenv("STORYMODE", "").casefold() == "true"
):
self.call_tts("posttext", sanitize_text(self.reddit_object["thread_post"]))
idx = None
for idx, comment in track(
enumerate(self.reddit_object["comments"]), "Saving..."
):
# ! Stop creating mp3 files if the length is greater than max length.
if self.length > self.max_length:
break
if not self.tts_module.max_chars:
self.call_tts(f"{idx}", sanitize_text(comment["comment_body"]))
else:
self.split_post(sanitize_text(comment["comment_body"]), idx)
print_substep("Saved Text to MP3 files successfully.", style="bold green")
return self.length, idx
def split_post(self, text: str, idx: int) -> str:
split_files = []
split_text = [
x.group().strip()
for x in re.finditer(
rf" *((.{{0,{self.tts_module.max_chars}}})(\.|.$))", text
)
]
idy = None
for idy, text_cut in enumerate(split_text):
print(f"{idx}-{idy}: {text_cut}\n")
self.call_tts(f"{idx}-{idy}.part", text_cut)
split_files.append(AudioFileClip(f"{self.path}/{idx}-{idy}.part.mp3"))
CompositeAudioClip([concatenate_audioclips(split_files)]).write_audiofile(
f"{self.path}/{idx}.mp3", fps=44100, verbose=False, logger=None
)
for i in range(0, idy + 1):
print(f"Cleaning up {self.path}/{idx}-{i}.part.mp3")
Path(f"{self.path}/{idx}-{i}.part.mp3").unlink()
def call_tts(self, filename: str, text: str):
self.tts_module.run(text=text, filepath=f"{self.path}/{filename}.mp3")
self.length += MP3(f"{self.path}/{filename}.mp3").info.length

@ -0,0 +1,53 @@
import random
import os
import requests
from requests.exceptions import JSONDecodeError
voices = [
"Brian",
"Emma",
"Russell",
"Joey",
"Matthew",
"Joanna",
"Kimberly",
"Amy",
"Geraint",
"Nicole",
"Justin",
"Ivy",
"Kendra",
"Salli",
"Raveena",
]
# valid voices https://lazypy.ro/tts/
class StreamlabsPolly:
def __init__(self):
self.url = "https://streamlabs.com/polly/speak"
self.max_chars = 550
self.voices = voices
def run(self, text, filepath, random_voice: bool = False):
if random_voice:
voice = self.randomvoice()
else:
if not os.getenv("VOICE"):
return ValueError(
f"Please set the environment variable VOICE to a valid voice. options are: {voices}"
)
voice = str(os.getenv("VOICE")).capitalize()
body = {"voice": voice, "text": text, "service": "polly"}
response = requests.post(self.url, data=body)
try:
voice_data = requests.get(response.json()["speak_url"])
with open(filepath, "wb") as f:
f.write(voice_data.content)
except (KeyError, JSONDecodeError):
print("Error occured calling Streamlabs Polly")
def randomvoice(self):
return random.choice(self.voices)

@ -1,24 +0,0 @@
from os import getenv
from dotenv import load_dotenv
from TTS.GTTS import GTTS
from TTS.POLLY import POLLY
from TTS.TikTok import TikTok
from utils.console import print_substep
CHOICE_DIR = {"tiktok": TikTok, "gtts": GTTS, "polly": POLLY}
class TTS:
def __new__(cls):
load_dotenv()
try:
CHOICE = getenv("TTsChoice").casefold()
except AttributeError:
print_substep("None defined. Defaulting to 'polly.'")
CHOICE = "polly"
valid_keys = [key.lower() for key in CHOICE_DIR.keys()]
if CHOICE not in valid_keys:
raise ValueError(f"{CHOICE} is not valid. Please use one of these {valid_keys} options")
return CHOICE_DIR.get(CHOICE)()

@ -4,6 +4,7 @@ from rich.markdown import Markdown
from rich.padding import Padding from rich.padding import Padding
from rich.panel import Panel from rich.panel import Panel
from rich.text import Text from rich.text import Text
from rich.columns import Columns
console = Console() console = Console()
@ -25,3 +26,9 @@ def print_step(text):
def print_substep(text, style=""): def print_substep(text, style=""):
"""Prints a rich info message without the panelling.""" """Prints a rich info message without the panelling."""
console.print(text, style=style) console.print(text, style=style)
def print_table(items):
"""Prints items in a table."""
console.print(Columns([Panel(f"[yellow]{item}", expand=True) for item in items]))

@ -17,6 +17,6 @@ def sanitize_text(text):
# note: not removing apostrophes # note: not removing apostrophes
regex_expr = r"\s['|]|['|]\s|[\^_~@!&;#:\-%“”‘\"%\*/{}\[\]\(\)\\|<>=+]" regex_expr = r"\s['|]|['|]\s|[\^_~@!&;#:\-%“”‘\"%\*/{}\[\]\(\)\\|<>=+]"
result = re.sub(regex_expr, " ", result) result = re.sub(regex_expr, " ", result)
result = result.replace("+", "plus").replace(" ", "+").replace("&", "and")
# remove extra whitespace # remove extra whitespace
return " ".join(result.split()) return " ".join(result.split())

@ -1,20 +1,25 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from os import getenv import os
from pathlib import Path
import sox
from mutagen import MutagenError
from mutagen.mp3 import MP3, HeaderNotFoundError
from rich.console import Console from rich.console import Console
from rich.progress import track
from TTS.swapper import TTS from TTS.engine_wrapper import TTSEngine
from TTS.GTTS import GTTS
from TTS.streamlabs_polly import StreamlabsPolly
from TTS.aws_polly import AWSPolly
from TTS.TikTok import TikTok
from utils.console import print_table, print_step
from utils.console import print_step, print_substep
from utils.voice import sanitize_text
console = Console() console = Console()
TTSProviders = {
"GoogleTranslate": GTTS,
"AWSPolly": AWSPolly,
"StreamlabsPolly": StreamlabsPolly,
"TikTok": TikTok,
}
VIDEO_LENGTH: int = 40 # secs VIDEO_LENGTH: int = 40 # secs
@ -22,58 +27,35 @@ VIDEO_LENGTH: int = 40 # secs
def save_text_to_mp3(reddit_obj): def save_text_to_mp3(reddit_obj):
"""Saves Text to MP3 files. """Saves Text to MP3 files.
Args: Args:
reddit_obj : The reddit object you received from the reddit API in the askreddit.py file. reddit_obj : The reddit object you received from the reddit API in the askreddit.py file.
""" """
print_step("Saving Text to MP3 files...") env = os.getenv("TTS_PROVIDER", "")
length = 0 if env in TTSProviders:
text_to_mp3 = TTSEngine(env, reddit_obj)
# Create a folder for the mp3 files. else:
Path("assets/temp/mp3").mkdir(parents=True, exist_ok=True) chosen = False
TextToSpeech = TTS() choice = ""
TextToSpeech.tts( while not chosen:
sanitize_text(reddit_obj["thread_title"]), print_step("Please choose one of the following TTS providers: ")
filename="assets/temp/mp3/title.mp3", print_table(TTSProviders)
random_speaker=False, choice = input("\n")
) if choice.casefold() not in map(lambda _: _.casefold(), TTSProviders):
try: print("Unknown Choice")
length += MP3("assets/temp/mp3/title.mp3").info.length else:
except HeaderNotFoundError: # note to self AudioFileClip chosen = True
length += sox.file_info.duration("assets/temp/mp3/title.mp3") text_to_mp3 = TTSEngine(
if getenv("STORYMODE").casefold() == "true": get_case_insensitive_key_value(TTSProviders, choice), reddit_obj
TextToSpeech.tts(
sanitize_text(reddit_obj["thread_content"]),
filename="assets/temp/mp3/story_content.mp3",
random_speaker=False,
) )
# 'story_content'
com = 0
for comment in track((reddit_obj["comments"]), "Saving..."):
# ! Stop creating mp3 files if the length is greater than VIDEO_LENGTH seconds. This can be longer
# but this is just a good_voices starting point
if length > VIDEO_LENGTH:
break
TextToSpeech.tts( return text_to_mp3.run()
sanitize_text(comment["comment_body"]),
filename=f"assets/temp/mp3/{com}.mp3",
random_speaker=False,
)
try:
length += MP3(f"assets/temp/mp3/{com}.mp3").info.length
com += 1
except (HeaderNotFoundError, MutagenError, Exception):
try:
length += sox.file_info.duration(f"assets/temp/mp3/{com}.mp3")
com += 1
except (OSError, IOError):
print(
"would have removed"
f"assets/temp/mp3/{com}.mp3"
f"assets/temp/png/comment_{com}.png"
)
# remove(f"assets/temp/mp3/{com}.mp3")
# remove(f"assets/temp/png/comment_{com}.png")# todo might cause odd un-syncing
print_substep("Saved Text to MP3 files Successfully.", style="bold green") def get_case_insensitive_key_value(input_dict, key):
# ! Return the index, so we know how many screenshots of comments we need to make. return next(
return length, com (
value
for dict_key, value in input_dict.items()
if dict_key.lower() == key.lower()
),
None,
)

Loading…
Cancel
Save