format,test=doc

pull/1507/head
Hui Zhang 4 years ago
parent 54341c88a6
commit 75098698d8

@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
""" """
rng = np.random.RandomState(epoch) rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False assert clipped is False

@ -33,8 +33,6 @@ from paddlespeech.s2t.modules.decoder import TransformerDecoder
from paddlespeech.s2t.modules.encoder import ConformerEncoder from paddlespeech.s2t.modules.encoder import ConformerEncoder
from paddlespeech.s2t.modules.encoder import TransformerEncoder from paddlespeech.s2t.modules.encoder import TransformerEncoder
from paddlespeech.s2t.modules.loss import LabelSmoothingLoss from paddlespeech.s2t.modules.loss import LabelSmoothingLoss
from paddlespeech.s2t.modules.mask import mask_finished_preds
from paddlespeech.s2t.modules.mask import mask_finished_scores
from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.utils import checkpoint from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils import layer_tools
@ -291,7 +289,7 @@ class U2STBaseModel(nn.Layer):
device = speech.place device = speech.place
# Let's assume B = batch_size and N = beam_size # Let's assume B = batch_size and N = beam_size
# 1. Encoder and init hypothesis # 1. Encoder and init hypothesis
encoder_out, encoder_mask = self._forward_encoder( encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size, speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, num_decoding_left_chunks,

@ -36,4 +36,4 @@ def repeat(N, fn):
Returns: Returns:
MultiSequential: Repeated model instance. MultiSequential: Repeated model instance.
""" """
return MultiSequential(*[fn(n) for n in range(N)]) return MultiSequential(* [fn(n) for n in range(N)])

@ -11,16 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import pickle
import unittest import unittest
import numpy as np import numpy as np
import paddle import paddle
import pickle
import os
from paddle import inference from paddle import inference
from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline
from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline
from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline
class TestDeepSpeech2ModelOnline(unittest.TestCase): class TestDeepSpeech2ModelOnline(unittest.TestCase):
def setUp(self): def setUp(self):
@ -185,15 +186,12 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
paddle.allclose(final_state_c_box, final_state_c_box_chk), True) paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
class TestDeepSpeech2StaticModelOnline(unittest.TestCase): class TestDeepSpeech2StaticModelOnline(unittest.TestCase):
def setUp(self): def setUp(self):
export_prefix = "exp/deepspeech2_online/checkpoints/test_export" export_prefix = "exp/deepspeech2_online/checkpoints/test_export"
if not os.path.exists(os.path.dirname(export_prefix)): if not os.path.exists(os.path.dirname(export_prefix)):
os.makedirs(os.path.dirname(export_prefix), mode=0o755) os.makedirs(os.path.dirname(export_prefix), mode=0o755)
infer_model = DeepSpeech2InferModelOnline( infer_model = DeepSpeech2InferModelOnline(
feat_size=161, feat_size=161,
dict_size=4233, dict_size=4233,
num_conv_layers=2, num_conv_layers=2,
@ -207,27 +205,25 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase):
with open("test_data/static_ds2online_inputs.pickle", "rb") as f: with open("test_data/static_ds2online_inputs.pickle", "rb") as f:
self.data_dict = pickle.load(f) self.data_dict = pickle.load(f)
self.setup_model(export_prefix) self.setup_model(export_prefix)
def setup_model(self, export_prefix): def setup_model(self, export_prefix):
deepspeech_config = inference.Config( deepspeech_config = inference.Config(export_prefix + ".pdmodel",
export_prefix + ".pdmodel", export_prefix + ".pdiparams")
export_prefix + ".pdiparams") if ('CUDA_VISIBLE_DEVICES' in os.environ.keys() and
if ('CUDA_VISIBLE_DEVICES' in os.environ.keys() and os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''): os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
deepspeech_config.enable_use_gpu(100, 0) deepspeech_config.enable_use_gpu(100, 0)
deepspeech_config.enable_memory_optim() deepspeech_config.enable_memory_optim()
deepspeech_predictor = inference.create_predictor(deepspeech_config) deepspeech_predictor = inference.create_predictor(deepspeech_config)
self.predictor = deepspeech_predictor self.predictor = deepspeech_predictor
def test_unit(self): def test_unit(self):
input_names = self.predictor.get_input_names() input_names = self.predictor.get_input_names()
audio_handle = self.predictor.get_input_handle(input_names[0]) audio_handle = self.predictor.get_input_handle(input_names[0])
audio_len_handle = self.predictor.get_input_handle(input_names[1]) audio_len_handle = self.predictor.get_input_handle(input_names[1])
h_box_handle = self.predictor.get_input_handle(input_names[2]) h_box_handle = self.predictor.get_input_handle(input_names[2])
c_box_handle = self.predictor.get_input_handle(input_names[3]) c_box_handle = self.predictor.get_input_handle(input_names[3])
x_chunk = self.data_dict["audio_chunk"] x_chunk = self.data_dict["audio_chunk"]
x_chunk_lens = self.data_dict["audio_chunk_lens"] x_chunk_lens = self.data_dict["audio_chunk_lens"]
@ -246,13 +242,9 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase):
c_box_handle.reshape(chunk_state_c_box.shape) c_box_handle.reshape(chunk_state_c_box.shape)
c_box_handle.copy_from_cpu(chunk_state_c_box) c_box_handle.copy_from_cpu(chunk_state_c_box)
output_names = self.predictor.get_output_names() output_names = self.predictor.get_output_names()
output_handle = self.predictor.get_output_handle( output_handle = self.predictor.get_output_handle(output_names[0])
output_names[0]) output_lens_handle = self.predictor.get_output_handle(output_names[1])
output_lens_handle = self.predictor.get_output_handle(
output_names[1])
output_state_h_handle = self.predictor.get_output_handle( output_state_h_handle = self.predictor.get_output_handle(
output_names[2]) output_names[2])
output_state_c_handle = self.predictor.get_output_handle( output_state_c_handle = self.predictor.get_output_handle(
@ -264,7 +256,7 @@ class TestDeepSpeech2StaticModelOnline(unittest.TestCase):
chunk_state_h_box = output_state_h_handle.copy_to_cpu() chunk_state_h_box = output_state_h_handle.copy_to_cpu()
chunk_state_c_box = output_state_c_handle.copy_to_cpu() chunk_state_c_box = output_state_c_handle.copy_to_cpu()
return True return True
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save