You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/third_party/pymmseg-cpp/mmseg/mmseg-cpp/dict.cpp

199 lines
4.6 KiB

E2E/Streaming Transformer/Conformer ASR (#578) * add cmvn and label smoothing loss layer * add layer for transformer * add glu and conformer conv * add torch compatiable hack, mask funcs * not hack size since it exists * add test; attention * add attention, common utils, hack paddle * add audio utils * conformer batch padding mask bug fix #223 * fix typo, python infer fix rnn mem opt name error and batchnorm1d, will be available at 2.0.2 * fix ci * fix ci * add encoder * refactor egs * add decoder * refactor ctc, add ctc align, refactor ckpt, add warmup lr scheduler, cmvn utils * refactor docs * add fix * fix readme * fix bugs, refactor collator, add pad_sequence, fix ckpt bugs * fix docstring * refactor data feed order * add u2 model * refactor cmvn, test * add utils * add u2 config * fix bugs * fix bugs * fix autograd maybe has problem when using inplace operation * refactor data, build vocab; add format data * fix text featurizer * refactor build vocab * add fbank, refactor feature of speech * refactor audio feat * refactor data preprare * refactor data * model init from config * add u2 bins * flake8 * can train * fix bugs, add coverage, add scripts * test can run * fix data * speed perturb with sox * add spec aug * fix for train * fix train logitc * fix logger * log valid loss, time dataset process * using np for speed perturb, remove some debug log of grad clip * fix logger * fix build vocab * fix logger name * using module logger as default * fix * fix install * reorder imports * fix board logger * fix logger * kaldi fbank and mfcc * fix cmvn and print prarams * fix add_eos_sos and cmvn * fix cmvn compute * fix logger and cmvn * fix subsampling, label smoothing loss, remove useless * add notebook test * fix log * fix tb logger * multi gpu valid * fix log * fix log * fix config * fix compute cmvn, need paddle 2.1 * add cmvn notebook * fix layer tools * fix compute cmvn * add rtf * fix decoding * fix layer tools * fix log, add avg script * more avg and test info * fix dataset pickle problem; using 2.1 paddle; num_workers can > 0; ckpt save in exp dir;fix setup.sh; * add vimrc * refactor tiny script, add transformer and stream conf * spm demo; librisppech scripts and confs * fix log * add librispeech scripts * refactor data pipe; fix conf; fix u2 default params * fix bugs * refactor aishell scripts * fix test * fix cmvn * fix s0 scripts * fix ds2 scripts and bugs * fix dev & test dataset filter * fix dataset filter * filter dev * fix ckpt path * filter test, since librispeech will cause OOM, but all test wer will be worse, since mismatch train with test * add comment * add syllable doc * fix ds2 configs * add doc * add pypinyin tools * fix decoder using blank_id=0 * mmseg with pybind11 * format code
3 years ago
#include <cstdio>
#include "dict.h"
using namespace std;
namespace rmmseg {
struct Entry {
Word *word;
Entry *next;
};
const size_t init_size = 262147;
const size_t max_density = 5;
/*
Table of prime numbers 2^n+a, 2<=n<=30.
*/
static size_t primes[] = {
524288 + 21,
1048576 + 7,
2097152 + 17,
4194304 + 15,
8388608 + 9,
16777216 + 43,
33554432 + 35,
67108864 + 15,
134217728 + 29,
268435456 + 3,
536870912 + 11,
1073741824 + 85,
};
static size_t n_bins = init_size;
static size_t n_entries = 0;
static Entry **bins =
static_cast<Entry **>(std::calloc(init_size, sizeof(Entry *)));
static size_t new_size() {
for (size_t i = 0; i < sizeof(primes) / sizeof(primes[0]); ++i) {
if (primes[i] > n_bins) {
return primes[i];
}
}
// TODO: raise exception here
return n_bins;
}
static unsigned int hash(const char *str, int len) {
unsigned int key = 0;
while (len--) {
key += *str++;
key += (key << 10);
key ^= (key >> 6);
}
key += (key << 3);
key ^= (key >> 11);
key += (key << 15);
return key;
}
static void rehash() {
size_t new_n_bins = new_size();
Entry **new_bins =
static_cast<Entry **>(calloc(new_n_bins, sizeof(Entry *)));
Entry *entry, *next;
unsigned int hash_val;
for (size_t i = 0; i < n_bins; ++i) {
entry = bins[i];
while (entry) {
next = entry->next;
hash_val =
hash(entry->word->text, entry->word->nbytes) % new_n_bins;
entry->next = new_bins[hash_val];
new_bins[hash_val] = entry;
entry = next;
}
}
free(bins);
n_bins = new_n_bins;
bins = new_bins;
}
namespace dict {
/**
* str: the base of the string
* len: length of the string (in bytes)
*
* str may be a substring of a big chunk of text thus not nul-terminated,
* so len is necessary here.
*/
Word *get(const char *str, int len) {
unsigned int h = hash(str, len) % n_bins;
Entry *entry = bins[h];
if (!entry) return NULL;
do {
if (len == entry->word->nbytes &&
strncmp(str, entry->word->text, len) == 0)
return entry->word;
entry = entry->next;
} while (entry);
return NULL;
}
void add(Word *word) {
unsigned int hash_val = hash(word->text, word->nbytes);
unsigned int h = hash_val % n_bins;
Entry *entry = bins[h];
if (!entry) {
if (n_entries / n_bins > max_density) {
rehash();
h = hash_val % n_bins;
}
entry = static_cast<Entry *>(pool_alloc(sizeof(Entry)));
entry->word = word;
entry->next = NULL;
bins[h] = entry;
n_entries++;
return;
}
bool done = false;
do {
if (word->nbytes == entry->word->nbytes &&
strncmp(word->text, entry->word->text, word->nbytes) == 0) {
/* Overwriting. WARNING: the original Word object is
* permanently lost. This IS a memory leak, because
* the memory is allocated by pool_alloc. Instead of
* fixing this, tuning the dictionary file is a better
* idea
*/
entry->word = word;
done = true;
break;
}
entry = entry->next;
} while (entry);
if (!done) {
entry = static_cast<Entry *>(pool_alloc(sizeof(Entry)));
entry->word = word;
entry->next = bins[h];
bins[h] = entry;
n_entries++;
}
}
bool load_chars(const char *filename) {
FILE *fp = fopen(filename, "r");
if (!fp) {
return false;
}
const size_t buf_len = 24;
char buf[buf_len];
char *ptr;
while (fgets(buf, buf_len, fp)) {
// NOTE: there SHOULD be a newline at the end of the file
buf[strlen(buf) - 1] = '\0'; // truncate the newline
ptr = strchr(buf, ' ');
if (!ptr) continue; // illegal input
*ptr = '\0';
add(make_word(ptr + 1, 1, atoi(buf)));
}
fclose(fp);
return true;
}
bool load_words(const char *filename) {
FILE *fp = fopen(filename, "r");
if (!fp) {
return false;
}
const int buf_len = 48;
char buf[buf_len];
char *ptr;
while (fgets(buf, buf_len, fp)) {
// NOTE: there SHOULD be a newline at the end of the file
buf[strlen(buf) - 1] = '\0'; // truncate the newline
ptr = strchr(buf, ' ');
if (!ptr) continue; // illegal input
*ptr = '\0';
add(make_word(ptr + 1, atoi(buf), 0));
}
fclose(fp);
return true;
}
}
}