self review: improved typing & logging, removed unused imports, fixes in README

pull/963/head
Drugsosos 3 years ago
parent 01e9b2e4d0
commit acf679bfb6
No known key found for this signature in database
GPG Key ID: 8E35176FE617E28D

@ -33,7 +33,6 @@ The only original thing being done is the editing and gathering of all materials
## Requirements ## Requirements
- Python 3.9+ - Python 3.9+
- Playwright (this should install automatically in installation)
## Installation 👩‍💻 ## Installation 👩‍💻

@ -5,12 +5,20 @@ from gtts import gTTS
class GTTS: class GTTS:
max_chars = 0 max_chars = 0
# voices = []
@staticmethod @staticmethod
async def run( async def run(
text, text,
filepath filepath
) -> None: ) -> None:
"""
Calls for TTS api
Args:
text: text to be voiced over
filepath: name of the audio file
"""
tts = gTTS( tts = gTTS(
text=text, text=text,
lang=settings.config["reddit"]["thread"]["post_lang"] or "en", lang=settings.config["reddit"]["thread"]["post_lang"] or "en",

@ -1,10 +1,8 @@
import base64
from utils import settings from utils import settings
import requests import requests
from requests.adapters import HTTPAdapter, Retry from requests.adapters import HTTPAdapter, Retry
from attr import attrs, attrib from attr import attrs, attrib
from attr.validators import instance_of
from TTS.common import BaseApiTTS, get_random_voice from TTS.common import BaseApiTTS, get_random_voice
@ -74,8 +72,20 @@ class TikTok(BaseApiTTS): # TikTok Text-to-Speech Wrapper
max_chars = 300 max_chars = 300
decode_base64 = True decode_base64 = True
def __attrs_post_init__(self): def make_request(
self.voice = ( self,
text: str,
):
"""
Makes a requests to remote TTS service
Args:
text: text to be voice over
Returns:
Request's response
"""
voice = (
get_random_voice(voices, 'human') get_random_voice(voices, 'human')
if self.random_voice if self.random_voice
else str(settings.config['settings']['tts']['tiktok_voice']).lower() else str(settings.config['settings']['tts']['tiktok_voice']).lower()
@ -83,16 +93,11 @@ class TikTok(BaseApiTTS): # TikTok Text-to-Speech Wrapper
voice.lower() for dict_title in voices for voice in voices[dict_title]] voice.lower() for dict_title in voices for voice in voices[dict_title]]
else get_random_voice(voices, 'human') else get_random_voice(voices, 'human')
) )
def make_request(
self,
text: str,
):
try: try:
r = requests.post( r = requests.post(
self.uri_base, self.uri_base,
params={ params={
'text_speaker': self.voice, 'text_speaker': voice,
'req_text': text, 'req_text': text,
'speaker_map_type': 0, 'speaker_map_type': 0,
}) })
@ -103,6 +108,6 @@ class TikTok(BaseApiTTS): # TikTok Text-to-Speech Wrapper
adapter = HTTPAdapter(max_retries=retry) adapter = HTTPAdapter(max_retries=retry)
session.mount('http://', adapter) session.mount('http://', adapter)
session.mount('https://', adapter) session.mount('https://', adapter)
r = session.post(f'{self.uri_base}{self.voice}&req_text={text}&speaker_map_type=0') r = session.post(f'{self.uri_base}{voice}&req_text={text}&speaker_map_type=0')
# print(r.text) # print(r.text)
return r.json()['data']['v_str'] return r.json()['data']['v_str']

@ -37,7 +37,14 @@ class AWSPolly:
self, self,
text, text,
filepath, filepath,
): ) -> None:
"""
Calls for TTS api
Args:
text: text to be voiced over
filepath: name of the audio file
"""
try: try:
session = Session(profile_name='polly') session = Session(profile_name='polly')
polly = session.client('polly') polly = session.client('polly')

@ -12,6 +12,16 @@ class BaseApiTTS:
text: str, text: str,
max_length: int, max_length: int,
) -> list: ) -> list:
"""
Splits text if it's too long to be a query
Args:
text: text to be sanitized
max_length: maximum length of the query
Returns:
Split text as a list
"""
# Split by comma or dot (else you can lose intonations), if there is non, split by groups of 299 chars # Split by comma or dot (else you can lose intonations), if there is non, split by groups of 299 chars
if '.' in text and all([split_text.__len__() < max_length for split_text in text.split('.')]): if '.' in text and all([split_text.__len__() < max_length for split_text in text.split('.')]):
return text.split('.') return text.split('.')
@ -26,6 +36,13 @@ class BaseApiTTS:
output_text: str, output_text: str,
filepath: str, filepath: str,
) -> None: ) -> None:
"""
Writes and decodes TTS responses in files
Args:
output_text: text to be written
filepath: path/name of the file
"""
decoded_text = base64.b64decode(output_text) if self.decode_base64 else output_text decoded_text = base64.b64decode(output_text) if self.decode_base64 else output_text
with open(filepath, 'wb') as out: with open(filepath, 'wb') as out:
@ -36,6 +53,16 @@ class BaseApiTTS:
text: str, text: str,
filepath: str, filepath: str,
) -> None: ) -> None:
"""
Calls for TTS api and writes audio file
Args:
text: text to be voice over
filepath: path/name of the file
Returns:
"""
output_text = '' output_text = ''
if len(text) > self.max_chars: if len(text) > self.max_chars:
for part in self.text_len_sanitize(text, self.max_chars): for part in self.text_len_sanitize(text, self.max_chars):
@ -50,19 +77,45 @@ def get_random_voice(
voices: Union[list, dict], voices: Union[list, dict],
key: Optional[str] = None, key: Optional[str] = None,
) -> str: ) -> str:
"""
Return random voice from list or dict
Args:
voices: list or dict of voices
key: key of a dict if you are using one
Returns:
random voice as a str
"""
if isinstance(voices, list): if isinstance(voices, list):
return choice(voices) return choice(voices)
else: else:
return choice(voices[key]) return choice(voices[key] if key else list(voices.values())[0])
def audio_length( def audio_length(
path: str, path: str,
) -> float | int: ) -> float | int:
"""
Gets the length of the audio file
Args:
path: audio file path
Returns:
length in seconds as an int
"""
from mutagen.mp3 import MP3 from mutagen.mp3 import MP3
try: try:
audio = MP3(path) audio = MP3(path)
return audio.info.length return audio.info.length
except Exception as e: # TODO add logging except Exception as e:
import logging
logger = logging.getLogger('spam_application')
logger.setLevel(logging.DEBUG)
handler = logging.FileHandler('tts_log', mode='a+', encoding='utf-8')
logger.addHandler(handler)
logger.error('Error occurred in audio_length:', e)
return 0 return 0

@ -39,7 +39,9 @@ class TTSEngine:
) )
def __attrs_post_init__(self): def __attrs_post_init__(self):
# Calls an instance of the tts_module class
self.tts_module = self.tts_module() self.tts_module = self.tts_module()
# Loading settings from the config
self.max_length: int = settings.config['settings']['video_length'] self.max_length: int = settings.config['settings']['video_length']
self.time_before_tts: float = settings.config['settings']['time_before_tts'] self.time_before_tts: float = settings.config['settings']['time_before_tts']
self.time_between_pictures: float = settings.config['settings']['time_between_pictures'] self.time_between_pictures: float = settings.config['settings']['time_between_pictures']
@ -51,7 +53,12 @@ class TTSEngine:
def run( def run(
self self
) -> list: ) -> list:
"""
Voices over comments & title of the submission
Returns:
Indexes of comments to be used in the final video
"""
Path(self.path).mkdir(parents=True, exist_ok=True) Path(self.path).mkdir(parents=True, exist_ok=True)
# This file needs to be removed in case this post does not use post text # This file needs to be removed in case this post does not use post text
@ -87,6 +94,16 @@ class TTSEngine:
filename: str, filename: str,
text: str text: str
) -> bool: ) -> bool:
"""
Calls for TTS api from the factory
Args:
filename: name of audio file w/o .mp3
text: text to be voiced over
Returns:
True if audio files not exceeding the maximum length else false
"""
if not text: if not text:
return False return False
@ -107,6 +124,15 @@ class TTSEngine:
def process_text( def process_text(
text: str, text: str,
) -> str: ) -> str:
"""
Sanitizes text for illegal characters and translates text
Args:
text: text to be sanitized & translated
Returns:
Processed text as a str
"""
lang = settings.config['reddit']['thread']['post_lang'] lang = settings.config['reddit']['thread']['post_lang']
new_text = sanitize_text(text) new_text = sanitize_text(text)
if lang: if lang:

@ -42,6 +42,15 @@ class StreamlabsPolly(BaseApiTTS):
self, self,
text, text,
): ):
"""
Makes a requests to remote TTS service
Args:
text: text to be voice over
Returns:
Request's response
"""
voice = ( voice = (
get_random_voice(voices) get_random_voice(voices)
if self.random_voice if self.random_voice

@ -7,7 +7,6 @@ from utils.cleanup import cleanup
from utils.console import print_markdown, print_step from utils.console import print_markdown, print_step
from utils import settings from utils import settings
# from utils.checker import envUpdate
from video_creation.background import ( from video_creation.background import (
get_background_config, get_background_config,
) )

@ -9,7 +9,7 @@ from utils.console import handle_input
console = Console() console = Console()
config = dict() # calling instance of a dict to calm lint down config = dict() # calling instance of a dict to calm lint down (dict[any] will work as well)
def crawl(obj: dict, func=lambda x, y: print(x, y, end="\n"), path=None): def crawl(obj: dict, func=lambda x, y: print(x, y, end="\n"), path=None):

@ -11,7 +11,7 @@ if sys.version_info[0] >= 3:
def check_ratelimit( def check_ratelimit(
response: Response response: Response,
): ):
""" """
Checks if the response is a ratelimit response. Checks if the response is a ratelimit response.

@ -13,7 +13,7 @@ from moviepy.editor import (
CompositeAudioClip, CompositeAudioClip,
CompositeVideoClip, CompositeVideoClip,
) )
from moviepy.video.io.ffmpeg_tools import ffmpeg_merge_video_audio, ffmpeg_extract_subclip from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip
from rich.console import Console from rich.console import Console
from rich.progress import track from rich.progress import track
@ -35,7 +35,7 @@ def name_normalize(
name = re.sub(r'(\d+)\s?\/\s?(\d+)', r'\1 of \2', name) name = re.sub(r'(\d+)\s?\/\s?(\d+)', r'\1 of \2', name)
name = re.sub(r'(\w+)\s?\/\s?(\w+)', r'\1 or \2', name) name = re.sub(r'(\w+)\s?\/\s?(\w+)', r'\1 or \2', name)
name = re.sub(r'\/', '', name) name = re.sub(r'\/', '', name)
name[:30] name[:30] # the hell this little guy does?
lang = settings.config['reddit']['thread']['post_lang'] lang = settings.config['reddit']['thread']['post_lang']
if lang: if lang:

@ -25,7 +25,7 @@ _exceptions = TypeVar('_exceptions', bound=Optional[Union[type, tuple, list]])
@attrs @attrs
class ExceptionDecorator: class ExceptionDecorator:
""" """
Factory for decorating functions Decorator factory for catching exceptions and writing logs
""" """
exception: Optional[_exceptions] = attrib(default=None) exception: Optional[_exceptions] = attrib(default=None)
__default_exception: _exceptions = attrib(default=BrowserTimeoutError) __default_exception: _exceptions = attrib(default=BrowserTimeoutError)
@ -45,15 +45,17 @@ class ExceptionDecorator:
except Exception as caughtException: except Exception as caughtException:
import logging import logging
logging.basicConfig(filename='.webdriver.log', filemode='a+', logger = logging.getLogger('webdriver_log')
encoding='utf-8', level=logging.ERROR) logger.setLevel(logging.DEBUG)
handler = logging.FileHandler('.webdriver.log', mode='a+', encoding='utf-8')
logger.addHandler(handler)
if isinstance(self.exception, type): if isinstance(self.exception, type):
if not type(caughtException) == self.exception: if not type(caughtException) == self.exception:
logging.error(f'unexpected error - {caughtException}') logger.error(f'unexpected error - {caughtException}')
else: else:
if not type(caughtException) in self.exception: if not type(caughtException) in self.exception:
logging.error(f'unexpected error - {caughtException}') logger.error(f'unexpected error - {caughtException}')
return wrapper return wrapper

@ -29,7 +29,7 @@ def save_text_to_mp3(
""" """
voice = settings.config['settings']['tts']['choice'] voice = settings.config['settings']['tts']['choice']
if str(voice).casefold() not in map(lambda _: _.casefold(), TTSProviders): if voice.casefold() not in map(lambda _: _.casefold(), TTSProviders):
while True: while True:
print_step('Please choose one of the following TTS providers: ') print_step('Please choose one of the following TTS providers: ')
print_table(TTSProviders) print_table(TTSProviders)
@ -45,6 +45,7 @@ def get_case_insensitive_key_value(
input_dict, input_dict,
key, key,
) -> object: ) -> object:
# TODO add a factory later
return next( return next(
(value for dict_key, value in input_dict.items() if dict_key.lower() == key.lower()), (value for dict_key, value in input_dict.items() if dict_key.lower() == key.lower()),
None, None,

Loading…
Cancel
Save