style: format with python-black

pull/861/head
Callum Leslie 3 years ago
parent ac28e72017
commit 6a16c9f3e7
No known key found for this signature in database
GPG Key ID: D382C4AFEECEAA90

@ -12,7 +12,11 @@ class GTTS:
self.voices = [] self.voices = []
def run(self, text, filepath): def run(self, text, filepath):
tts = gTTS(text=text, lang=settings.config["reddit"]["thread"]["post_lang"] or "en", slow=False) tts = gTTS(
text=text,
lang=settings.config["reddit"]["thread"]["post_lang"] or "en",
slow=False,
)
tts.save(filepath) tts.save(filepath)
def randomvoice(self): def randomvoice(self):

@ -62,9 +62,7 @@ noneng = [
class TikTok: # TikTok Text-to-Speech Wrapper class TikTok: # TikTok Text-to-Speech Wrapper
def __init__(self): def __init__(self):
self.URI_BASE = ( self.URI_BASE = "https://api16-normal-useast5.us.tiktokv.com/media/api/text/speech/invoke/?text_speaker="
"https://api16-normal-useast5.us.tiktokv.com/media/api/text/speech/invoke/?text_speaker="
)
self.max_chars = 300 self.max_chars = 300
self.voices = {"human": human, "nonhuman": nonhuman, "noneng": noneng} self.voices = {"human": human, "nonhuman": nonhuman, "noneng": noneng}
@ -75,10 +73,15 @@ class TikTok: # TikTok Text-to-Speech Wrapper
voice = ( voice = (
self.randomvoice() self.randomvoice()
if random_voice if random_voice
else (settings.config["settings"]["tts"]["tiktok_voice"] or random.choice(self.voices["human"])) else (
settings.config["settings"]["tts"]["tiktok_voice"]
or random.choice(self.voices["human"])
)
) )
try: try:
r = requests.post(f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0") r = requests.post(
f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0"
)
except requests.exceptions.SSLError: except requests.exceptions.SSLError:
# https://stackoverflow.com/a/47475019/18516611 # https://stackoverflow.com/a/47475019/18516611
session = requests.Session() session = requests.Session()
@ -86,7 +89,9 @@ class TikTok: # TikTok Text-to-Speech Wrapper
adapter = HTTPAdapter(max_retries=retry) adapter = HTTPAdapter(max_retries=retry)
session.mount("http://", adapter) session.mount("http://", adapter)
session.mount("https://", adapter) session.mount("https://", adapter)
r = session.post(f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0") r = session.post(
f"{self.URI_BASE}{voice}&req_text={text}&speaker_map_type=0"
)
# print(r.text) # print(r.text)
vstr = [r.json()["data"]["v_str"]][0] vstr = [r.json()["data"]["v_str"]][0]
b64d = base64.b64decode(vstr) b64d = base64.b64decode(vstr)

@ -39,7 +39,9 @@ class AWSPolly:
return ValueError( return ValueError(
f"Please set the environment variable AWS_VOICE to a valid voice. options are: {voices}" f"Please set the environment variable AWS_VOICE to a valid voice. options are: {voices}"
) )
voice = str(settings.config["settings"]["tts"]["aws_polly_voice"]).capitalize() voice = str(
settings.config["settings"]["tts"]["aws_polly_voice"]
).capitalize()
try: try:
# Request speech synthesis # Request speech synthesis
response = polly.synthesize_speech( response = polly.synthesize_speech(

@ -56,11 +56,16 @@ class TTSEngine:
print_step("Saving Text to MP3 files...") print_step("Saving Text to MP3 files...")
self.call_tts("title", self.reddit_object["thread_title"]) self.call_tts("title", self.reddit_object["thread_title"])
if self.reddit_object["thread_post"] != "" and settings.config["settings"]["storymode"] == True: if (
self.reddit_object["thread_post"] != ""
and settings.config["settings"]["storymode"] == True
):
self.call_tts("posttext", self.reddit_object["thread_post"]) self.call_tts("posttext", self.reddit_object["thread_post"])
idx = None idx = None
for idx, comment in track(enumerate(self.reddit_object["comments"]), "Saving..."): for idx, comment in track(
enumerate(self.reddit_object["comments"]), "Saving..."
):
# ! Stop creating mp3 files if the length is greater than max length. # ! Stop creating mp3 files if the length is greater than max length.
if self.length > self.max_length: if self.length > self.max_length:
break break
@ -76,7 +81,9 @@ class TTSEngine:
split_files = [] split_files = []
split_text = [ split_text = [
x.group().strip() x.group().strip()
for x in re.finditer(rf" *((.{{0,{self.tts_module.max_chars}}})(\.|.$))", text) for x in re.finditer(
rf" *((.{{0,{self.tts_module.max_chars}}})(\.|.$))", text
)
] ]
idy = None idy = None
@ -94,12 +101,14 @@ class TTSEngine:
Path(name).unlink() Path(name).unlink()
# for i in range(0, idy + 1): # for i in range(0, idy + 1):
# print(f"Cleaning up {self.path}/{idx}-{i}.part.mp3") # print(f"Cleaning up {self.path}/{idx}-{i}.part.mp3")
# Path(f"{self.path}/{idx}-{i}.part.mp3").unlink() # Path(f"{self.path}/{idx}-{i}.part.mp3").unlink()
def call_tts(self, filename: str, text: str): def call_tts(self, filename: str, text: str):
self.tts_module.run(text=process_text(text), filepath=f"{self.path}/{filename}.mp3") self.tts_module.run(
text=process_text(text), filepath=f"{self.path}/{filename}.mp3"
)
# try: # try:
# self.length += MP3(f"{self.path}/{filename}.mp3").info.length # self.length += MP3(f"{self.path}/{filename}.mp3").info.length
# except (MutagenError, HeaderNotFoundError): # except (MutagenError, HeaderNotFoundError):
@ -108,6 +117,7 @@ class TTSEngine:
self.length += clip.duration self.length += clip.duration
clip.close() clip.close()
def process_text(text: str): def process_text(text: str):
lang = settings.config["reddit"]["thread"]["post_lang"] lang = settings.config["reddit"]["thread"]["post_lang"]
new_text = sanitize_text(text) new_text = sanitize_text(text)

@ -39,7 +39,9 @@ class StreamlabsPolly:
return ValueError( return ValueError(
f"Please set the environment variable STREAMLABS_VOICE to a valid voice. options are: {voices}" f"Please set the environment variable STREAMLABS_VOICE to a valid voice. options are: {voices}"
) )
voice = str(settings.config["settings"]["tts"]["streamlabs_polly_voice"]).capitalize() voice = str(
settings.config["settings"]["tts"]["streamlabs_polly_voice"]
).capitalize()
body = {"voice": voice, "text": text, "service": "polly"} body = {"voice": voice, "text": text, "service": "polly"}
response = requests.post(self.url, data=body) response = requests.post(self.url, data=body)
try: try:
@ -55,4 +57,3 @@ class StreamlabsPolly:
def randomvoice(self): def randomvoice(self):
return random.choice(self.voices) return random.choice(self.voices)

@ -59,7 +59,9 @@ if __name__ == "__main__":
run_many(config["settings"]["times_to_run"]) run_many(config["settings"]["times_to_run"])
elif len(config["reddit"]["thread"]["post_id"].split("+")) > 1: elif len(config["reddit"]["thread"]["post_id"].split("+")) > 1:
for index, post_id in enumerate(config["reddit"]["thread"]["post_id"].split("+")): for index, post_id in enumerate(
config["reddit"]["thread"]["post_id"].split("+")
):
index += 1 index += 1
print_step( print_step(
f'on the {index}{("st" if index%10 == 1 else ("nd" if index%10 == 2 else ("rd" if index%10 == 3 else "th")))} post of {len(config["reddit"]["thread"]["post_id"].split("+"))}' f'on the {index}{("st" if index%10 == 1 else ("nd" if index%10 == 2 else ("rd" if index%10 == 3 else "th")))} post of {len(config["reddit"]["thread"]["post_id"].split("+"))}'

@ -18,7 +18,9 @@ def get_subreddit_threads(POST_ID: str):
content = {} content = {}
if settings.config["reddit"]["creds"]["2fa"] == True: if settings.config["reddit"]["creds"]["2fa"] == True:
print("\nEnter your two-factor authentication code from your authenticator app.\n") print(
"\nEnter your two-factor authentication code from your authenticator app.\n"
)
code = input("> ") code = input("> ")
print() print()
pw = settings.config["reddit"]["creds"]["password"] pw = settings.config["reddit"]["creds"]["password"]
@ -29,7 +31,7 @@ def get_subreddit_threads(POST_ID: str):
if username.casefold().startswith("u/"): if username.casefold().startswith("u/"):
username = username[2:] username = username[2:]
reddit = praw.Reddit( reddit = praw.Reddit(
client_id= settings.config["reddit"]["creds"]["client_id"], client_id=settings.config["reddit"]["creds"]["client_id"],
client_secret=settings.config["reddit"]["creds"]["client_secret"], client_secret=settings.config["reddit"]["creds"]["client_secret"],
user_agent="Accessing Reddit threads", user_agent="Accessing Reddit threads",
username=username, username=username,
@ -39,10 +41,14 @@ def get_subreddit_threads(POST_ID: str):
# Ask user for subreddit input # Ask user for subreddit input
print_step("Getting subreddit threads...") print_step("Getting subreddit threads...")
if not settings.config["reddit"]["thread"]["subreddit"]: # note to user. you can have multiple subreddits via reddit.subreddit("redditdev+learnpython") if not settings.config["reddit"]["thread"][
"subreddit"
]: # note to user. you can have multiple subreddits via reddit.subreddit("redditdev+learnpython")
try: try:
subreddit = reddit.subreddit( subreddit = reddit.subreddit(
re.sub(r"r\/", "", input("What subreddit would you like to pull from? ")) re.sub(
r"r\/", "", input("What subreddit would you like to pull from? ")
)
# removes the r/ from the input # removes the r/ from the input
) )
except ValueError: except ValueError:
@ -52,7 +58,9 @@ def get_subreddit_threads(POST_ID: str):
sub = settings.config["reddit"]["thread"]["subreddit"] sub = settings.config["reddit"]["thread"]["subreddit"]
print_substep(f"Using subreddit: r/{sub} from TOML config") print_substep(f"Using subreddit: r/{sub} from TOML config")
subreddit_choice = sub subreddit_choice = sub
if subreddit_choice.casefold().startswith("r/"): # removes the r/ from the input if subreddit_choice.casefold().startswith(
"r/"
): # removes the r/ from the input
subreddit_choice = subreddit_choice[2:] subreddit_choice = subreddit_choice[2:]
subreddit = reddit.subreddit( subreddit = reddit.subreddit(
subreddit_choice subreddit_choice
@ -60,8 +68,13 @@ def get_subreddit_threads(POST_ID: str):
if POST_ID: # would only be called if there are multiple queued posts if POST_ID: # would only be called if there are multiple queued posts
submission = reddit.submission(id=POST_ID) submission = reddit.submission(id=POST_ID)
elif settings.config["reddit"]["thread"]["post_id"] and len(settings.config["reddit"]["thread"]["post_id"].split("+")) == 1: elif (
submission = reddit.submission(id=settings.config["reddit"]["thread"]["post_id"]) settings.config["reddit"]["thread"]["post_id"]
and len(settings.config["reddit"]["thread"]["post_id"].split("+")) == 1
):
submission = reddit.submission(
id=settings.config["reddit"]["thread"]["post_id"]
)
else: else:
threads = subreddit.hot(limit=25) threads = subreddit.hot(limit=25)
@ -90,7 +103,9 @@ def get_subreddit_threads(POST_ID: str):
if top_level_comment.body in ["[removed]", "[deleted]"]: if top_level_comment.body in ["[removed]", "[deleted]"]:
continue # # see https://github.com/JasonLovesDoggo/RedditVideoMakerBot/issues/78 continue # # see https://github.com/JasonLovesDoggo/RedditVideoMakerBot/issues/78
if not top_level_comment.stickied: if not top_level_comment.stickied:
if len(top_level_comment.body) <= int(settings.config["reddit"]["thread"]["max_comment_length"]): if len(top_level_comment.body) <= int(
settings.config["reddit"]["thread"]["max_comment_length"]
):
if ( if (
top_level_comment.author is not None top_level_comment.author is not None
): # if errors occur with this change to if not. ): # if errors occur with this change to if not.

@ -10,7 +10,9 @@ def cleanup() -> int:
""" """
if exists("./assets/temp"): if exists("./assets/temp"):
count = 0 count = 0
files = [f for f in os.listdir(".") if f.endswith(".mp4") and "temp" in f.lower()] files = [
f for f in os.listdir(".") if f.endswith(".mp4") and "temp" in f.lower()
]
count += len(files) count += len(files)
for f in files: for f in files:
os.remove(f) os.remove(f)

@ -49,7 +49,10 @@ def handle_input(
optional=False, optional=False,
): ):
if optional: if optional:
console.print(message + "\n[green]This is an optional value. Do you want to skip it? (y/n)") console.print(
message
+ "\n[green]This is an optional value. Do you want to skip it? (y/n)"
)
if input().casefold().startswith("y"): if input().casefold().startswith("y"):
return None return None
if default is not NotImplemented: if default is not NotImplemented:
@ -84,7 +87,9 @@ def handle_input(
continue continue
elif match != "" and re.match(match, user_input) is None: elif match != "" and re.match(match, user_input) is None:
console.print( console.print(
"[red]" + err_message + "\nAre you absolutely sure it's correct?(y/n)" "[red]"
+ err_message
+ "\nAre you absolutely sure it's correct?(y/n)"
) )
if input().casefold().startswith("y"): if input().casefold().startswith("y"):
break break

@ -54,7 +54,11 @@ def check(value, checks, name):
and not hasattr(value, "__iter__") and not hasattr(value, "__iter__")
and ( and (
("nmin" in checks and checks["nmin"] is not None and value < checks["nmin"]) ("nmin" in checks and checks["nmin"] is not None and value < checks["nmin"])
or ("nmax" in checks and checks["nmax"] is not None and value > checks["nmax"]) or (
"nmax" in checks
and checks["nmax"] is not None
and value > checks["nmax"]
)
) )
): ):
incorrect = True incorrect = True
@ -62,8 +66,16 @@ def check(value, checks, name):
not incorrect not incorrect
and hasattr(value, "__iter__") and hasattr(value, "__iter__")
and ( and (
("nmin" in checks and checks["nmin"] is not None and len(value) < checks["nmin"]) (
or ("nmax" in checks and checks["nmax"] is not None and len(value) > checks["nmax"]) "nmin" in checks
and checks["nmin"] is not None
and len(value) < checks["nmin"]
)
or (
"nmax" in checks
and checks["nmax"] is not None
and len(value) > checks["nmax"]
)
) )
): ):
incorrect = True incorrect = True
@ -71,7 +83,11 @@ def check(value, checks, name):
if incorrect: if incorrect:
value = handle_input( value = handle_input(
message=( message=(
(("[blue]Example: " + str(checks["example"]) + "\n") if "example" in checks else "") (
("[blue]Example: " + str(checks["example"]) + "\n")
if "example" in checks
else ""
)
+ "[red]" + "[red]"
+ ("Non-optional ", "Optional ")[ + ("Non-optional ", "Optional ")[
"optional" in checks and checks["optional"] is True "optional" in checks and checks["optional"] is True
@ -84,7 +100,9 @@ def check(value, checks, name):
check_type=eval(checks["type"]) if "type" in checks else False, check_type=eval(checks["type"]) if "type" in checks else False,
default=checks["default"] if "default" in checks else NotImplemented, default=checks["default"] if "default" in checks else NotImplemented,
match=checks["regex"] if "regex" in checks else "", match=checks["regex"] if "regex" in checks else "",
err_message=checks["input_error"] if "input_error" in checks else "Incorrect input", err_message=checks["input_error"]
if "input_error" in checks
else "Incorrect input",
nmin=checks["nmin"] if "nmin" in checks else None, nmin=checks["nmin"] if "nmin" in checks else None,
nmax=checks["nmax"] if "nmax" in checks else None, nmax=checks["nmax"] if "nmax" in checks else None,
oob_error=checks["oob_error"] oob_error=checks["oob_error"]

@ -15,7 +15,9 @@ def get_subreddit_undone(submissions: list, subreddit):
""" """
# recursively checks if the top submission in the list was already done. # recursively checks if the top submission in the list was already done.
with open("./video_creation/data/videos.json", "r", encoding="utf-8") as done_vids_raw: with open(
"./video_creation/data/videos.json", "r", encoding="utf-8"
) as done_vids_raw:
done_videos = json.load(done_vids_raw) done_videos = json.load(done_vids_raw)
for submission in submissions: for submission in submissions:
if already_done(done_videos, submission): if already_done(done_videos, submission):

@ -20,7 +20,9 @@ def check_done(
Returns: Returns:
Dict[str]|None: Reddit object in args Dict[str]|None: Reddit object in args
""" """
with open("./video_creation/data/videos.json", "r", encoding="utf-8") as done_vids_raw: with open(
"./video_creation/data/videos.json", "r", encoding="utf-8"
) as done_vids_raw:
done_videos = json.load(done_vids_raw) done_videos = json.load(done_vids_raw)
for video in done_videos: for video in done_videos:
if video["id"] == str(redditobj): if video["id"] == str(redditobj):

@ -52,7 +52,9 @@ def download_background():
"assets/backgrounds", filename=f"{credit}-{filename}" "assets/backgrounds", filename=f"{credit}-{filename}"
) )
print_substep("Background videos downloaded successfully! 🎉", style="bold green") print_substep(
"Background videos downloaded successfully! 🎉", style="bold green"
)
def chop_background_video(video_length: int) -> str: def chop_background_video(video_length: int) -> str:

@ -21,12 +21,15 @@ from utils.cleanup import cleanup
from utils.console import print_step, print_substep from utils.console import print_step, print_substep
from utils.videos import save_data from utils.videos import save_data
from utils import settings from utils import settings
console = Console() console = Console()
W, H = 1080, 1920 W, H = 1080, 1920
def make_final_video(number_of_clips: int, length: int, reddit_obj: dict, background_credit: str): def make_final_video(
number_of_clips: int, length: int, reddit_obj: dict, background_credit: str
):
"""Gathers audio clips, gathers all screenshots, stitches them together and saves the final video to assets/temp """Gathers audio clips, gathers all screenshots, stitches them together and saves the final video to assets/temp
Args: Args:
@ -46,7 +49,9 @@ def make_final_video(number_of_clips: int, length: int, reddit_obj: dict, backgr
) )
# Gather all audio clips # Gather all audio clips
audio_clips = [AudioFileClip(f"assets/temp/mp3/{i}.mp3") for i in range(number_of_clips)] audio_clips = [
AudioFileClip(f"assets/temp/mp3/{i}.mp3") for i in range(number_of_clips)
]
audio_clips.insert(0, AudioFileClip("assets/temp/mp3/title.mp3")) audio_clips.insert(0, AudioFileClip("assets/temp/mp3/title.mp3"))
audio_concat = concatenate_audioclips(audio_clips) audio_concat = concatenate_audioclips(audio_clips)
audio_composite = CompositeAudioClip([audio_concat]) audio_composite = CompositeAudioClip([audio_concat])
@ -63,7 +68,7 @@ def make_final_video(number_of_clips: int, length: int, reddit_obj: dict, backgr
.set_duration(audio_clips[0].duration) .set_duration(audio_clips[0].duration)
.set_position("center") .set_position("center")
.resize(width=W - 100) .resize(width=W - 100)
.set_opacity(new_opacity) .set_opacity(new_opacity),
) )
for i in range(0, number_of_clips): for i in range(0, number_of_clips):
@ -85,7 +90,9 @@ def make_final_video(number_of_clips: int, length: int, reddit_obj: dict, backgr
# .set_opacity(float(opacity)), # .set_opacity(float(opacity)),
# ) # )
# else: # else:
image_concat = concatenate_videoclips(image_clips).set_position(("center", "center")) image_concat = concatenate_videoclips(image_clips).set_position(
("center", "center")
)
image_concat.audio = audio_composite image_concat.audio = audio_composite
final = CompositeVideoClip([background_clip, image_concat]) final = CompositeVideoClip([background_clip, image_concat])
title = re.sub(r"[^\w\s-]", "", reddit_obj["thread_title"]) title = re.sub(r"[^\w\s-]", "", reddit_obj["thread_title"])
@ -108,7 +115,10 @@ def make_final_video(number_of_clips: int, length: int, reddit_obj: dict, backgr
threads=multiprocessing.cpu_count(), threads=multiprocessing.cpu_count(),
) )
ffmpeg_tools.ffmpeg_extract_subclip( ffmpeg_tools.ffmpeg_extract_subclip(
"assets/temp/temp.mp4", 0, final.duration, targetname=f"results/{subreddit}/{filename}" "assets/temp/temp.mp4",
0,
final.duration,
targetname=f"results/{subreddit}/{filename}",
) )
# os.remove("assets/temp/temp.mp4") # os.remove("assets/temp/temp.mp4")

@ -35,9 +35,13 @@ def download_screenshots_of_reddit_posts(reddit_object: dict, screenshot_num: in
context = browser.new_context() context = browser.new_context()
if settings.config["settings"]["theme"] == "dark": if settings.config["settings"]["theme"] == "dark":
cookie_file = open("./video_creation/data/cookie-dark-mode.json", encoding="utf-8") cookie_file = open(
"./video_creation/data/cookie-dark-mode.json", encoding="utf-8"
)
else: else:
cookie_file = open("./video_creation/data/cookie-light-mode.json", encoding="utf-8") cookie_file = open(
"./video_creation/data/cookie-light-mode.json", encoding="utf-8"
)
cookies = json.load(cookie_file) cookies = json.load(cookie_file)
context.add_cookies(cookies) # load preference cookies context.add_cookies(cookies) # load preference cookies
# Get the thread screenshot # Get the thread screenshot
@ -57,7 +61,10 @@ def download_screenshots_of_reddit_posts(reddit_object: dict, screenshot_num: in
if settings.config["reddit"]["thread"]["post_lang"]: if settings.config["reddit"]["thread"]["post_lang"]:
print_substep("Translating post...") print_substep("Translating post...")
texts_in_tl = ts.google(reddit_object["thread_title"], to_language=settings.config["reddit"]["thread"]["post_lang"]) texts_in_tl = ts.google(
reddit_object["thread_title"],
to_language=settings.config["reddit"]["thread"]["post_lang"],
)
page.evaluate( page.evaluate(
"tl_content => document.querySelector('[data-test-id=\"post-content\"] > div:nth-child(3) > div > div').textContent = tl_content", "tl_content => document.querySelector('[data-test-id=\"post-content\"] > div:nth-child(3) > div > div').textContent = tl_content",
@ -66,7 +73,9 @@ def download_screenshots_of_reddit_posts(reddit_object: dict, screenshot_num: in
else: else:
print_substep("Skipping translation...") print_substep("Skipping translation...")
page.locator('[data-test-id="post-content"]').screenshot(path="assets/temp/png/title.png") page.locator('[data-test-id="post-content"]').screenshot(
path="assets/temp/png/title.png"
)
if storymode: if storymode:
page.locator('[data-click-id="text"]').screenshot( page.locator('[data-click-id="text"]').screenshot(
@ -89,7 +98,8 @@ def download_screenshots_of_reddit_posts(reddit_object: dict, screenshot_num: in
if settings.config["reddit"]["thread"]["post_lang"]: if settings.config["reddit"]["thread"]["post_lang"]:
comment_tl = ts.google( comment_tl = ts.google(
comment["comment_body"], to_language=settings.config["reddit"]["thread"]["post_lang"] comment["comment_body"],
to_language=settings.config["reddit"]["thread"]["post_lang"],
) )
page.evaluate( page.evaluate(
'([tl_content, tl_id]) => document.querySelector(`#t1_${tl_id} > div:nth-child(2) > div > div[data-testid="comment"] > div`).textContent = tl_content', '([tl_content, tl_id]) => document.querySelector(`#t1_${tl_id} > div:nth-child(2) > div > div[data-testid="comment"] > div`).textContent = tl_content',

@ -35,7 +35,9 @@ def save_text_to_mp3(reddit_obj) -> Tuple[int, int]:
voice = settings.config["settings"]["tts"]["choice"] voice = settings.config["settings"]["tts"]["choice"]
if voice.casefold() in map(lambda _: _.casefold(), TTSProviders): if voice.casefold() in map(lambda _: _.casefold(), TTSProviders):
text_to_mp3 = TTSEngine(get_case_insensitive_key_value(TTSProviders, voice), reddit_obj) text_to_mp3 = TTSEngine(
get_case_insensitive_key_value(TTSProviders, voice), reddit_obj
)
else: else:
while True: while True:
print_step("Please choose one of the following TTS providers: ") print_step("Please choose one of the following TTS providers: ")
@ -44,13 +46,19 @@ def save_text_to_mp3(reddit_obj) -> Tuple[int, int]:
if choice.casefold() in map(lambda _: _.casefold(), TTSProviders): if choice.casefold() in map(lambda _: _.casefold(), TTSProviders):
break break
print("Unknown Choice") print("Unknown Choice")
text_to_mp3 = TTSEngine(get_case_insensitive_key_value(TTSProviders, choice), reddit_obj) text_to_mp3 = TTSEngine(
get_case_insensitive_key_value(TTSProviders, choice), reddit_obj
)
return text_to_mp3.run() return text_to_mp3.run()
def get_case_insensitive_key_value(input_dict, key): def get_case_insensitive_key_value(input_dict, key):
return next( return next(
(value for dict_key, value in input_dict.items() if dict_key.lower() == key.lower()), (
value
for dict_key, value in input_dict.items()
if dict_key.lower() == key.lower()
),
None, None,
) )

Loading…
Cancel
Save