|
|
|
@ -11,7 +11,6 @@
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import argparse
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
@ -24,6 +23,8 @@ from paddlespeech.text.models.ernie_linear import ErnieLinear
|
|
|
|
|
DefinedClassifier = {
|
|
|
|
|
'ErnieLinear': ErnieLinear,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Rhy_predictor():
|
|
|
|
|
def __init__(self, model_path, config_path, punc_path):
|
|
|
|
|
with open(config_path) as f:
|
|
|
|
@ -57,7 +58,7 @@ class Rhy_predictor():
|
|
|
|
|
_inputs['seg_ids'] = tokenized_input['token_type_ids']
|
|
|
|
|
_inputs['seq_len'] = tokenized_input['seq_len']
|
|
|
|
|
return _inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_prediction(self, raw_text):
|
|
|
|
|
_inputs = self.preprocess(raw_text, self.tokenizer)
|
|
|
|
|
seq_len = _inputs['seq_len']
|
|
|
|
@ -76,19 +77,19 @@ class Rhy_predictor():
|
|
|
|
|
if l != 0: # Non punc.
|
|
|
|
|
text += self.punc_list[l]
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_rhy_dict(self):
|
|
|
|
|
self.rhy_dict = {}
|
|
|
|
|
for i, p in enumerate(self.punc_list[1:]):
|
|
|
|
|
self.rhy_dict[p] = 'sp'+str(i+1)
|
|
|
|
|
self.rhy_dict[p] = 'sp' + str(i + 1)
|
|
|
|
|
|
|
|
|
|
def pinyin_align(self, pinyins, rhy_pre):
|
|
|
|
|
final_py = []
|
|
|
|
|
j=0
|
|
|
|
|
j = 0
|
|
|
|
|
for i in range(len(rhy_pre)):
|
|
|
|
|
if rhy_pre[i] in self.rhy_dict:
|
|
|
|
|
final_py.append(self.rhy_dict[rhy_pre[i]])
|
|
|
|
|
else:
|
|
|
|
|
final_py.append(pinyins[j])
|
|
|
|
|
j+=1
|
|
|
|
|
return final_py
|
|
|
|
|
j += 1
|
|
|
|
|
return final_py
|
|
|
|
|