replace logger.info with logger.debug in cli, change default log level to INFO

pull/2111/head
TianYuan 2 years ago
parent cf846f9ebc
commit bc93bffbb4

@ -133,11 +133,11 @@ class ASRExecutor(BaseExecutor):
"""
Init model and other resources from a specific path.
"""
logger.info("start to init the model")
logger.debug("start to init the model")
# default max_len: unit:second
self.max_len = 50
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
logger.debug('Model had been initialized.')
return
if cfg_path is None or ckpt_path is None:
@ -151,15 +151,15 @@ class ASRExecutor(BaseExecutor):
self.ckpt_path = os.path.join(
self.res_path,
self.task_resource.res_dict['ckpt_path'] + ".pdparams")
logger.info(self.res_path)
logger.debug(self.res_path)
else:
self.cfg_path = os.path.abspath(cfg_path)
self.ckpt_path = os.path.abspath(ckpt_path + ".pdparams")
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
logger.debug(self.cfg_path)
logger.debug(self.ckpt_path)
#Init body.
self.config = CfgNode(new_allowed=True)
@ -216,7 +216,7 @@ class ASRExecutor(BaseExecutor):
max_len = self.config.encoder_conf.max_len
self.max_len = frame_shift_ms * max_len * subsample_rate
logger.info(
logger.debug(
f"The asr server limit max duration len: {self.max_len}")
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
@ -227,15 +227,15 @@ class ASRExecutor(BaseExecutor):
audio_file = input
if isinstance(audio_file, (str, os.PathLike)):
logger.info("Preprocess audio_file:" + audio_file)
logger.debug("Preprocess audio_file:" + audio_file)
# Get the object for feature extraction
if "deepspeech2" in model_type or "conformer" in model_type or "transformer" in model_type:
logger.info("get the preprocess conf")
logger.debug("get the preprocess conf")
preprocess_conf = self.config.preprocess_config
preprocess_args = {"train": False}
preprocessing = Transformation(preprocess_conf)
logger.info("read the audio file")
logger.debug("read the audio file")
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
if self.change_format:
@ -255,7 +255,7 @@ class ASRExecutor(BaseExecutor):
else:
audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}")
logger.debug(f"audio shape: {audio.shape}")
# fbank
audio = preprocessing(audio, **preprocess_args)
@ -264,19 +264,19 @@ class ASRExecutor(BaseExecutor):
self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}")
logger.debug(f"audio feat shape: {audio.shape}")
else:
raise Exception("wrong type")
logger.info("audio feat process success")
logger.debug("audio feat process success")
@paddle.no_grad()
def infer(self, model_type: str):
"""
Model inference and result stored in self.output.
"""
logger.info("start to infer the model to get the output")
logger.debug("start to infer the model to get the output")
cfg = self.config.decode
audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"]
@ -293,7 +293,7 @@ class ASRExecutor(BaseExecutor):
self._outputs["result"] = result_transcripts[0]
elif "conformer" in model_type or "transformer" in model_type:
logger.info(
logger.debug(
f"we will use the transformer like model : {model_type}")
try:
result_transcripts = self.model.decode(
@ -352,7 +352,7 @@ class ASRExecutor(BaseExecutor):
logger.error("Please input the right audio file path")
return False
logger.info("checking the audio file format......")
logger.debug("checking the audio file format......")
try:
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="int16", always_2d=True)
@ -374,7 +374,7 @@ class ASRExecutor(BaseExecutor):
sox input_audio.xx --rate 8k --bits 16 --channels 1 output_audio.wav \n \
")
return False
logger.info("The sample rate is %d" % audio_sample_rate)
logger.debug("The sample rate is %d" % audio_sample_rate)
if audio_sample_rate != self.sample_rate:
logger.warning("The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
@ -383,28 +383,28 @@ class ASRExecutor(BaseExecutor):
".format(self.sample_rate, self.sample_rate))
if force_yes is False:
while (True):
logger.info(
logger.debug(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
)
content = input("Input(Y/N):")
if content.strip() == "Y" or content.strip(
) == "y" or content.strip() == "yes" or content.strip(
) == "Yes":
logger.info(
logger.debug(
"change the sampele rate, channel to 16k and 1 channel"
)
break
elif content.strip() == "N" or content.strip(
) == "n" or content.strip() == "no" or content.strip(
) == "No":
logger.info("Exit the program")
logger.debug("Exit the program")
return False
else:
logger.warning("Not regular input, please input again")
self.change_format = True
else:
logger.info("The audio file format is right")
logger.debug("The audio file format is right")
self.change_format = False
return True

@ -92,7 +92,7 @@ class CLSExecutor(BaseExecutor):
Init model and other resources from a specific path.
"""
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
logger.debug('Model had been initialized.')
return
if label_file is None or ckpt_path is None:
@ -135,14 +135,14 @@ class CLSExecutor(BaseExecutor):
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
"""
feat_conf = self._conf['feature']
logger.info(feat_conf)
logger.debug(feat_conf)
waveform, _ = load(
file=audio_file,
sr=feat_conf['sample_rate'],
mono=True,
dtype='float32')
if isinstance(audio_file, (str, os.PathLike)):
logger.info("Preprocessing audio_file:" + audio_file)
logger.debug("Preprocessing audio_file:" + audio_file)
# Feature extraction
feature_extractor = LogMelSpectrogram(

@ -61,7 +61,7 @@ def _get_unique_endpoints(trainer_endpoints):
continue
ips.add(ip)
unique_endpoints.add(endpoint)
logger.info("unique_endpoints {}".format(unique_endpoints))
logger.debug("unique_endpoints {}".format(unique_endpoints))
return unique_endpoints
@ -96,7 +96,7 @@ def get_path_from_url(url,
# data, and the same ip will only download data once.
unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:])
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
logger.info("Found {}".format(fullpath))
logger.debug("Found {}".format(fullpath))
else:
if ParallelEnv().current_endpoint in unique_endpoints:
fullpath = _download(url, root_dir, md5sum, method=method)
@ -118,7 +118,7 @@ def _get_download(url, fullname):
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
logger.debug("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False
@ -190,7 +190,7 @@ def _download(url, path, md5sum=None, method='get'):
fullname = osp.join(path, fname)
retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url))
logger.debug("Downloading {} from {}".format(fname, url))
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
@ -209,7 +209,7 @@ def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
logger.info("File {} md5 checking...".format(fullname))
logger.debug("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
@ -217,8 +217,8 @@ def _md5check(fullname, md5sum=None):
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
logger.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
logger.debug("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
@ -227,7 +227,7 @@ def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
logger.debug("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress

@ -88,7 +88,7 @@ class KWSExecutor(BaseExecutor):
Init model and other resources from a specific path.
"""
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
logger.debug('Model had been initialized.')
return
if ckpt_path is None:
@ -141,7 +141,7 @@ class KWSExecutor(BaseExecutor):
assert os.path.isfile(audio_file)
waveform, _ = load(audio_file)
if isinstance(audio_file, (str, os.PathLike)):
logger.info("Preprocessing audio_file:" + audio_file)
logger.debug("Preprocessing audio_file:" + audio_file)
# Feature extraction
waveform = paddle.to_tensor(waveform).unsqueeze(0)

@ -49,7 +49,7 @@ class Logger(object):
self.handler.setFormatter(self.format)
self.logger.addHandler(self.handler)
self.logger.setLevel(logging.DEBUG)
self.logger.setLevel(logging.INFO)
self.logger.propagate = False
def __call__(self, log_level: str, msg: str):

@ -110,7 +110,7 @@ class STExecutor(BaseExecutor):
"""
decompressed_path = download_and_decompress(self.kaldi_bins, MODEL_HOME)
decompressed_path = os.path.abspath(decompressed_path)
logger.info("Kaldi_bins stored in: {}".format(decompressed_path))
logger.debug("Kaldi_bins stored in: {}".format(decompressed_path))
if "LD_LIBRARY_PATH" in os.environ:
os.environ["LD_LIBRARY_PATH"] += f":{decompressed_path}"
else:
@ -128,7 +128,7 @@ class STExecutor(BaseExecutor):
Init model and other resources from a specific path.
"""
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
logger.debug('Model had been initialized.')
return
if cfg_path is None or ckpt_path is None:
@ -140,8 +140,8 @@ class STExecutor(BaseExecutor):
self.ckpt_path = os.path.join(
self.task_resource.res_dir,
self.task_resource.res_dict['ckpt_path'])
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
logger.debug(self.cfg_path)
logger.debug(self.ckpt_path)
res_path = self.task_resource.res_dir
else:
self.cfg_path = os.path.abspath(cfg_path)
@ -192,7 +192,7 @@ class STExecutor(BaseExecutor):
Input content can be a file(wav).
"""
audio_file = os.path.abspath(wav_file)
logger.info("Preprocess audio_file:" + audio_file)
logger.debug("Preprocess audio_file:" + audio_file)
if "fat_st" in model_type:
cmvn = self.config.cmvn_path

@ -98,7 +98,7 @@ class TextExecutor(BaseExecutor):
Init model and other resources from a specific path.
"""
if hasattr(self, 'model'):
logger.info('Model had been initialized.')
logger.debug('Model had been initialized.')
return
self.task = task

@ -173,7 +173,7 @@ class TTSExecutor(BaseExecutor):
Init model and other resources from a specific path.
"""
if hasattr(self, 'am_inference') and hasattr(self, 'voc_inference'):
logger.info('Models had been initialized.')
logger.debug('Models had been initialized.')
return
# am
@ -200,9 +200,9 @@ class TTSExecutor(BaseExecutor):
# must have phones_dict in acoustic
self.phones_dict = os.path.join(
self.am_res_path, self.task_resource.res_dict['phones_dict'])
logger.info(self.am_res_path)
logger.info(self.am_config)
logger.info(self.am_ckpt)
logger.debug(self.am_res_path)
logger.debug(self.am_config)
logger.debug(self.am_ckpt)
else:
self.am_config = os.path.abspath(am_config)
self.am_ckpt = os.path.abspath(am_ckpt)
@ -248,9 +248,9 @@ class TTSExecutor(BaseExecutor):
self.voc_stat = os.path.join(
self.voc_res_path,
self.task_resource.voc_res_dict['speech_stats'])
logger.info(self.voc_res_path)
logger.info(self.voc_config)
logger.info(self.voc_ckpt)
logger.debug(self.voc_res_path)
logger.debug(self.voc_config)
logger.debug(self.voc_ckpt)
else:
self.voc_config = os.path.abspath(voc_config)
self.voc_ckpt = os.path.abspath(voc_ckpt)

@ -117,7 +117,7 @@ class VectorExecutor(BaseExecutor):
# stage 2: read the input data and store them as a list
task_source = self.get_input_source(parser_args.input)
logger.info(f"task source: {task_source}")
logger.debug(f"task source: {task_source}")
# stage 3: process the audio one by one
# we do action according the task type
@ -127,13 +127,13 @@ class VectorExecutor(BaseExecutor):
try:
# extract the speaker audio embedding
if parser_args.task == "spk":
logger.info("do vector spk task")
logger.debug("do vector spk task")
res = self(input_, model, sample_rate, config, ckpt_path,
device)
task_result[id_] = res
elif parser_args.task == "score":
logger.info("do vector score task")
logger.info(f"input content {input_}")
logger.debug("do vector score task")
logger.debug(f"input content {input_}")
if len(input_.split()) != 2:
logger.error(
f"vector score task input {input_} wav num is not two,"
@ -142,7 +142,7 @@ class VectorExecutor(BaseExecutor):
# get the enroll and test embedding
enroll_audio, test_audio = input_.split()
logger.info(
logger.debug(
f"score task, enroll audio: {enroll_audio}, test audio: {test_audio}"
)
enroll_embedding = self(enroll_audio, model, sample_rate,
@ -158,8 +158,8 @@ class VectorExecutor(BaseExecutor):
has_exceptions = True
task_result[id_] = f'{e.__class__.__name__}: {e}'
logger.info("task result as follows: ")
logger.info(f"{task_result}")
logger.debug("task result as follows: ")
logger.debug(f"{task_result}")
# stage 4: process the all the task results
self.process_task_results(parser_args.input, task_result,
@ -207,7 +207,7 @@ class VectorExecutor(BaseExecutor):
"""
if not hasattr(self, "score_func"):
self.score_func = paddle.nn.CosineSimilarity(axis=0)
logger.info("create the cosine score function ")
logger.debug("create the cosine score function ")
score = self.score_func(
paddle.to_tensor(enroll_embedding),
@ -244,7 +244,7 @@ class VectorExecutor(BaseExecutor):
sys.exit(-1)
# stage 1: set the paddle runtime host device
logger.info(f"device type: {device}")
logger.debug(f"device type: {device}")
paddle.device.set_device(device)
# stage 2: read the specific pretrained model
@ -283,7 +283,7 @@ class VectorExecutor(BaseExecutor):
# stage 0: avoid to init the mode again
self.task = task
if hasattr(self, "model"):
logger.info("Model has been initialized")
logger.debug("Model has been initialized")
return
# stage 1: get the model and config path
@ -294,7 +294,7 @@ class VectorExecutor(BaseExecutor):
sample_rate_str = "16k" if sample_rate == 16000 else "8k"
tag = model_type + "-" + sample_rate_str
self.task_resource.set_task_model(tag, version=None)
logger.info(f"load the pretrained model: {tag}")
logger.debug(f"load the pretrained model: {tag}")
# get the model from the pretrained list
# we download the pretrained model and store it in the res_path
self.res_path = self.task_resource.res_dir
@ -312,19 +312,19 @@ class VectorExecutor(BaseExecutor):
self.res_path = os.path.dirname(
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(f"start to read the ckpt from {self.ckpt_path}")
logger.info(f"read the config from {self.cfg_path}")
logger.info(f"get the res path {self.res_path}")
logger.debug(f"start to read the ckpt from {self.ckpt_path}")
logger.debug(f"read the config from {self.cfg_path}")
logger.debug(f"get the res path {self.res_path}")
# stage 2: read and config and init the model body
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
# stage 3: get the model name to instance the model network with dynamic_import
logger.info("start to dynamic import the model class")
logger.debug("start to dynamic import the model class")
model_name = model_type[:model_type.rindex('_')]
model_class = self.task_resource.get_model_class(model_name)
logger.info(f"model name {model_name}")
logger.debug(f"model name {model_name}")
model_conf = self.config.model
backbone = model_class(**model_conf)
model = SpeakerIdetification(
@ -333,11 +333,11 @@ class VectorExecutor(BaseExecutor):
self.model.eval()
# stage 4: load the model parameters
logger.info("start to set the model parameters to model")
logger.debug("start to set the model parameters to model")
model_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(model_dict)
logger.info("create the model instance success")
logger.debug("create the model instance success")
@paddle.no_grad()
def infer(self, model_type: str):
@ -349,14 +349,14 @@ class VectorExecutor(BaseExecutor):
# stage 0: get the feat and length from _inputs
feats = self._inputs["feats"]
lengths = self._inputs["lengths"]
logger.info("start to do backbone network model forward")
logger.info(
logger.debug("start to do backbone network model forward")
logger.debug(
f"feats shape:{feats.shape}, lengths shape: {lengths.shape}")
# stage 1: get the audio embedding
# embedding from (1, emb_size, 1) -> (emb_size)
embedding = self.model.backbone(feats, lengths).squeeze().numpy()
logger.info(f"embedding size: {embedding.shape}")
logger.debug(f"embedding size: {embedding.shape}")
# stage 2: put the embedding and dim info to _outputs property
# the embedding type is numpy.array
@ -380,12 +380,13 @@ class VectorExecutor(BaseExecutor):
"""
audio_file = input_file
if isinstance(audio_file, (str, os.PathLike)):
logger.info(f"Preprocess audio file: {audio_file}")
logger.debug(f"Preprocess audio file: {audio_file}")
# stage 1: load the audio sample points
# Note: this process must match the training process
waveform, sr = load_audio(audio_file)
logger.info(f"load the audio sample points, shape is: {waveform.shape}")
logger.debug(
f"load the audio sample points, shape is: {waveform.shape}")
# stage 2: get the audio feat
# Note: Now we only support fbank feature
@ -396,9 +397,9 @@ class VectorExecutor(BaseExecutor):
n_mels=self.config.n_mels,
window_size=self.config.window_size,
hop_length=self.config.hop_size)
logger.info(f"extract the audio feat, shape is: {feat.shape}")
logger.debug(f"extract the audio feat, shape is: {feat.shape}")
except Exception as e:
logger.info(f"feat occurs exception {e}")
logger.debug(f"feat occurs exception {e}")
sys.exit(-1)
feat = paddle.to_tensor(feat).unsqueeze(0)
@ -411,11 +412,11 @@ class VectorExecutor(BaseExecutor):
# stage 4: store the feat and length in the _inputs,
# which will be used in other function
logger.info(f"feats shape: {feat.shape}")
logger.debug(f"feats shape: {feat.shape}")
self._inputs["feats"] = feat
self._inputs["lengths"] = lengths
logger.info("audio extract the feat success")
logger.debug("audio extract the feat success")
def _check(self, audio_file: str, sample_rate: int):
"""Check if the model sample match the audio sample rate
@ -441,7 +442,7 @@ class VectorExecutor(BaseExecutor):
logger.error("Please input the right audio file path")
return False
logger.info("checking the aduio file format......")
logger.debug("checking the aduio file format......")
try:
audio, audio_sample_rate = soundfile.read(
audio_file, dtype="float32", always_2d=True)
@ -458,7 +459,7 @@ class VectorExecutor(BaseExecutor):
")
return False
logger.info(f"The sample rate is {audio_sample_rate}")
logger.debug(f"The sample rate is {audio_sample_rate}")
if audio_sample_rate != self.sample_rate:
logger.error("The sample rate of the input file is not {}.\n \
@ -468,6 +469,6 @@ class VectorExecutor(BaseExecutor):
".format(self.sample_rate, self.sample_rate))
sys.exit(-1)
else:
logger.info("The audio file format is right")
logger.debug("The audio file format is right")
return True

@ -16,7 +16,7 @@ import random
import numpy as np
from PIL import Image
from PIL.Image import BICUBIC
from PIL.Image import Resampling
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase
from paddlespeech.s2t.utils.log import Log
@ -164,9 +164,9 @@ class SpecAugmentor(AugmentorBase):
window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped),
BICUBIC)
Resampling.BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
BICUBIC)
Resampling.BICUBIC)
if self.inplace:
x[:warped] = left
x[warped:] = right

@ -226,10 +226,10 @@ class TextFeaturizer():
sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1
space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1
logger.info(f"BLANK id: {blank_id}")
logger.info(f"UNK id: {unk_id}")
logger.info(f"EOS id: {eos_id}")
logger.info(f"SOS id: {sos_id}")
logger.info(f"SPACE id: {space_id}")
logger.info(f"MASKCTC id: {maskctc_id}")
logger.debug(f"BLANK id: {blank_id}")
logger.debug(f"UNK id: {unk_id}")
logger.debug(f"EOS id: {eos_id}")
logger.debug(f"SOS id: {sos_id}")
logger.debug(f"SPACE id: {space_id}")
logger.debug(f"MASKCTC id: {maskctc_id}")
return token2id, id2token, vocab_list, unk_id, eos_id, blank_id

@ -827,7 +827,7 @@ class U2Model(U2DecodeModel):
# encoder
encoder_type = configs.get('encoder', 'transformer')
logger.info(f"U2 Encoder type: {encoder_type}")
logger.debug(f"U2 Encoder type: {encoder_type}")
if encoder_type == 'transformer':
encoder = TransformerEncoder(
input_dim, global_cmvn=global_cmvn, **configs['encoder_conf'])
@ -894,7 +894,7 @@ class U2Model(U2DecodeModel):
if checkpoint_path:
infos = checkpoint.Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
logger.debug(f"checkpoint info: {infos}")
layer_tools.summary(model)
return model

@ -37,9 +37,9 @@ class CTCLoss(nn.Layer):
self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
self.batch_average = batch_average
logger.info(
logger.debug(
f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}")
logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}")
logger.debug(f"CTCLoss Grad Norm Type: {grad_norm_type}")
assert grad_norm_type in ('instance', 'batch', 'frame', None)
self.norm_by_times = False
@ -70,7 +70,8 @@ class CTCLoss(nn.Layer):
param = {}
self._kwargs = {k: v for k, v in kwargs.items() if k in param}
_notin = {k: v for k, v in kwargs.items() if k not in param}
logger.info(f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}")
logger.debug(
f"{self.loss} kwargs:{self._kwargs}, not support: {_notin}")
def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss.

@ -17,7 +17,7 @@ import random
import numpy
from PIL import Image
from PIL.Image import BICUBIC
from PIL.Image import Resampling
from paddlespeech.s2t.transform.functional import FuncTrans
@ -46,9 +46,10 @@ def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"):
warped = random.randrange(center - window, center +
window) + 1 # 1 ... t - 1
left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC)
left = Image.fromarray(x[:center]).resize((x.shape[1], warped),
Resampling.BICUBIC)
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
BICUBIC)
Resampling.BICUBIC)
if inplace:
x[:warped] = left
x[warped:] = right

@ -94,7 +94,7 @@ def pad_sequence(sequences: List[paddle.Tensor],
for i, tensor in enumerate(sequences):
length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor
logger.info(
logger.debug(
f"length {length}, out_tensor {out_tensor.shape}, tensor {tensor.shape}"
)
if batch_first:

Loading…
Cancel
Save