[Hackathon 7th] 修复 `tal_cs` 测试中 0D tensor to 1D (#3913)

* [Fix] 0D tensor to 1D

* [Update] feat dim
pull/3894/merge
megemini 4 weeks ago committed by GitHub
parent a397ebe207
commit 5e8c727fd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -75,7 +75,7 @@ class DeepSpeech2Tester_hub():
feat = self.preprocessing(audio, **self.preprocess_args) feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}") logger.info(f"feat shape: {feat.shape}")
audio_len = paddle.to_tensor(feat.shape[0]) audio_len = paddle.to_tensor(feat.shape[0]).unsqueeze(0)
audio = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0) audio = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)
result_transcripts = self.compute_result_transcripts( result_transcripts = self.compute_result_transcripts(

@ -75,7 +75,7 @@ class U2Infer():
feat = self.preprocessing(audio, **self.preprocess_args) feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}") logger.info(f"feat shape: {feat.shape}")
ilen = paddle.to_tensor(feat.shape[0]) ilen = paddle.to_tensor(feat.shape[0]).unsqueeze(0)
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
decode_config = self.config.decode decode_config = self.config.decode
logger.info(f"decode cfg: {decode_config}") logger.info(f"decode cfg: {decode_config}")

@ -78,7 +78,7 @@ class U2Infer():
if self.args.debug: if self.args.debug:
np.savetxt("feat.transform.txt", feat) np.savetxt("feat.transform.txt", feat)
ilen = paddle.to_tensor(feat.shape[0]) ilen = paddle.to_tensor(feat.shape[0]).unsqueeze(0)
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
decode_config = self.config.decode decode_config = self.config.decode
logger.info(f"decode cfg: {decode_config}") logger.info(f"decode cfg: {decode_config}")

Loading…
Cancel
Save