Merge branch 'develop' into develop

pull/1278/head
TianYuan 3 years ago committed by GitHub
commit 318cc9e539
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,6 +31,7 @@ from ..log import logger
from ..utils import cli_register from ..utils import cli_register
from ..utils import download_and_decompress from ..utils import download_and_decompress
from ..utils import MODEL_HOME from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.transform.transformation import Transformation from paddlespeech.s2t.transform.transformation import Transformation
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
@ -425,6 +426,7 @@ class ASRExecutor(BaseExecutor):
logger.exception(e) logger.exception(e)
return False return False
@stats_wrapper
def __call__(self, def __call__(self,
audio_file: os.PathLike, audio_file: os.PathLike,
model: str='conformer_wenetspeech', model: str='conformer_wenetspeech',

@ -26,6 +26,7 @@ from ..log import logger
from ..utils import cli_register from ..utils import cli_register
from ..utils import download_and_decompress from ..utils import download_and_decompress
from ..utils import MODEL_HOME from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from paddleaudio import load from paddleaudio import load
from paddleaudio.features import LogMelSpectrogram from paddleaudio.features import LogMelSpectrogram
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
@ -245,6 +246,7 @@ class CLSExecutor(BaseExecutor):
logger.exception(e) logger.exception(e)
return False return False
@stats_wrapper
def __call__(self, def __call__(self,
audio_file: os.PathLike, audio_file: os.PathLike,
model: str='panns_cnn14', model: str='panns_cnn14',

@ -30,6 +30,7 @@ from ..log import logger
from ..utils import cli_register from ..utils import cli_register
from ..utils import download_and_decompress from ..utils import download_and_decompress
from ..utils import MODEL_HOME from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
@ -334,6 +335,7 @@ class STExecutor(BaseExecutor):
logger.exception(e) logger.exception(e)
return False return False
@stats_wrapper
def __call__(self, def __call__(self,
audio_file: os.PathLike, audio_file: os.PathLike,
model: str='fat_st_ted', model: str='fat_st_ted',

@ -26,6 +26,7 @@ from ..log import logger
from ..utils import cli_register from ..utils import cli_register
from ..utils import download_and_decompress from ..utils import download_and_decompress
from ..utils import MODEL_HOME from ..utils import MODEL_HOME
from ..utils import stats_wrapper
__all__ = ['TextExecutor'] __all__ = ['TextExecutor']
@ -272,6 +273,7 @@ class TextExecutor(BaseExecutor):
logger.exception(e) logger.exception(e)
return False return False
@stats_wrapper
def __call__( def __call__(
self, self,
text: str, text: str,

@ -29,6 +29,7 @@ from ..log import logger
from ..utils import cli_register from ..utils import cli_register
from ..utils import download_and_decompress from ..utils import download_and_decompress
from ..utils import MODEL_HOME from ..utils import MODEL_HOME
from ..utils import stats_wrapper
from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.s2t.utils.dynamic_import import dynamic_import
from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend import English
from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.frontend.zh_frontend import Frontend
@ -645,6 +646,7 @@ class TTSExecutor(BaseExecutor):
logger.exception(e) logger.exception(e)
return False return False
@stats_wrapper
def __call__(self, def __call__(self,
text: str, text: str,
am: str='fastspeech2_csmsc', am: str='fastspeech2_csmsc',

@ -11,22 +11,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import hashlib
import inspect
import json
import os import os
import tarfile import tarfile
import threading
import time
import uuid
import zipfile import zipfile
from typing import Any from typing import Any
from typing import Dict from typing import Dict
import paddle
import paddleaudio
import requests
import yaml
from paddle.framework import load from paddle.framework import load
from . import download from . import download
from .. import __version__
from .entry import commands from .entry import commands
requests.adapters.DEFAULT_RETRIES = 3
__all__ = [ __all__ = [
'cli_register', 'cli_register',
'get_command', 'get_command',
'download_and_decompress', 'download_and_decompress',
'load_state_dict_from_url', 'load_state_dict_from_url',
'stats_wrapper',
] ]
@ -101,6 +115,13 @@ def download_and_decompress(archive: Dict[str, str], path: str) -> os.PathLike:
if not os.path.isdir(uncompress_path): if not os.path.isdir(uncompress_path):
download._decompress(filepath) download._decompress(filepath)
else: else:
StatsWorker(
task='download',
version=__version__,
extra_info={
'download_url': archive['url'],
'paddle_version': paddle.__version__
}).start()
uncompress_path = download.get_path_from_url(archive['url'], path, uncompress_path = download.get_path_from_url(archive['url'], path,
archive['md5']) archive['md5'])
@ -146,3 +167,171 @@ def _get_sub_home(directory):
PPSPEECH_HOME = _get_paddlespcceh_home() PPSPEECH_HOME = _get_paddlespcceh_home()
MODEL_HOME = _get_sub_home('models') MODEL_HOME = _get_sub_home('models')
CONF_HOME = _get_sub_home('conf')
def _md5(text: str):
'''Calculate the md5 value of the input text.'''
md5code = hashlib.md5(text.encode())
return md5code.hexdigest()
class ConfigCache:
def __init__(self):
self._data = {}
self._initialize()
self.file = os.path.join(CONF_HOME, 'cache.yaml')
if not os.path.exists(self.file):
self.flush()
return
with open(self.file, 'r') as file:
try:
cfg = yaml.load(file, Loader=yaml.FullLoader)
self._data.update(cfg)
except:
self.flush()
@property
def cache_info(self):
return self._data['cache_info']
def _initialize(self):
# Set default configuration values.
cache_info = _md5(str(uuid.uuid1())[-12:]) + "-" + str(int(time.time()))
self._data['cache_info'] = cache_info
def flush(self):
'''Flush the current configuration into the configuration file.'''
with open(self.file, 'w') as file:
cfg = json.loads(json.dumps(self._data))
yaml.dump(cfg, file)
stats_api = "http://paddlepaddle.org.cn/paddlehub/stat"
cache_info = ConfigCache().cache_info
class StatsWorker(threading.Thread):
def __init__(self,
task="asr",
model=None,
version=__version__,
extra_info={}):
threading.Thread.__init__(self)
self._task = task
self._model = model
self._version = version
self._extra_info = extra_info
def run(self):
params = {
'task': self._task,
'version': self._version,
'from': 'ppspeech'
}
if self._model:
params['model'] = self._model
self._extra_info.update({
'cache_info': cache_info,
})
params.update({"extra": json.dumps(self._extra_info)})
try:
requests.get(stats_api, params)
except Exception:
pass
return
def _note_one_stat(cls_name, params={}):
task = cls_name.replace('Executor', '').lower() # XXExecutor
extra_info = {
'paddle_version': paddle.__version__,
}
if 'model' in params:
model = params['model']
else:
model = None
if 'audio_file' in params:
try:
_, sr = paddleaudio.load(params['audio_file'])
except Exception:
sr = -1
if task == 'asr':
extra_info.update({
'lang': params['lang'],
'inp_sr': sr,
'model_sr': params['sample_rate'],
})
elif task == 'st':
extra_info.update({
'lang':
params['src_lang'] + '-' + params['tgt_lang'],
'inp_sr':
sr,
'model_sr':
params['sample_rate'],
})
elif task == 'tts':
model = params['am']
extra_info.update({
'lang': params['lang'],
'vocoder': params['voc'],
})
elif task == 'cls':
extra_info.update({
'inp_sr': sr,
})
elif task == 'text':
extra_info.update({
'sub_task': params['task'],
'lang': params['lang'],
})
else:
return
StatsWorker(
task=task,
model=model,
version=__version__,
extra_info=extra_info, ).start()
def _parse_args(func, *args, **kwargs):
# FullArgSpec(args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations)
argspec = inspect.getfullargspec(func)
keys = argspec[0]
if keys[0] == 'self': # Remove self pointer.
keys = keys[1:]
default_values = argspec[3]
values = [None] * (len(keys) - len(default_values))
values.extend(list(default_values))
params = dict(zip(keys, values))
for idx, v in enumerate(args):
params[keys[idx]] = v
for k, v in kwargs.items():
params[k] = v
return params
def stats_wrapper(executor_func):
def _warpper(self, *args, **kwargs):
try:
_note_one_stat(
type(self).__name__, _parse_args(executor_func, *args,
**kwargs))
except Exception:
pass
return executor_func(self, *args, **kwargs)
return _warpper

@ -238,7 +238,9 @@ class U2Trainer(Trainer):
preprocess_conf=config.preprocess_config, preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers, n_iter_processes=config.num_workers,
subsampling_factor=1, subsampling_factor=1,
num_encs=1) num_encs=1,
dist_sampler=False,
shortest_first=False)
self.valid_loader = BatchDataLoader( self.valid_loader = BatchDataLoader(
json_file=config.dev_manifest, json_file=config.dev_manifest,
@ -257,7 +259,9 @@ class U2Trainer(Trainer):
preprocess_conf=config.preprocess_config, preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers, n_iter_processes=config.num_workers,
subsampling_factor=1, subsampling_factor=1,
num_encs=1) num_encs=1,
dist_sampler=False,
shortest_first=False)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
decode_batch_size = config.get('decode', dict()).get( decode_batch_size = config.get('decode', dict()).get(

@ -78,7 +78,8 @@ class BatchDataLoader():
load_aux_input: bool=False, load_aux_input: bool=False,
load_aux_output: bool=False, load_aux_output: bool=False,
num_encs: int=1, num_encs: int=1,
dist_sampler: bool=False): dist_sampler: bool=False,
shortest_first: bool=False):
self.json_file = json_file self.json_file = json_file
self.train_mode = train_mode self.train_mode = train_mode
self.use_sortagrad = sortagrad == -1 or sortagrad > 0 self.use_sortagrad = sortagrad == -1 or sortagrad > 0
@ -97,6 +98,7 @@ class BatchDataLoader():
self.load_aux_input = load_aux_input self.load_aux_input = load_aux_input
self.load_aux_output = load_aux_output self.load_aux_output = load_aux_output
self.dist_sampler = dist_sampler self.dist_sampler = dist_sampler
self.shortest_first = shortest_first
# read json data # read json data
with jsonlines.open(json_file, 'r') as reader: with jsonlines.open(json_file, 'r') as reader:
@ -113,7 +115,7 @@ class BatchDataLoader():
maxlen_out, maxlen_out,
minibatches, # for debug minibatches, # for debug
min_batch_size=mini_batch_size, min_batch_size=mini_batch_size,
shortest_first=self.use_sortagrad, shortest_first=self.shortest_first or self.use_sortagrad,
count=batch_count, count=batch_count,
batch_bins=batch_bins, batch_bins=batch_bins,
batch_frames_in=batch_frames_in, batch_frames_in=batch_frames_in,
@ -149,13 +151,13 @@ class BatchDataLoader():
self.reader) self.reader)
if self.dist_sampler: if self.dist_sampler:
self.sampler = DistributedBatchSampler( self.batch_sampler = DistributedBatchSampler(
dataset=self.dataset, dataset=self.dataset,
batch_size=1, batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False, shuffle=not self.use_sortagrad if self.train_mode else False,
drop_last=False, ) drop_last=False, )
else: else:
self.sampler = BatchSampler( self.batch_sampler = BatchSampler(
dataset=self.dataset, dataset=self.dataset,
batch_size=1, batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False, shuffle=not self.use_sortagrad if self.train_mode else False,
@ -163,7 +165,7 @@ class BatchDataLoader():
self.dataloader = DataLoader( self.dataloader = DataLoader(
dataset=self.dataset, dataset=self.dataset,
batch_sampler=self.sampler, batch_sampler=self.batch_sampler,
collate_fn=batch_collate, collate_fn=batch_collate,
num_workers=self.n_iter_processes, ) num_workers=self.n_iter_processes, )
@ -194,5 +196,6 @@ class BatchDataLoader():
echo += f"load_aux_input: {self.load_aux_input}, " echo += f"load_aux_input: {self.load_aux_input}, "
echo += f"load_aux_output: {self.load_aux_output}, " echo += f"load_aux_output: {self.load_aux_output}, "
echo += f"dist_sampler: {self.dist_sampler}, " echo += f"dist_sampler: {self.dist_sampler}, "
echo += f"shortest_first: {self.shortest_first}, "
echo += f"file: {self.json_file}" echo += f"file: {self.json_file}"
return echo return echo

@ -39,10 +39,6 @@ except ImportError:
except Exception as e: except Exception as e:
logger.info("paddlespeech_ctcdecoders not installed!") logger.info("paddlespeech_ctcdecoders not installed!")
#try:
#except Exception as e:
# logger.info("ctcdecoder not installed!")
__all__ = ['CTCDecoder'] __all__ = ['CTCDecoder']

@ -67,9 +67,10 @@ class WarmupLR(LRScheduler):
super().__init__(learning_rate, last_epoch, verbose) super().__init__(learning_rate, last_epoch, verbose)
def __repr__(self): def __repr__(self):
return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps}, lr={self.base_lr}, last_epoch={self.last_epoch})"
def get_lr(self): def get_lr(self):
# self.last_epoch start from zero
step_num = self.last_epoch + 1 step_num = self.last_epoch + 1
return self.base_lr * self.warmup_steps**0.5 * min( return self.base_lr * self.warmup_steps**0.5 * min(
step_num**-0.5, step_num * self.warmup_steps**-1.5) step_num**-0.5, step_num * self.warmup_steps**-1.5)

@ -222,7 +222,7 @@ class Trainer():
batch_sampler = self.train_loader.batch_sampler batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler): if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
logger.debug( logger.debug(
f"train_loader.batch_sample set epoch: {self.epoch}") f"train_loader.batch_sample.set_epoch: {self.epoch}")
batch_sampler.set_epoch(self.epoch) batch_sampler.set_epoch(self.epoch)
def before_train(self): def before_train(self):

Loading…
Cancel
Save