fix the bug of 'import path error' for ds2

pull/791/head
huangyuxin 3 years ago
parent 86d08f994b
commit 8d5062702b

@ -83,8 +83,8 @@
"from deepspeech.frontend.utility import read_manifest\n",
"from deepspeech.utils.utility import add_arguments, print_arguments\n",
"\n",
"from deepspeech.models.deepspeech2 import DeepSpeech2Model\n",
"from deepspeech.models.deepspeech2 import DeepSpeech2InferModel\n",
"from deepspeech.models.ds2 import DeepSpeech2Model\n",
"from deepspeech.models.ds2 import DeepSpeech2InferModel\n",
"from deepspeech.io.dataset import ManifestDataset\n",
"\n",
"\n",
@ -669,4 +669,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

@ -23,7 +23,7 @@ from paddle.io import DataLoader
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.socket_server import AsrRequestHandler
from deepspeech.utils.socket_server import AsrTCPServer

@ -21,7 +21,7 @@ from paddle.io import DataLoader
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.socket_server import AsrRequestHandler
from deepspeech.utils.socket_server import AsrTCPServer

@ -21,7 +21,7 @@ from paddle.io import DataLoader
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.ds2 import DeepSpeech2Model
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils import error_rate
from deepspeech.utils.utility import add_arguments

@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.subsampling import Conv2dSubsampling4

@ -26,7 +26,7 @@ from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModeOnline']
__all__ = ['DeepSpeech2ModelOnline', 'DeepSpeech2InferModelOnline']
class CRNNEncoder(nn.Layer):
@ -68,7 +68,7 @@ class CRNNEncoder(nn.Layer):
rnn_input_size = i_size
else:
rnn_input_size = layernorm_size
if use_gru == True:
if use_gru is True:
self.rnn.append(
nn.GRU(
input_size=rnn_input_size,
@ -113,7 +113,7 @@ class CRNNEncoder(nn.Layer):
if init_state_h_box is not None:
init_state_list = None
if self.use_gru == True:
if self.use_gru is True:
init_state_h_list = paddle.split(
init_state_h_box, self.num_rnn_layers, axis=0)
init_state_list = init_state_h_list
@ -139,7 +139,7 @@ class CRNNEncoder(nn.Layer):
x = self.fc_layers_list[i](x)
x = F.relu(x)
if self.use_gru == True:
if self.use_gru is True:
final_chunk_state_h_box = paddle.concat(
final_chunk_state_list, axis=0)
final_chunk_state_c_box = init_state_c_box #paddle.zeros_like(final_chunk_state_h_box)

@ -146,7 +146,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False:
if use_gru is False:
self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)
@ -177,7 +177,7 @@ class TestDeepSpeech2ModelOnline(unittest.TestCase):
self.assertEqual(paddle.allclose(eouts_by_chk, eouts), True)
self.assertEqual(
paddle.allclose(final_state_h_box, final_state_h_box_chk), True)
if use_gru == False:
if use_gru is False:
self.assertEqual(
paddle.allclose(final_state_c_box, final_state_c_box_chk), True)

Loading…
Cancel
Save