fixed formatting

pull/1092/head
Supreme-hub 2 years ago
parent 3da1debf83
commit ca97087ac8

@ -62,7 +62,9 @@ noneng = [
class TikTok: # TikTok Text-to-Speech Wrapper class TikTok: # TikTok Text-to-Speech Wrapper
def __init__(self): def __init__(self):
self.URI_BASE = "https://api16-normal-useast5.us.tiktokv.com/media/api/text/speech/invoke/?text_speaker=" self.URI_BASE = (
"https://api16-normal-useast5.us.tiktokv.com/media/api/text/speech/invoke/?text_speaker="
)
self.max_chars = 300 self.max_chars = 300
self.voices = {"human": human, "nonhuman": nonhuman, "noneng": noneng} self.voices = {"human": human, "nonhuman": nonhuman, "noneng": noneng}
@ -79,9 +81,7 @@ class TikTok: # TikTok Text-to-Speech Wrapper
) )
) )
try: try:
r = requests.post( r = requests.post(f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0")
f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0"
)
except requests.exceptions.SSLError: except requests.exceptions.SSLError:
# https://stackoverflow.com/a/47475019/18516611 # https://stackoverflow.com/a/47475019/18516611
session = requests.Session() session = requests.Session()
@ -89,9 +89,7 @@ class TikTok: # TikTok Text-to-Speech Wrapper
adapter = HTTPAdapter(max_retries=retry) adapter = HTTPAdapter(max_retries=retry)
session.mount("http://", adapter) session.mount("http://", adapter)
session.mount("https://", adapter) session.mount("https://", adapter)
r = session.post( r = session.post(f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0")
f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0"
)
# print(r.text) # print(r.text)
vstr = [r.json()["data"]["v_str"]][0] vstr = [r.json()["data"]["v_str"]][0]
b64d = base64.b64decode(vstr) b64d = base64.b64decode(vstr)

@ -40,9 +40,7 @@ class AWSPolly:
raise ValueError( raise ValueError(
f"Please set the TOML variable AWS_VOICE to a valid voice. options are: {voices}" f"Please set the TOML variable AWS_VOICE to a valid voice. options are: {voices}"
) )
voice = str( voice = str(settings.config["settings"]["tts"]["aws_polly_voice"]).capitalize()
settings.config["settings"]["tts"]["aws_polly_voice"]
).capitalize()
try: try:
# Request speech synthesis # Request speech synthesis
response = polly.synthesize_speech( response = polly.synthesize_speech(

@ -59,7 +59,10 @@ class TTSEngine:
self.call_tts("title", process_text(self.reddit_object["thread_title"])) self.call_tts("title", process_text(self.reddit_object["thread_title"]))
processed_text = process_text(self.reddit_object["thread_post"]) processed_text = process_text(self.reddit_object["thread_post"])
if processed_text != "" and settings.config["settings"]["storymode"] == True: if (
processed_text != ""
and settings.config["settings"]["storymode"] == True
):
self.call_tts("posttext", processed_text) self.call_tts("posttext", processed_text)
idx = None idx = None
@ -92,7 +95,7 @@ class TTSEngine:
offset = 0 offset = 0
for idy, text_cut in enumerate(split_text): for idy, text_cut in enumerate(split_text):
# print(f"{idx}-{idy}: {text_cut}\n") # print(f"{idx}-{idy}: {text_cut}\n")
new_text = process_text(text_cut) new_text = process_text(text_cut)
if not new_text or new_text.isspace(): if not new_text or new_text.isspace():
offset += 1 offset += 1
continue continue
@ -117,7 +120,9 @@ class TTSEngine:
# Path(f"{self.path}/{idx}-{i}.part.mp3").unlink() # Path(f"{self.path}/{idx}-{i}.part.mp3").unlink()
def call_tts(self, filename: str, text: str): def call_tts(self, filename: str, text: str):
self.tts_module.run(text, filepath=f"{self.path}/{filename}.mp3") self.tts_module.run(
text, filepath=f"{self.path}/{filename}.mp3"
)
# try: # try:
# self.length += MP3(f"{self.path}/{filename}.mp3").info.length # self.length += MP3(f"{self.path}/{filename}.mp3").info.length
# except (MutagenError, HeaderNotFoundError): # except (MutagenError, HeaderNotFoundError):

@ -2,37 +2,41 @@ import random
import pyttsx3 import pyttsx3
from utils import settings from utils import settings
class pyttsx:
class pyttsx:
def __init__(self): def __init__(self):
self.max_chars = 5000 self.max_chars = 5000
self.voices = [] self.voices = []
def run( def run(
self, self,
text: str , text: str,
filepath: str, filepath: str,
random_voice=False, random_voice=False,
): ):
voice_id = settings.config["settings"]['tts']["python_voice"] voice_id = settings.config["settings"]["tts"]["python_voice"]
voice_num = settings.config["settings"]['tts']["py_voice_num"] voice_num = settings.config["settings"]["tts"]["py_voice_num"]
if (voice_id == "" or voice_num == ""): if voice_id == "" or voice_num == "":
voice_id = 2 voice_id = 2
voice_num = 3 voice_num = 3
raise ValueError("set pyttsx values to a valid value, switching to defaults") raise ValueError(
"set pyttsx values to a valid value, switching to defaults"
)
else: else:
voice_id = int(voice_id) voice_id = int(voice_id)
voice_num = int(voice_num) voice_num = int(voice_num)
for i in range(voice_num): for i in range(voice_num):
self.voices.append(i) self.voices.append(i)
i=+1 i = +1
if random_voice: if random_voice:
voice_id = self.randomvoice() voice_id = self.randomvoice()
engine = pyttsx3.init() engine = pyttsx3.init()
voices = engine.getProperty('voices') voices = engine.getProperty("voices")
engine.setProperty('voice', voices[voice_id].id) #changing index changes voices but ony 0 and 1 are working here engine.setProperty(
"voice", voices[voice_id].id
) # changing index changes voices but ony 0 and 1 are working here
engine.save_to_file(text, f"{filepath}") engine.save_to_file(text, f"{filepath}")
engine.runAndWait() engine.runAndWait()
def randomvoice(self): def randomvoice(self):
return random.choice(self.voices) return random.choice(self.voices)

@ -40,9 +40,7 @@ class StreamlabsPolly:
raise ValueError( raise ValueError(
f"Please set the config variable STREAMLABS_POLLY_VOICE to a valid voice. options are: {voices}" f"Please set the config variable STREAMLABS_POLLY_VOICE to a valid voice. options are: {voices}"
) )
voice = str( voice = str(settings.config["settings"]["tts"]["streamlabs_polly_voice"]).capitalize()
settings.config["settings"]["tts"]["streamlabs_polly_voice"]
).capitalize()
body = {"voice": voice, "text": text, "service": "polly"} body = {"voice": voice, "text": text, "service": "polly"}
response = requests.post(self.url, data=body) response = requests.post(self.url, data=body)
if not check_ratelimit(response): if not check_ratelimit(response):

@ -75,9 +75,7 @@ if __name__ == "__main__":
run_many(config["settings"]["times_to_run"]) run_many(config["settings"]["times_to_run"])
elif len(config["reddit"]["thread"]["post_id"].split("+")) > 1: elif len(config["reddit"]["thread"]["post_id"].split("+")) > 1:
for index, post_id in enumerate( for index, post_id in enumerate(config["reddit"]["thread"]["post_id"].split("+")):
config["reddit"]["thread"]["post_id"].split("+")
):
index += 1 index += 1
print_step( print_step(
f'on the {index}{("st" if index % 10 == 1 else ("nd" if index % 10 == 2 else ("rd" if index % 10 == 3 else "th")))} post of {len(config["reddit"]["thread"]["post_id"].split("+"))}' f'on the {index}{("st" if index % 10 == 1 else ("nd" if index % 10 == 2 else ("rd" if index % 10 == 3 else "th")))} post of {len(config["reddit"]["thread"]["post_id"].split("+"))}'

@ -19,9 +19,7 @@ def get_subreddit_threads(POST_ID: str):
content = {} content = {}
if settings.config["reddit"]["creds"]["2fa"]: if settings.config["reddit"]["creds"]["2fa"]:
print( print("\nEnter your two-factor authentication code from your authenticator app.\n")
"\nEnter your two-factor authentication code from your authenticator app.\n"
)
code = input("> ") code = input("> ")
print() print()
pw = settings.config["reddit"]["creds"]["password"] pw = settings.config["reddit"]["creds"]["password"]
@ -47,9 +45,7 @@ def get_subreddit_threads(POST_ID: str):
]: # note to user. you can have multiple subreddits via reddit.subreddit("redditdev+learnpython") ]: # note to user. you can have multiple subreddits via reddit.subreddit("redditdev+learnpython")
try: try:
subreddit = reddit.subreddit( subreddit = reddit.subreddit(
re.sub( re.sub(r"r\/", "", input("What subreddit would you like to pull from? "))
r"r\/", "", input("What subreddit would you like to pull from? ")
)
# removes the r/ from the input # removes the r/ from the input
) )
except ValueError: except ValueError:
@ -59,9 +55,7 @@ def get_subreddit_threads(POST_ID: str):
sub = settings.config["reddit"]["thread"]["subreddit"] sub = settings.config["reddit"]["thread"]["subreddit"]
print_substep(f"Using subreddit: r/{sub} from TOML config") print_substep(f"Using subreddit: r/{sub} from TOML config")
subreddit_choice = sub subreddit_choice = sub
if ( if (str(subreddit_choice).casefold().startswith("r/")): # removes the r/ from the input
str(subreddit_choice).casefold().startswith("r/")
): # removes the r/ from the input
subreddit_choice = subreddit_choice[2:] subreddit_choice = subreddit_choice[2:]
subreddit = reddit.subreddit( subreddit = reddit.subreddit(
subreddit_choice subreddit_choice
@ -73,9 +67,7 @@ def get_subreddit_threads(POST_ID: str):
settings.config["reddit"]["thread"]["post_id"] settings.config["reddit"]["thread"]["post_id"]
and len(str(settings.config["reddit"]["thread"]["post_id"]).split("+")) == 1 and len(str(settings.config["reddit"]["thread"]["post_id"]).split("+")) == 1
): ):
submission = reddit.submission( submission = reddit.submission(id=settings.config["reddit"]["thread"]["post_id"])
id=settings.config["reddit"]["thread"]["post_id"]
)
else: else:
threads = subreddit.hot(limit=25) threads = subreddit.hot(limit=25)
submission = get_subreddit_undone(threads, subreddit) submission = get_subreddit_undone(threads, subreddit)

@ -14,9 +14,7 @@ def cleanup() -> int:
""" """
if exists("./assets/temp"): if exists("./assets/temp"):
count = 0 count = 0
files = [ files = [f for f in os.listdir(".") if f.endswith(".mp4") and "temp" in f.lower()]
f for f in os.listdir(".") if f.endswith(".mp4") and "temp" in f.lower()
]
count += len(files) count += len(files)
for f in files: for f in files:
os.remove(f) os.remove(f)

@ -49,10 +49,7 @@ def handle_input(
optional=False, optional=False,
): ):
if optional: if optional:
console.print( console.print(message + "\n[green]This is an optional value. Do you want to skip it? (y/n)")
message
+ "\n[green]This is an optional value. Do you want to skip it? (y/n)"
)
if input().casefold().startswith("y"): if input().casefold().startswith("y"):
return default if default is not NotImplemented else "" return default if default is not NotImplemented else ""
if default is not NotImplemented: if default is not NotImplemented:
@ -86,11 +83,7 @@ def handle_input(
console.print("[red]" + err_message) console.print("[red]" + err_message)
continue continue
elif match != "" and re.match(match, user_input) is None: elif match != "" and re.match(match, user_input) is None:
console.print( console.print("[red]" + err_message + "\nAre you absolutely sure it's correct?(y/n)")
"[red]"
+ err_message
+ "\nAre you absolutely sure it's correct?(y/n)"
)
if input().casefold().startswith("y"): if input().casefold().startswith("y"):
break break
continue continue
@ -123,9 +116,5 @@ def handle_input(
if user_input in options: if user_input in options:
return user_input return user_input
console.print( console.print(
"[red bold]" "[red bold]" + err_message + "\nValid options are: " + ", ".join(map(str, options)) + "."
+ err_message )
+ "\nValid options are: "
+ ", ".join(map(str, options))
+ "."
)

@ -54,11 +54,7 @@ def check(value, checks, name):
and not hasattr(value, "__iter__") and not hasattr(value, "__iter__")
and ( and (
("nmin" in checks and checks["nmin"] is not None and value < checks["nmin"]) ("nmin" in checks and checks["nmin"] is not None and value < checks["nmin"])
or ( or ("nmax" in checks and checks["nmax"] is not None and value > checks["nmax"])
"nmax" in checks
and checks["nmax"] is not None
and value > checks["nmax"]
)
) )
): ):
incorrect = True incorrect = True
@ -66,16 +62,8 @@ def check(value, checks, name):
not incorrect not incorrect
and hasattr(value, "__iter__") and hasattr(value, "__iter__")
and ( and (
( ("nmin" in checks and checks["nmin"] is not None and len(value) < checks["nmin"])
"nmin" in checks or ("nmax" in checks and checks["nmax"] is not None and len(value) > checks["nmax"])
and checks["nmin"] is not None
and len(value) < checks["nmin"]
)
or (
"nmax" in checks
and checks["nmax"] is not None
and len(value) > checks["nmax"]
)
) )
): ):
incorrect = True incorrect = True
@ -83,15 +71,9 @@ def check(value, checks, name):
if incorrect: if incorrect:
value = handle_input( value = handle_input(
message=( message=(
( (("[blue]Example: " + str(checks["example"]) + "\n") if "example" in checks else "")
("[blue]Example: " + str(checks["example"]) + "\n")
if "example" in checks
else ""
)
+ "[red]" + "[red]"
+ ("Non-optional ", "Optional ")[ + ("Non-optional ", "Optional ")["optional" in checks and checks["optional"] is True]
"optional" in checks and checks["optional"] is True
]
) )
+ "[#C0CAF5 bold]" + "[#C0CAF5 bold]"
+ str(name) + str(name)
@ -132,9 +114,7 @@ def check_toml(template_file, config_file) -> Tuple[bool, Dict]:
try: try:
template = toml.load(template_file) template = toml.load(template_file)
except Exception as error: except Exception as error:
console.print( console.print(f"[red bold]Encountered error when trying to to load {template_file}: {error}")
f"[red bold]Encountered error when trying to to load {template_file}: {error}"
)
return False return False
try: try:
config = toml.load(config_file) config = toml.load(config_file)

@ -19,9 +19,7 @@ def get_subreddit_undone(submissions: list, subreddit, times_checked=0):
if not exists("./video_creation/data/videos.json"): if not exists("./video_creation/data/videos.json"):
with open("./video_creation/data/videos.json", "w+") as f: with open("./video_creation/data/videos.json", "w+") as f:
json.dump([], f) json.dump([], f)
with open( with open("./video_creation/data/videos.json", "r", encoding="utf-8") as done_vids_raw:
"./video_creation/data/videos.json", "r", encoding="utf-8"
) as done_vids_raw:
done_videos = json.load(done_vids_raw) done_videos = json.load(done_vids_raw)
for submission in submissions: for submission in submissions:
if already_done(done_videos, submission): if already_done(done_videos, submission):
@ -36,9 +34,7 @@ def get_subreddit_undone(submissions: list, subreddit, times_checked=0):
if submission.stickied: if submission.stickied:
print_substep("This post was pinned by moderators. Skipping...") print_substep("This post was pinned by moderators. Skipping...")
continue continue
if submission.num_comments <= int( if submission.num_comments <= int(settings.config["reddit"]["thread"]["min_comments"]):
settings.config["reddit"]["thread"]["min_comments"]
):
print_substep( print_substep(
f'This post has under the specified minimum of comments ({settings.config["reddit"]["thread"]["min_comments"]}). Skipping...' f'This post has under the specified minimum of comments ({settings.config["reddit"]["thread"]["min_comments"]}). Skipping...'
) )
@ -59,8 +55,7 @@ def get_subreddit_undone(submissions: list, subreddit, times_checked=0):
return get_subreddit_undone( return get_subreddit_undone(
subreddit.top( subreddit.top(
time_filter=VALID_TIME_FILTERS[index], time_filter=VALID_TIME_FILTERS[index],limit=(50 if int(index) == 0 else index + 1 * 50),
limit=(50 if int(index) == 0 else index + 1 * 50),
), ),
subreddit, subreddit,
times_checked=index, times_checked=index,

@ -35,18 +35,10 @@ class Video:
return ImageClip(path) return ImageClip(path)
def add_watermark( def add_watermark(
self, self, text, opacity=0.5, duration: int | float = 5, position: Tuple = (0.7, 0.9), fontsize=15,
text,
opacity=0.5,
duration: int | float = 5,
position: Tuple = (0.7, 0.9),
fontsize=15,
): ):
compensation = round( compensation = round(
( (position[0] / ((len(text) * (fontsize / 5) / 1.5) / 100 + position[0] * position[0])),
position[0]
/ ((len(text) * (fontsize / 5) / 1.5) / 100 + position[0] * position[0])
),
ndigits=2, ndigits=2,
) )
position = (compensation, position[1]) position = (compensation, position[1])

@ -20,9 +20,7 @@ def check_done(
Returns: Returns:
Submission|None: Reddit object in args Submission|None: Reddit object in args
""" """
with open( with open("./video_creation/data/videos.json", "r", encoding="utf-8") as done_vids_raw:
"./video_creation/data/videos.json", "r", encoding="utf-8"
) as done_vids_raw:
done_videos = json.load(done_vids_raw) done_videos = json.load(done_vids_raw)
for video in done_videos: for video in done_videos:
if video["id"] == str(redditobj): if video["id"] == str(redditobj):
@ -36,9 +34,7 @@ def check_done(
return redditobj return redditobj
def save_data( def save_data(subreddit: str, filename: str, reddit_title: str, reddit_id: str, credit: str):
subreddit: str, filename: str, reddit_title: str, reddit_id: str, credit: str
):
"""Saves the videos that have already been generated to a JSON file in video_creation/data/videos.json """Saves the videos that have already been generated to a JSON file in video_creation/data/videos.json
Args: Args:

@ -40,9 +40,7 @@ def sleep_until(time):
if sys.version_info[0] >= 3 and time.tzinfo: if sys.version_info[0] >= 3 and time.tzinfo:
end = time.astimezone(timezone.utc).timestamp() end = time.astimezone(timezone.utc).timestamp()
else: else:
zoneDiff = ( zoneDiff = pytime.time() - (datetime.now() - datetime(1970, 1, 1)).total_seconds()
pytime.time() - (datetime.now() - datetime(1970, 1, 1)).total_seconds()
)
end = (time - datetime(1970, 1, 1)).total_seconds() + zoneDiff end = (time - datetime(1970, 1, 1)).total_seconds() + zoneDiff
# Type check # Type check

@ -31,9 +31,7 @@ def get_start_and_end_times(video_length: int, length_of_clip: int) -> Tuple[int
def get_background_config(): def get_background_config():
"""Fetch the background/s configuration""" """Fetch the background/s configuration"""
try: try:
choice = str( choice = str(settings.config["settings"]["background"]["background_choice"]).casefold()
settings.config["settings"]["background"]["background_choice"]
).casefold()
except AttributeError: except AttributeError:
print_substep("No background selected. Picking random background'") print_substep("No background selected. Picking random background'")
choice = None choice = None
@ -58,15 +56,13 @@ def download_background(background_config: Tuple[str, str, str, Any]):
) )
print_substep("Downloading the backgrounds videos... please be patient 🙏 ") print_substep("Downloading the backgrounds videos... please be patient 🙏 ")
print_substep(f"Downloading {filename} from {uri}") print_substep(f"Downloading {filename} from {uri}")
YouTube(uri, on_progress_callback=on_progress).streams.filter( YouTube(uri, on_progress_callback=on_progress).streams.filter(res="1080p").first().download(
res="1080p" "assets/backgrounds", filename=f"{credit}-{filename}"
).first().download("assets/backgrounds", filename=f"{credit}-{filename}") )
print_substep("Background video downloaded successfully! 🎉", style="bold green") print_substep("Background video downloaded successfully! 🎉", style="bold green")
def chop_background_video( def chop_background_video(background_config: Tuple[str, str, str, Any], video_length: int):
background_config: Tuple[str, str, str, Any], video_length: int
):
"""Generates the background footage to be used in the video and writes it to assets/temp/background.mp4 """Generates the background footage to be used in the video and writes it to assets/temp/background.mp4
Args: Args:

@ -75,9 +75,7 @@ def make_final_video(
) )
# Gather all audio clips # Gather all audio clips
audio_clips = [ audio_clips = [AudioFileClip(f"assets/temp/mp3/{i}.mp3") for i in range(number_of_clips)]
AudioFileClip(f"assets/temp/mp3/{i}.mp3") for i in range(number_of_clips)
]
audio_clips.insert(0, AudioFileClip("assets/temp/mp3/title.mp3")) audio_clips.insert(0, AudioFileClip("assets/temp/mp3/title.mp3"))
audio_concat = concatenate_audioclips(audio_clips) audio_concat = concatenate_audioclips(audio_clips)
audio_composite = CompositeAudioClip([audio_concat]) audio_composite = CompositeAudioClip([audio_concat])
@ -87,9 +85,7 @@ def make_final_video(
image_clips = [] image_clips = []
# Gather all images # Gather all images
new_opacity = 1 if opacity is None or float(opacity) >= 1 else float(opacity) new_opacity = 1 if opacity is None or float(opacity) >= 1 else float(opacity)
new_transition = ( new_transition = (0 if transition is None or float(transition) > 2 else float(transition))
0 if transition is None or float(transition) > 2 else float(transition)
)
image_clips.insert( image_clips.insert(
0, 0,
ImageClip("assets/temp/png/title.png") ImageClip("assets/temp/png/title.png")

@ -1,4 +1,3 @@
from distutils.command.config import config
import json import json
from pathlib import Path from pathlib import Path
@ -36,13 +35,9 @@ def download_screenshots_of_reddit_posts(reddit_object: dict, screenshot_num: in
context = browser.new_context() context = browser.new_context()
if settings.config["settings"]["theme"] == "dark": if settings.config["settings"]["theme"] == "dark":
cookie_file = open( cookie_file = open("./video_creation/data/cookie-dark-mode.json", encoding="utf-8")
"./video_creation/data/cookie-dark-mode.json", encoding="utf-8"
)
else: else:
cookie_file = open( cookie_file = open("./video_creation/data/cookie-light-mode.json", encoding="utf-8")
"./video_creation/data/cookie-light-mode.json", encoding="utf-8"
)
cookies = json.load(cookie_file) cookies = json.load(cookie_file)
context.add_cookies(cookies) # load preference cookies context.add_cookies(cookies) # load preference cookies
# Get the thread screenshot # Get the thread screenshot
@ -54,7 +49,7 @@ def download_screenshots_of_reddit_posts(reddit_object: dict, screenshot_num: in
print_substep("Post is NSFW. You are spicy...") print_substep("Post is NSFW. You are spicy...")
page.locator('[data-testid="content-gate"] button').click() page.locator('[data-testid="content-gate"] button').click()
page.wait_for_load_state() # Wait for page to fully load page.wait_for_load_state() # Wait for page to fully load
if page.locator('[data-click-id="text"] button').is_visible(): if page.locator('[data-click-id="text"] button').is_visible():
page.locator( page.locator(
@ -77,9 +72,7 @@ def download_screenshots_of_reddit_posts(reddit_object: dict, screenshot_num: in
else: else:
print_substep("Skipping translation...") print_substep("Skipping translation...")
page.locator('[data-test-id="post-content"]').screenshot( page.locator('[data-test-id="post-content"]').screenshot(path="assets/temp/png/title.png")
path="assets/temp/png/title.png"
)
if storymode: if storymode:
page.locator('[data-click-id="text"]').screenshot( page.locator('[data-click-id="text"]').screenshot(

@ -37,9 +37,7 @@ def save_text_to_mp3(reddit_obj) -> Tuple[int, int]:
voice = settings.config["settings"]["tts"]["voice_choice"] voice = settings.config["settings"]["tts"]["voice_choice"]
if str(voice).casefold() in map(lambda _: _.casefold(), TTSProviders): if str(voice).casefold() in map(lambda _: _.casefold(), TTSProviders):
text_to_mp3 = TTSEngine( text_to_mp3 = TTSEngine(get_case_insensitive_key_value(TTSProviders, voice), reddit_obj)
get_case_insensitive_key_value(TTSProviders, voice), reddit_obj
)
else: else:
while True: while True:
print_step("Please choose one of the following TTS providers: ") print_step("Please choose one of the following TTS providers: ")
@ -48,18 +46,12 @@ def save_text_to_mp3(reddit_obj) -> Tuple[int, int]:
if choice.casefold() in map(lambda _: _.casefold(), TTSProviders): if choice.casefold() in map(lambda _: _.casefold(), TTSProviders):
break break
print("Unknown Choice") print("Unknown Choice")
text_to_mp3 = TTSEngine( text_to_mp3 = TTSEngine(get_case_insensitive_key_value(TTSProviders, choice), reddit_obj)
get_case_insensitive_key_value(TTSProviders, choice), reddit_obj
)
return text_to_mp3.run() return text_to_mp3.run()
def get_case_insensitive_key_value(input_dict, key): def get_case_insensitive_key_value(input_dict, key):
return next( return next(
( (value for dict_key, value in input_dict.items() if dict_key.lower() == key.lower()),
value
for dict_key, value in input_dict.items()
if dict_key.lower() == key.lower()
),
None, None,
) )

Loading…
Cancel
Save