fix the asr online client bug, return None, test=doc

pull/1704/head
xiongxinlei 2 years ago
parent babac27a79
commit 48fa84bee9

@ -317,8 +317,6 @@ class BaseEncoder(nn.Layer):
outputs = [] outputs = []
offset = 0 offset = 0
# Feed forward overlap input step by step # Feed forward overlap input step by step
print(f"context: {context}")
print(f"stride: {stride}")
for cur in range(0, num_frames - context + 1, stride): for cur in range(0, num_frames - context + 1, stride):
end = min(cur + decoding_window, num_frames) end = min(cur + decoding_window, num_frames)
chunk_xs = xs[:, cur:end, :] chunk_xs = xs[:, cur:end, :]

@ -35,3 +35,16 @@
```bash ```bash
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
``` ```
## Online ASR Server
### Lanuch online asr server
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### Access online asr server
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav
```

@ -35,3 +35,17 @@
```bash ```bash
paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav paddlespeech_client cls --server_ip 127.0.0.1 --port 8090 --input input.wav
``` ```
## 流式ASR
### 启动流式语音识别服务
```
paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
### 访问流式语音识别服务
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav
```

@ -277,11 +277,12 @@ class ASRClientExecutor(BaseExecutor):
lang=lang, lang=lang,
audio_format=audio_format) audio_format=audio_format)
time_end = time.time() time_end = time.time()
logger.info(res.json()) logger.info(res)
logger.info("Response time %f s." % (time_end - time_start)) logger.info("Response time %f s." % (time_end - time_start))
return True return True
except Exception as e: except Exception as e:
logger.error("Failed to speech recognition.") logger.error("Failed to speech recognition.")
logger.error(e)
return False return False
@stats_wrapper @stats_wrapper
@ -299,9 +300,10 @@ class ASRClientExecutor(BaseExecutor):
logging.info("asr websocket client start") logging.info("asr websocket client start")
handler = ASRAudioHandler(server_ip, port) handler = ASRAudioHandler(server_ip, port)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(handler.run(input)) res = loop.run_until_complete(handler.run(input))
logging.info("asr websocket client finished") logging.info("asr websocket client finished")
return res['asr_results']
@cli_client_register( @cli_client_register(
name='paddlespeech_client.cls', description='visit cls service') name='paddlespeech_client.cls', description='visit cls service')

@ -473,7 +473,7 @@ class PaddleASRConnectionHanddler:
ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size) ctc_probs = self.model.ctc.log_softmax(ys) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0) ctc_probs = ctc_probs.squeeze(0)
self.searcher.search(None, ctc_probs, self.cached_feat.place) self.searcher.search(ctc_probs, self.cached_feat.place)
self.hyps = self.searcher.get_one_best_hyps() self.hyps = self.searcher.get_one_best_hyps()
assert self.cached_feat.shape[0] == 1 assert self.cached_feat.shape[0] == 1
@ -823,7 +823,7 @@ class ASRServerExecutor(ASRExecutor):
ctc_probs = self.model.ctc.log_softmax( ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size) encoder_out) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0) ctc_probs = ctc_probs.squeeze(0)
self.searcher.search(xs, ctc_probs, xs.place) self.searcher.search(ctc_probs, xs.place)
# update the one best result # update the one best result
self.hyps = self.searcher.get_one_best_hyps() self.hyps = self.searcher.get_one_best_hyps()

@ -24,19 +24,18 @@ class CTCPrefixBeamSearch:
"""Implement the ctc prefix beam search """Implement the ctc prefix beam search
Args: Args:
config (_type_): _description_ config (yacs.config.CfgNode): _description_
""" """
self.config = config self.config = config
self.reset() self.reset()
def search(self, xs, ctc_probs, device, blank_id=0): def search(self, ctc_probs, device, blank_id=0):
"""ctc prefix beam search method decode a chunk feature """ctc prefix beam search method decode a chunk feature
Args: Args:
xs (paddle.Tensor): feature data xs (paddle.Tensor): feature data
ctc_probs (paddle.Tensor): the ctc probability of all the tokens ctc_probs (paddle.Tensor): the ctc probability of all the tokens
encoder_out (paddle.Tensor): _description_ device (paddle.fluid.core_avx.Place): the feature host device, such as CUDAPlace(0).
encoder_mask (_type_): _description_
blank_id (int, optional): the blank id in the vocab. Defaults to 0. blank_id (int, optional): the blank id in the vocab. Defaults to 0.
Returns: Returns:
@ -45,7 +44,6 @@ class CTCPrefixBeamSearch:
# decode # decode
logger.info("start to ctc prefix search") logger.info("start to ctc prefix search")
# device = xs.place
batch_size = 1 batch_size = 1
beam_size = self.config.beam_size beam_size = self.config.beam_size
maxlen = ctc_probs.shape[0] maxlen = ctc_probs.shape[0]

@ -34,10 +34,9 @@ class ASRAudioHandler:
def read_wave(self, wavfile_path: str): def read_wave(self, wavfile_path: str):
samples, sample_rate = soundfile.read(wavfile_path, dtype='int16') samples, sample_rate = soundfile.read(wavfile_path, dtype='int16')
x_len = len(samples) x_len = len(samples)
# chunk_stride = 40 * 16 #40ms, sample_rate = 16kHz
chunk_size = 80 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size != 0: chunk_size = 85 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size!= 0:
padding_len_x = chunk_size - x_len % chunk_size padding_len_x = chunk_size - x_len % chunk_size
else: else:
padding_len_x = 0 padding_len_x = 0
@ -48,7 +47,6 @@ class ASRAudioHandler:
assert (x_len + padding_len_x) % chunk_size == 0 assert (x_len + padding_len_x) % chunk_size == 0
num_chunk = (x_len + padding_len_x) / chunk_size num_chunk = (x_len + padding_len_x) / chunk_size
num_chunk = int(num_chunk) num_chunk = int(num_chunk)
for i in range(0, num_chunk): for i in range(0, num_chunk):
start = i * chunk_size start = i * chunk_size
end = start + chunk_size end = start + chunk_size
@ -82,7 +80,6 @@ class ASRAudioHandler:
msg = json.loads(msg) msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("receive msg={}".format(msg))
result = msg
# finished # finished
audio_info = json.dumps( audio_info = json.dumps(
{ {
@ -98,8 +95,8 @@ class ASRAudioHandler:
# decode the bytes to str # decode the bytes to str
msg = json.loads(msg) msg = json.loads(msg)
logging.info("receive msg={}".format(msg)) logging.info("final receive msg={}".format(msg))
result = msg
return result return result

Loading…
Cancel
Save