Updated cld3

pull/214/head
M66B 1 year ago
parent 5d0ee632de
commit 5f9928313b

@ -167,6 +167,7 @@ EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model)
for (int i = 0; i < model_->embedding_dim_size(); ++i) {
CLD3_DCHECK(offset_sum == model_->concat_offset(i));
offset_sum += model_->embedding_dim(i) * model_->embedding_num_features(i);
(void)offset_sum; // Avoid compiler warning for "unused" variable.
embedding_matrices_.emplace_back(model_->GetEmbeddingMatrix(i));
}

@ -16,6 +16,7 @@ limitations under the License.
#include <cmath>
#include <iostream>
#include <vector>
#include <set>
#include "base.h"
#include "feature_extractor.h"

@ -209,6 +209,34 @@ bool TestMultipleLanguagesInInput() {
<< result.proportion << std::endl;
return false;
}
// Skip over undefined language.
if (result.language == "und")
continue;
if (result.byte_ranges.size() != 1) {
std::cout << " Should only detect one span containing " << result.language
<< std::endl;
return false;
}
// Check that specified byte ranges for language are correct.
int start_index = result.byte_ranges[0].start_index;
int end_index = result.byte_ranges[0].end_index;
std::string byte_ranges_text = text.substr(start_index, end_index - start_index);
if (result.language == "bg") {
if (byte_ranges_text.compare("Този текст е на Български.") != 0) {
std::cout << " Incorrect byte ranges returned for Bulgarian " << std::endl;
return false;
}
} else if (result.language == "en") {
if (byte_ranges_text.compare("This piece of text is in English. ") != 0) {
std::cout << " Incorrect byte ranges returned for English " << std::endl;
return false;
}
} else {
std::cout << " Got language other than English or Bulgarian "
<< std::endl;
return false;
}
}
std::cout << " Success!" << std::endl;
return true;

@ -47,6 +47,9 @@ struct LangChunksStats {
// Number chunks corresponding to the language.
int num_chunks = 0;
// Specifies the byte ranges that language applies to.
std::vector<NNetLanguageIdentifier::SpanInfo> byte_ranges;
};
// Compares two pairs based on their values.
@ -281,8 +284,6 @@ NNetLanguageIdentifier::FindTopNMostFreqLangs(const string &text,
CLD2::LangSpan script_span;
std::unordered_map<string, LangChunksStats> lang_stats;
int total_num_bytes = 0;
Result result;
string language;
int chunk_size = 0; // Use the default.
while (ss.GetOneScriptSpanLower(&script_span)) {
const int num_original_span_bytes = script_span.text_bytes;
@ -298,12 +299,16 @@ NNetLanguageIdentifier::FindTopNMostFreqLangs(const string &text,
total_num_bytes += num_original_span_bytes;
const string selected_text = SelectTextGivenScriptSpan(script_span);
result = FindLanguageOfValidUTF8(selected_text);
language = result.language;
Result result = FindLanguageOfValidUTF8(selected_text);
string language = result.language;
lang_stats[language].byte_sum += num_original_span_bytes;
lang_stats[language].prob_sum +=
result.probability * num_original_span_bytes;
lang_stats[language].num_chunks++;
// Add SpanInfo. Start and end indices are relative to original input.
lang_stats[language].byte_ranges.push_back(SpanInfo(
ss.MapBack(0), ss.MapBack(script_span.text_bytes), result.probability));
}
// Sort the languages based on the number of bytes associated with them.
@ -329,6 +334,7 @@ NNetLanguageIdentifier::FindTopNMostFreqLangs(const string &text,
result.probability = stats.prob_sum / stats.byte_sum;
result.proportion = stats.byte_sum / byte_sum;
result.is_reliable = ResultIsReliable(language, result.probability);
result.byte_ranges = stats.byte_ranges;
results.push_back(result);
}
@ -348,7 +354,7 @@ string NNetLanguageIdentifier::SelectTextGivenBeginAndSize(
const char *text_begin, int text_size) {
string output_text;
// If the size of the input is greater than the maxium number of bytes needed
// If the size of the input is greater than the maximum number of bytes needed
// for a prediction, then concatenate snippets that are equally spread out
// throughout the input.
if (text_size > max_num_bytes_) {

@ -44,6 +44,19 @@ class LanguageIdEmbeddingFeatureExtractor
// Class for detecting the language of a document.
class NNetLanguageIdentifier {
public:
// Holds probability that Span, specified by start/end indices, is a given
// language. The langauge is not stored here; it can be found in Result, which
// holds a vector of SpanInfo.
struct SpanInfo {
SpanInfo(int start_index_val, int end_index_val, float probability_val)
: start_index(start_index_val),
end_index(end_index_val),
probability(probability_val) {}
int start_index = -1;
int end_index = -1;
float probability = 0.0;
};
// Information about a predicted language.
struct Result {
string language = kUnknown;
@ -53,6 +66,9 @@ class NNetLanguageIdentifier {
// Proportion of bytes associated with the language. If FindLanguage is
// called, this variable is set to 1.
float proportion = 0.0;
// Specifies the byte ranges that |language| applies to.
std::vector<SpanInfo> byte_ranges;
};
NNetLanguageIdentifier();

@ -878,7 +878,6 @@ bool ScriptScanner::GetOneScriptSpan(LangSpan* span) {
// copying letters to buffer with single spaces for each run of non-letters
while (take < byte_length_) {
// Copy run of letters in same script (&LS | LS)*
int letter_count = 0; // Keep track of word length
bool need_break = false;
while (take < byte_length_) {
@ -963,7 +962,6 @@ bool ScriptScanner::GetOneScriptSpan(LangSpan* span) {
map2original_.Delete(tlen - plen);
}
++letter_count;
if (put >= kMaxScriptBytes) {
// Buffer is full
span->truncated = true;

@ -33,14 +33,14 @@ static const int kMaxScriptBytes = kMaxScriptBuffer - 32; // Leave some room
static const int kWithinScriptTail = 32; // Stop at word space in last
// N bytes of script buffer
typedef struct {
struct LangSpan {
char* text = nullptr; // Pointer to the span, somewhere
int text_bytes = 0; // Number of bytes of text in the span
int offset = 0; // Offset of start of span in original input buffer
ULScript ulscript = UNKNOWN_ULSCRIPT; // Unicode Letters Script of this span
bool truncated = false; // true if buffer filled up before a
// different script or EOF was found
} LangSpan;
};
static inline bool IsContinuationByte(char c) {
return static_cast<signed char>(c) < -64;
@ -93,7 +93,7 @@ class ScriptScanner {
// again with the first byte of the following range.
int MapBack(int text_offset);
const char* GetBufferStart() {return start_byte_;};
const char* GetBufferStart() {return start_byte_;}
private:
// Skip over tags and non-letters

@ -158,6 +158,20 @@ static const int kHtmlPlaintextFlag = 0x80; // Bit in add byte to distinguish
*
**/
// All intentional fallthroughs in breakpad are in this file, so define
// this macro locally.
// If you ever move this to a .h file, make sure it's defined in a
// private header file: clang suggests the first macro expanding to
// [[clang::fallthrough]] in its diagnostics, so if BP_FALLTHROUGH
// is visible in code depending on breakpad, clang would suggest
// BP_FALLTHROUGH for code depending on breakpad, instead of the
// client code's own fallthrough macro.
#if defined(__clang__)
#define CLD_FALLTHROUGH [[clang::fallthrough]]
#else
#define CLD_FALLTHROUGH
#endif
// Return true if current Tbl pointer is within state0 range
// Note that unsigned compare checks both ends of range simultaneously
static inline bool InStateZero(const UTF8ScanObj* st, const uint8* Tbl) {
@ -715,10 +729,10 @@ static int UTF8GenericReplaceInternal(const UTF8ReplaceObj* st,
goto Do_state_table;
case kExitReplace3: // update 3 bytes to change
dst[-3] = (unsigned char)Tbl[c + (nEntries * 3)];
// Fall into next case
CLD_FALLTHROUGH;
case kExitReplace2: // update 2 bytes to change
dst[-2] = (unsigned char)Tbl[c + (nEntries * 2)];
// Fall into next case
CLD_FALLTHROUGH;
case kExitReplace1: // update 1 byte to change
dst[-1] = (unsigned char)Tbl[c + (nEntries * 1)];
total_changed++;
@ -736,7 +750,7 @@ static int UTF8GenericReplaceInternal(const UTF8ReplaceObj* st,
} else {
offset += ((unsigned char)Tbl[c + (nEntries * 2)] << 8);
}
// Fall into next case
CLD_FALLTHROUGH;
case kExitSpecial: // Apply special fixups [read: hacks]
case kExitReplaceOffset1:
if ((nEntries != 256) && InStateZero(st, Tbl)) {
@ -986,10 +1000,10 @@ static int UTF8GenericReplaceInternalTwoByte(const UTF8ReplaceObj_2* st,
goto Do_state_table_2;
case kExitReplace3_2: // update 3 bytes to change
dst[-3] = (unsigned char)(Tbl[c + (nEntries * 2)] & 0xff);
// Fall into next case
CLD_FALLTHROUGH;
case kExitReplace2_2: // update 2 bytes to change
dst[-2] = (unsigned char)(Tbl[c + (nEntries * 1)] >> 8 & 0xff);
// Fall into next case
CLD_FALLTHROUGH;
case kExitReplace1_2: // update 1 byte to change
dst[-1] = (unsigned char)(Tbl[c + (nEntries * 1)] & 0xff);
total_changed++;
@ -1007,7 +1021,7 @@ static int UTF8GenericReplaceInternalTwoByte(const UTF8ReplaceObj_2* st,
} else {
offset += ((unsigned char)(Tbl[c + (nEntries * 1)] >> 8 & 0xff) << 8);
}
// Fall into next case
CLD_FALLTHROUGH;
case kExitReplaceOffset1_2:
if ((nEntries != 256) && InStateZero_2(st, Tbl)) {
// For space-optimized table, we need multiples of 256 bytes

@ -19,11 +19,11 @@ limitations under the License.
namespace chrome_lang_id {
// Declare registry for the whole Sentence feature functions. NOTE: this is not
// Define registry for the whole Sentence feature functions. NOTE: this is not
// yet set to anything meaningful. It will be set so in NNetLanguageIdentifier
// constructor, *before* we use any feature.
template <>
WholeSentenceFeature::Registry
*RegisterableClass<WholeSentenceFeature>::registry_ = nullptr;
WholeSentenceFeature::Registry*
RegisterableClass<WholeSentenceFeature>::registry_ = nullptr;
} // namespace chrome_lang_id
} // namespace chrome_lang_id

@ -26,9 +26,19 @@ limitations under the License.
namespace chrome_lang_id {
// Feature function that extracts features for the full Sentence.
typedef FeatureFunction<Sentence> WholeSentenceFeature;
typedef FeatureExtractor<Sentence> WholeSentenceExtractor;
using WholeSentenceFeature = FeatureFunction<Sentence>;
using WholeSentenceExtractor = FeatureExtractor<Sentence>;
// Declare registry for the whole Sentence feature functions. This is required
// for clang's -Wundefined-var-template. However, MSVC has a bug which treats
// this declaration as a definition, leading to multiple definition errors, so
// omit this on MSVC.
#if !defined(COMPILER_MSVC)
template <>
WholeSentenceFeature::Registry
*RegisterableClass<WholeSentenceFeature>::registry_;
#endif
} // namespace chrome_lang_id

Loading…
Cancel
Save