You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
79 lines
2.9 KiB
79 lines
2.9 KiB
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "audio_classification/nnet/panns_interface.h"
|
|
|
|
#include "audio_classification/nnet/panns_nnet.h"
|
|
#include "common/base/config.h"
|
|
|
|
namespace ppspeech {
|
|
|
|
void* ClsCreateInstance(const char* conf_path) {
|
|
Config conf(conf_path);
|
|
// cls init
|
|
ppspeech::ClsNnetConf cls_nnet_conf;
|
|
cls_nnet_conf.wav_normal_ = conf.Read("wav_normal", true);
|
|
cls_nnet_conf.wav_normal_type_ =
|
|
conf.Read("wav_normal_type", std::string("linear"));
|
|
cls_nnet_conf.wav_norm_mul_factor_ = conf.Read("wav_norm_mul_factor", 1.0);
|
|
cls_nnet_conf.model_file_path_ = conf.Read("model_path", std::string(""));
|
|
cls_nnet_conf.param_file_path_ = conf.Read("param_path", std::string(""));
|
|
cls_nnet_conf.dict_file_path_ = conf.Read("dict_path", std::string(""));
|
|
cls_nnet_conf.num_cpu_thread_ = conf.Read("num_cpu_thread", 12);
|
|
cls_nnet_conf.samp_freq = conf.Read("samp_freq", 32000);
|
|
cls_nnet_conf.frame_length_ms = conf.Read("frame_length_ms", 32);
|
|
cls_nnet_conf.frame_shift_ms = conf.Read("frame_shift_ms", 10);
|
|
cls_nnet_conf.num_bins = conf.Read("num_bins", 64);
|
|
cls_nnet_conf.low_freq = conf.Read("low_freq", 50);
|
|
cls_nnet_conf.high_freq = conf.Read("high_freq", 14000);
|
|
cls_nnet_conf.dither = conf.Read("dither", 0.0);
|
|
|
|
ppspeech::ClsNnet* cls_model = new ppspeech::ClsNnet();
|
|
int ret = cls_model->Init(cls_nnet_conf);
|
|
return static_cast<void*>(cls_model);
|
|
}
|
|
|
|
int ClsDestroyInstance(void* instance) {
|
|
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
|
|
if (cls_model != NULL) {
|
|
delete cls_model;
|
|
cls_model = NULL;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
int ClsFeedForward(void* instance,
|
|
const char* wav_path,
|
|
int topk,
|
|
char* result,
|
|
int result_max_len) {
|
|
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
|
|
if (cls_model == NULL) {
|
|
printf("instance is null\n");
|
|
return -1;
|
|
}
|
|
int ret = cls_model->Forward(wav_path, topk, result, result_max_len);
|
|
return 0;
|
|
}
|
|
|
|
int ClsReset(void* instance) {
|
|
ppspeech::ClsNnet* cls_model = static_cast<ppspeech::ClsNnet*>(instance);
|
|
if (cls_model == NULL) {
|
|
printf("instance is null\n");
|
|
return -1;
|
|
}
|
|
cls_model->Reset();
|
|
return 0;
|
|
}
|
|
} // namespace ppspeech
|