self review: improved typing & logging, removed unused imports, fixes in README

pull/963/head
Drugsosos 3 years ago
parent 01e9b2e4d0
commit acf679bfb6
No known key found for this signature in database
GPG Key ID: 8E35176FE617E28D

@ -33,7 +33,6 @@ The only original thing being done is the editing and gathering of all materials
## Requirements
- Python 3.9+
- Playwright (this should install automatically in installation)
## Installation 👩‍💻

@ -5,12 +5,20 @@ from gtts import gTTS
class GTTS:
max_chars = 0
# voices = []
@staticmethod
async def run(
text,
filepath
) -> None:
"""
Calls for TTS api
Args:
text: text to be voiced over
filepath: name of the audio file
"""
tts = gTTS(
text=text,
lang=settings.config["reddit"]["thread"]["post_lang"] or "en",

@ -1,10 +1,8 @@
import base64
from utils import settings
import requests
from requests.adapters import HTTPAdapter, Retry
from attr import attrs, attrib
from attr.validators import instance_of
from TTS.common import BaseApiTTS, get_random_voice
@ -74,8 +72,20 @@ class TikTok(BaseApiTTS): # TikTok Text-to-Speech Wrapper
max_chars = 300
decode_base64 = True
def __attrs_post_init__(self):
self.voice = (
def make_request(
self,
text: str,
):
"""
Makes a requests to remote TTS service
Args:
text: text to be voice over
Returns:
Request's response
"""
voice = (
get_random_voice(voices, 'human')
if self.random_voice
else str(settings.config['settings']['tts']['tiktok_voice']).lower()
@ -83,16 +93,11 @@ class TikTok(BaseApiTTS): # TikTok Text-to-Speech Wrapper
voice.lower() for dict_title in voices for voice in voices[dict_title]]
else get_random_voice(voices, 'human')
)
def make_request(
self,
text: str,
):
try:
r = requests.post(
self.uri_base,
params={
'text_speaker': self.voice,
'text_speaker': voice,
'req_text': text,
'speaker_map_type': 0,
})
@ -103,6 +108,6 @@ class TikTok(BaseApiTTS): # TikTok Text-to-Speech Wrapper
adapter = HTTPAdapter(max_retries=retry)
session.mount('http://', adapter)
session.mount('https://', adapter)
r = session.post(f'{self.uri_base}{self.voice}&req_text={text}&speaker_map_type=0')
r = session.post(f'{self.uri_base}{voice}&req_text={text}&speaker_map_type=0')
# print(r.text)
return r.json()['data']['v_str']

@ -37,7 +37,14 @@ class AWSPolly:
self,
text,
filepath,
):
) -> None:
"""
Calls for TTS api
Args:
text: text to be voiced over
filepath: name of the audio file
"""
try:
session = Session(profile_name='polly')
polly = session.client('polly')

@ -12,6 +12,16 @@ class BaseApiTTS:
text: str,
max_length: int,
) -> list:
"""
Splits text if it's too long to be a query
Args:
text: text to be sanitized
max_length: maximum length of the query
Returns:
Split text as a list
"""
# Split by comma or dot (else you can lose intonations), if there is non, split by groups of 299 chars
if '.' in text and all([split_text.__len__() < max_length for split_text in text.split('.')]):
return text.split('.')
@ -26,6 +36,13 @@ class BaseApiTTS:
output_text: str,
filepath: str,
) -> None:
"""
Writes and decodes TTS responses in files
Args:
output_text: text to be written
filepath: path/name of the file
"""
decoded_text = base64.b64decode(output_text) if self.decode_base64 else output_text
with open(filepath, 'wb') as out:
@ -36,6 +53,16 @@ class BaseApiTTS:
text: str,
filepath: str,
) -> None:
"""
Calls for TTS api and writes audio file
Args:
text: text to be voice over
filepath: path/name of the file
Returns:
"""
output_text = ''
if len(text) > self.max_chars:
for part in self.text_len_sanitize(text, self.max_chars):
@ -50,19 +77,45 @@ def get_random_voice(
voices: Union[list, dict],
key: Optional[str] = None,
) -> str:
"""
Return random voice from list or dict
Args:
voices: list or dict of voices
key: key of a dict if you are using one
Returns:
random voice as a str
"""
if isinstance(voices, list):
return choice(voices)
else:
return choice(voices[key])
return choice(voices[key] if key else list(voices.values())[0])
def audio_length(
path: str,
) -> float | int:
"""
Gets the length of the audio file
Args:
path: audio file path
Returns:
length in seconds as an int
"""
from mutagen.mp3 import MP3
try:
audio = MP3(path)
return audio.info.length
except Exception as e: # TODO add logging
except Exception as e:
import logging
logger = logging.getLogger('spam_application')
logger.setLevel(logging.DEBUG)
handler = logging.FileHandler('tts_log', mode='a+', encoding='utf-8')
logger.addHandler(handler)
logger.error('Error occurred in audio_length:', e)
return 0

@ -39,7 +39,9 @@ class TTSEngine:
)
def __attrs_post_init__(self):
# Calls an instance of the tts_module class
self.tts_module = self.tts_module()
# Loading settings from the config
self.max_length: int = settings.config['settings']['video_length']
self.time_before_tts: float = settings.config['settings']['time_before_tts']
self.time_between_pictures: float = settings.config['settings']['time_between_pictures']
@ -51,7 +53,12 @@ class TTSEngine:
def run(
self
) -> list:
"""
Voices over comments & title of the submission
Returns:
Indexes of comments to be used in the final video
"""
Path(self.path).mkdir(parents=True, exist_ok=True)
# This file needs to be removed in case this post does not use post text
@ -87,6 +94,16 @@ class TTSEngine:
filename: str,
text: str
) -> bool:
"""
Calls for TTS api from the factory
Args:
filename: name of audio file w/o .mp3
text: text to be voiced over
Returns:
True if audio files not exceeding the maximum length else false
"""
if not text:
return False
@ -107,6 +124,15 @@ class TTSEngine:
def process_text(
text: str,
) -> str:
"""
Sanitizes text for illegal characters and translates text
Args:
text: text to be sanitized & translated
Returns:
Processed text as a str
"""
lang = settings.config['reddit']['thread']['post_lang']
new_text = sanitize_text(text)
if lang:

@ -42,6 +42,15 @@ class StreamlabsPolly(BaseApiTTS):
self,
text,
):
"""
Makes a requests to remote TTS service
Args:
text: text to be voice over
Returns:
Request's response
"""
voice = (
get_random_voice(voices)
if self.random_voice

@ -7,7 +7,6 @@ from utils.cleanup import cleanup
from utils.console import print_markdown, print_step
from utils import settings
# from utils.checker import envUpdate
from video_creation.background import (
get_background_config,
)

@ -9,7 +9,7 @@ from utils.console import handle_input
console = Console()
config = dict() # calling instance of a dict to calm lint down
config = dict() # calling instance of a dict to calm lint down (dict[any] will work as well)
def crawl(obj: dict, func=lambda x, y: print(x, y, end="\n"), path=None):

@ -11,7 +11,7 @@ if sys.version_info[0] >= 3:
def check_ratelimit(
response: Response
response: Response,
):
"""
Checks if the response is a ratelimit response.

@ -13,7 +13,7 @@ from moviepy.editor import (
CompositeAudioClip,
CompositeVideoClip,
)
from moviepy.video.io.ffmpeg_tools import ffmpeg_merge_video_audio, ffmpeg_extract_subclip
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
from rich.console import Console
from rich.progress import track
@ -35,7 +35,7 @@ def name_normalize(
name = re.sub(r'(\d+)\s?\/\s?(\d+)', r'\1 of \2', name)
name = re.sub(r'(\w+)\s?\/\s?(\w+)', r'\1 or \2', name)
name = re.sub(r'\/', '', name)
name[:30]
name[:30] # the hell this little guy does?
lang = settings.config['reddit']['thread']['post_lang']
if lang:

@ -25,7 +25,7 @@ _exceptions = TypeVar('_exceptions', bound=Optional[Union[type, tuple, list]])
@attrs
class ExceptionDecorator:
"""
Factory for decorating functions
Decorator factory for catching exceptions and writing logs
"""
exception: Optional[_exceptions] = attrib(default=None)
__default_exception: _exceptions = attrib(default=BrowserTimeoutError)
@ -45,15 +45,17 @@ class ExceptionDecorator:
except Exception as caughtException:
import logging
logging.basicConfig(filename='.webdriver.log', filemode='a+',
encoding='utf-8', level=logging.ERROR)
logger = logging.getLogger('webdriver_log')
logger.setLevel(logging.DEBUG)
handler = logging.FileHandler('.webdriver.log', mode='a+', encoding='utf-8')
logger.addHandler(handler)
if isinstance(self.exception, type):
if not type(caughtException) == self.exception:
logging.error(f'unexpected error - {caughtException}')
logger.error(f'unexpected error - {caughtException}')
else:
if not type(caughtException) in self.exception:
logging.error(f'unexpected error - {caughtException}')
logger.error(f'unexpected error - {caughtException}')
return wrapper

@ -29,7 +29,7 @@ def save_text_to_mp3(
"""
voice = settings.config['settings']['tts']['choice']
if str(voice).casefold() not in map(lambda _: _.casefold(), TTSProviders):
if voice.casefold() not in map(lambda _: _.casefold(), TTSProviders):
while True:
print_step('Please choose one of the following TTS providers: ')
print_table(TTSProviders)
@ -45,6 +45,7 @@ def get_case_insensitive_key_value(
input_dict,
key,
) -> object:
# TODO add a factory later
return next(
(value for dict_key, value in input_dict.items() if dict_key.lower() == key.lower()),
None,

Loading…
Cancel
Save