diff --git a/model.py b/model.py index c8766deb..c2e440b3 100644 --- a/model.py +++ b/model.py @@ -120,6 +120,16 @@ class DeepSpeech2Model(object): feeding=feeding_dict) def infer_loss_batch(self, infer_data): + """Model inference. Infer the ctc loss for a batch of speech + utterances. + + :param infer_data: List of utterances to infer, with each utterance a + tuple of audio features and transcription text (empty + string). + :type infer_data: list + :return: List of ctc loss. + :rtype: List of float + """ # define inferer if self._loss_inferer == None: self._loss_inferer = paddle.inference.Inference(