You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
893 lines
31 KiB
893 lines
31 KiB
// fstext/lattice-weight.h
|
|
// Copyright 2009-2012 Microsoft Corporation
|
|
// Johns Hopkins University (author: Daniel Povey)
|
|
|
|
// See ../../COPYING for clarification regarding multiple authors
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
|
|
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
|
|
// MERCHANTABLITY OR NON-INFRINGEMENT.
|
|
// See the Apache 2 License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#ifndef KALDI_FSTEXT_LATTICE_WEIGHT_H_
|
|
#define KALDI_FSTEXT_LATTICE_WEIGHT_H_
|
|
|
|
#include <algorithm>
|
|
#include <limits>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "base/kaldi-common.h"
|
|
#include "fst/fstlib.h"
|
|
|
|
namespace fst {
|
|
|
|
// Declare weight type for lattice... will import to namespace kaldi. has two
|
|
// members, value1_ and value2_, of type BaseFloat (normally equals float). It
|
|
// is basically the same as the tropical semiring on value1_+value2_, except it
|
|
// keeps track of a and b separately. More precisely, it is equivalent to the
|
|
// lexicographic semiring on (value1_+value2_), (value1_-value2_)
|
|
|
|
template <class FloatType>
|
|
class LatticeWeightTpl;
|
|
|
|
template <class FloatType>
|
|
inline std::ostream &operator<<(std::ostream &strm,
|
|
const LatticeWeightTpl<FloatType> &w);
|
|
|
|
template <class FloatType>
|
|
inline std::istream &operator>>(std::istream &strm,
|
|
LatticeWeightTpl<FloatType> &w);
|
|
|
|
template <class FloatType>
|
|
class LatticeWeightTpl {
|
|
public:
|
|
typedef FloatType T; // normally float.
|
|
typedef LatticeWeightTpl ReverseWeight;
|
|
|
|
inline T Value1() const { return value1_; }
|
|
|
|
inline T Value2() const { return value2_; }
|
|
|
|
inline void SetValue1(T f) { value1_ = f; }
|
|
|
|
inline void SetValue2(T f) { value2_ = f; }
|
|
|
|
LatticeWeightTpl() : value1_{}, value2_{} {}
|
|
|
|
LatticeWeightTpl(T a, T b) : value1_(a), value2_(b) {}
|
|
|
|
LatticeWeightTpl(const LatticeWeightTpl &other)
|
|
: value1_(other.value1_), value2_(other.value2_) {}
|
|
|
|
LatticeWeightTpl &operator=(const LatticeWeightTpl &w) {
|
|
value1_ = w.value1_;
|
|
value2_ = w.value2_;
|
|
return *this;
|
|
}
|
|
|
|
LatticeWeightTpl<FloatType> Reverse() const { return *this; }
|
|
|
|
static const LatticeWeightTpl Zero() {
|
|
return LatticeWeightTpl(std::numeric_limits<T>::infinity(),
|
|
std::numeric_limits<T>::infinity());
|
|
}
|
|
|
|
static const LatticeWeightTpl One() { return LatticeWeightTpl(0.0, 0.0); }
|
|
|
|
static const std::string &Type() {
|
|
static const std::string type = (sizeof(T) == 4 ? "lattice4" : "lattice8");
|
|
return type;
|
|
}
|
|
|
|
static const LatticeWeightTpl NoWeight() {
|
|
return LatticeWeightTpl(std::numeric_limits<FloatType>::quiet_NaN(),
|
|
std::numeric_limits<FloatType>::quiet_NaN());
|
|
}
|
|
|
|
bool Member() const {
|
|
// value1_ == value1_ tests for NaN.
|
|
// also test for no -inf, and either both or neither
|
|
// must be +inf, and
|
|
if (value1_ != value1_ || value2_ != value2_) return false; // NaN
|
|
if (value1_ == -std::numeric_limits<T>::infinity() ||
|
|
value2_ == -std::numeric_limits<T>::infinity())
|
|
return false; // -infty not allowed
|
|
if (value1_ == std::numeric_limits<T>::infinity() ||
|
|
value2_ == std::numeric_limits<T>::infinity()) {
|
|
if (value1_ != std::numeric_limits<T>::infinity() ||
|
|
value2_ != std::numeric_limits<T>::infinity())
|
|
return false; // both must be +infty;
|
|
// this is necessary so that the semiring has only one zero.
|
|
}
|
|
return true;
|
|
}
|
|
|
|
LatticeWeightTpl Quantize(float delta = kDelta) const {
|
|
if (value1_ + value2_ == -std::numeric_limits<T>::infinity()) {
|
|
return LatticeWeightTpl(-std::numeric_limits<T>::infinity(),
|
|
-std::numeric_limits<T>::infinity());
|
|
} else if (value1_ + value2_ == std::numeric_limits<T>::infinity()) {
|
|
return LatticeWeightTpl(std::numeric_limits<T>::infinity(),
|
|
std::numeric_limits<T>::infinity());
|
|
} else if (value1_ + value2_ != value1_ + value2_) { // NaN
|
|
return LatticeWeightTpl(value1_ + value2_, value1_ + value2_);
|
|
} else {
|
|
return LatticeWeightTpl(floor(value1_ / delta + 0.5F) * delta,
|
|
floor(value2_ / delta + 0.5F) * delta);
|
|
}
|
|
}
|
|
static constexpr uint64 Properties() {
|
|
return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent;
|
|
}
|
|
|
|
// This is used in OpenFst for binary I/O. This is OpenFst-style,
|
|
// not Kaldi-style, I/O.
|
|
std::istream &Read(std::istream &strm) {
|
|
// Always read/write as float, even if T is double,
|
|
// so we can use OpenFst-style read/write and still maintain
|
|
// compatibility when compiling with different FloatTypes
|
|
ReadType(strm, &value1_);
|
|
ReadType(strm, &value2_);
|
|
return strm;
|
|
}
|
|
|
|
// This is used in OpenFst for binary I/O. This is OpenFst-style,
|
|
// not Kaldi-style, I/O.
|
|
std::ostream &Write(std::ostream &strm) const {
|
|
WriteType(strm, value1_);
|
|
WriteType(strm, value2_);
|
|
return strm;
|
|
}
|
|
|
|
size_t Hash() const {
|
|
size_t ans;
|
|
union {
|
|
T f;
|
|
size_t s;
|
|
} u;
|
|
u.s = 0;
|
|
u.f = value1_;
|
|
ans = u.s;
|
|
u.f = value2_;
|
|
ans += u.s;
|
|
return ans;
|
|
}
|
|
|
|
protected:
|
|
inline static void WriteFloatType(std::ostream &strm, const T &f) {
|
|
if (f == std::numeric_limits<T>::infinity())
|
|
strm << "Infinity";
|
|
else if (f == -std::numeric_limits<T>::infinity())
|
|
strm << "-Infinity";
|
|
else if (f != f)
|
|
strm << "BadNumber";
|
|
else
|
|
strm << f;
|
|
}
|
|
|
|
// Internal helper function, used in ReadNoParen.
|
|
inline static void ReadFloatType(std::istream &strm, T &f) { // NOLINT
|
|
std::string s;
|
|
strm >> s;
|
|
if (s == "Infinity") {
|
|
f = std::numeric_limits<T>::infinity();
|
|
} else if (s == "-Infinity") {
|
|
f = -std::numeric_limits<T>::infinity();
|
|
} else if (s == "BadNumber") {
|
|
f = std::numeric_limits<T>::quiet_NaN();
|
|
} else {
|
|
char *p;
|
|
f = strtod(s.c_str(), &p);
|
|
if (p < s.c_str() + s.size()) strm.clear(std::ios::badbit);
|
|
}
|
|
}
|
|
|
|
// Reads LatticeWeight when there are no parentheses around pair terms...
|
|
// currently the only form supported.
|
|
inline std::istream &ReadNoParen(std::istream &strm, char separator) {
|
|
int c;
|
|
do {
|
|
c = strm.get();
|
|
} while (isspace(c));
|
|
|
|
std::string s1;
|
|
while (c != separator) {
|
|
if (c == EOF) {
|
|
strm.clear(std::ios::badbit);
|
|
return strm;
|
|
}
|
|
s1 += c;
|
|
c = strm.get();
|
|
}
|
|
std::istringstream strm1(s1);
|
|
ReadFloatType(strm1, value1_); // ReadFloatType is class member function
|
|
// read second element
|
|
ReadFloatType(strm, value2_);
|
|
return strm;
|
|
}
|
|
|
|
friend std::istream &operator>>
|
|
<FloatType>(std::istream &, LatticeWeightTpl<FloatType> &);
|
|
friend std::ostream &operator<<<FloatType>(
|
|
std::ostream &, const LatticeWeightTpl<FloatType> &);
|
|
|
|
private:
|
|
T value1_;
|
|
T value2_;
|
|
};
|
|
|
|
/* ScaleTupleWeight is a function defined for LatticeWeightTpl and
|
|
CompactLatticeWeightTpl that mutliplies the pair (value1_, value2_) by a 2x2
|
|
matrix. Used, for example, in applying acoustic scaling.
|
|
*/
|
|
template <class FloatType, class ScaleFloatType>
|
|
inline LatticeWeightTpl<FloatType> ScaleTupleWeight(
|
|
const LatticeWeightTpl<FloatType> &w,
|
|
const std::vector<std::vector<ScaleFloatType> > &scale) {
|
|
// Without the next special case we'd get NaNs from infinity * 0
|
|
if (w.Value1() == std::numeric_limits<FloatType>::infinity())
|
|
return LatticeWeightTpl<FloatType>::Zero();
|
|
return LatticeWeightTpl<FloatType>(
|
|
scale[0][0] * w.Value1() + scale[0][1] * w.Value2(),
|
|
scale[1][0] * w.Value1() + scale[1][1] * w.Value2());
|
|
}
|
|
|
|
/* For testing purposes and in case it's ever useful, we define a similar
|
|
function to apply to LexicographicWeight and the like, templated on
|
|
TropicalWeight<float> etc.; we use PairWeight which is the base class of
|
|
LexicographicWeight.
|
|
*/
|
|
template <class FloatType, class ScaleFloatType>
|
|
inline PairWeight<TropicalWeightTpl<FloatType>, TropicalWeightTpl<FloatType> >
|
|
ScaleTupleWeight(const PairWeight<TropicalWeightTpl<FloatType>,
|
|
TropicalWeightTpl<FloatType> > &w,
|
|
const std::vector<std::vector<ScaleFloatType> > &scale) {
|
|
typedef TropicalWeightTpl<FloatType> BaseType;
|
|
typedef PairWeight<BaseType, BaseType> PairType;
|
|
const BaseType zero = BaseType::Zero();
|
|
// Without the next special case we'd get NaNs from infinity * 0
|
|
if (w.Value1() == zero || w.Value2() == zero) return PairType(zero, zero);
|
|
FloatType f1 = w.Value1().Value(), f2 = w.Value2().Value();
|
|
return PairType(BaseType(scale[0][0] * f1 + scale[0][1] * f2),
|
|
BaseType(scale[1][0] * f1 + scale[1][1] * f2));
|
|
}
|
|
|
|
template <class FloatType>
|
|
inline bool operator==(const LatticeWeightTpl<FloatType> &wa,
|
|
const LatticeWeightTpl<FloatType> &wb) {
|
|
// Volatile qualifier thwarts over-aggressive compiler optimizations
|
|
// that lead to problems esp. with NaturalLess().
|
|
volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(),
|
|
vb2 = wb.Value2();
|
|
return (va1 == vb1 && va2 == vb2);
|
|
}
|
|
|
|
template <class FloatType>
|
|
inline bool operator!=(const LatticeWeightTpl<FloatType> &wa,
|
|
const LatticeWeightTpl<FloatType> &wb) {
|
|
// Volatile qualifier thwarts over-aggressive compiler optimizations
|
|
// that lead to problems esp. with NaturalLess().
|
|
volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(),
|
|
vb2 = wb.Value2();
|
|
return (va1 != vb1 || va2 != vb2);
|
|
}
|
|
|
|
// We define a Compare function LatticeWeightTpl even though it's
|
|
// not required by the semiring standard-- it's just more efficient
|
|
// to do it this way rather than using the NaturalLess template.
|
|
|
|
/// Compare returns -1 if w1 < w2, +1 if w1 > w2, and 0 if w1 == w2.
|
|
|
|
template <class FloatType>
|
|
inline int Compare(const LatticeWeightTpl<FloatType> &w1,
|
|
const LatticeWeightTpl<FloatType> &w2) {
|
|
FloatType f1 = w1.Value1() + w1.Value2(), f2 = w2.Value1() + w2.Value2();
|
|
if (f1 < f2) { // having smaller cost means you're larger
|
|
return 1;
|
|
} else if (f1 > f2) { // in the semiring [higher probability]
|
|
return -1;
|
|
} else if (w1.Value1() < w2.Value1()) {
|
|
// mathematically we should be comparing (w1.value1_-w1.value2_ <
|
|
// w2.value1_-w2.value2_) in the next line, but add w1.value1_+w1.value2_ =
|
|
// w2.value1_+w2.value2_ to both sides and divide by two, and we get the
|
|
// simpler equivalent form w1.value1_ < w2.value1_.
|
|
return 1;
|
|
} else if (w1.Value1() > w2.Value1()) {
|
|
return -1;
|
|
} else {
|
|
return 0;
|
|
}
|
|
}
|
|
|
|
template <class FloatType>
|
|
inline LatticeWeightTpl<FloatType> Plus(const LatticeWeightTpl<FloatType> &w1,
|
|
const LatticeWeightTpl<FloatType> &w2) {
|
|
return (Compare(w1, w2) >= 0 ? w1 : w2);
|
|
}
|
|
|
|
// For efficiency, override the NaturalLess template class.
|
|
template <class FloatType>
|
|
class NaturalLess<LatticeWeightTpl<FloatType> > {
|
|
public:
|
|
typedef LatticeWeightTpl<FloatType> Weight;
|
|
|
|
NaturalLess() {}
|
|
|
|
bool operator()(const Weight &w1, const Weight &w2) const {
|
|
// NaturalLess is a negative order (opposite to normal ordering).
|
|
// This operator () corresponds to "<" in the negative order, which
|
|
// corresponds to the ">" in the normal order.
|
|
return (Compare(w1, w2) == 1);
|
|
}
|
|
};
|
|
template <>
|
|
class NaturalLess<LatticeWeightTpl<float> > {
|
|
public:
|
|
typedef LatticeWeightTpl<float> Weight;
|
|
|
|
NaturalLess() {}
|
|
|
|
bool operator()(const Weight &w1, const Weight &w2) const {
|
|
// NaturalLess is a negative order (opposite to normal ordering).
|
|
// This operator () corresponds to "<" in the negative order, which
|
|
// corresponds to the ">" in the normal order.
|
|
return (Compare(w1, w2) == 1);
|
|
}
|
|
};
|
|
template <>
|
|
class NaturalLess<LatticeWeightTpl<double> > {
|
|
public:
|
|
typedef LatticeWeightTpl<double> Weight;
|
|
|
|
NaturalLess() {}
|
|
|
|
bool operator()(const Weight &w1, const Weight &w2) const {
|
|
// NaturalLess is a negative order (opposite to normal ordering).
|
|
// This operator () corresponds to "<" in the negative order, which
|
|
// corresponds to the ">" in the normal order.
|
|
return (Compare(w1, w2) == 1);
|
|
}
|
|
};
|
|
|
|
template <class FloatType>
|
|
inline LatticeWeightTpl<FloatType> Times(
|
|
const LatticeWeightTpl<FloatType> &w1,
|
|
const LatticeWeightTpl<FloatType> &w2) {
|
|
return LatticeWeightTpl<FloatType>(w1.Value1() + w2.Value1(),
|
|
w1.Value2() + w2.Value2());
|
|
}
|
|
|
|
// divide w1 by w2 (on left/right/any doesn't matter as
|
|
// commutative).
|
|
template <class FloatType>
|
|
inline LatticeWeightTpl<FloatType> Divide(const LatticeWeightTpl<FloatType> &w1,
|
|
const LatticeWeightTpl<FloatType> &w2,
|
|
DivideType typ = DIVIDE_ANY) {
|
|
typedef FloatType T;
|
|
T a = w1.Value1() - w2.Value1(), b = w1.Value2() - w2.Value2();
|
|
if (a != a || b != b || a == -std::numeric_limits<T>::infinity() ||
|
|
b == -std::numeric_limits<T>::infinity()) {
|
|
KALDI_WARN << "LatticeWeightTpl::Divide, NaN or invalid number produced. "
|
|
<< "[dividing by zero?] Returning zero";
|
|
return LatticeWeightTpl<T>::Zero();
|
|
}
|
|
if (a == std::numeric_limits<T>::infinity() ||
|
|
b == std::numeric_limits<T>::infinity())
|
|
return LatticeWeightTpl<T>::Zero(); // not a valid number if only one is
|
|
// infinite.
|
|
return LatticeWeightTpl<T>(a, b);
|
|
}
|
|
|
|
template <class FloatType>
|
|
inline bool ApproxEqual(const LatticeWeightTpl<FloatType> &w1,
|
|
const LatticeWeightTpl<FloatType> &w2,
|
|
float delta = kDelta) {
|
|
if (w1.Value1() == w2.Value1() && w1.Value2() == w2.Value2())
|
|
return true; // handles Zero().
|
|
return (fabs((w1.Value1() + w1.Value2()) - (w2.Value1() + w2.Value2())) <=
|
|
delta);
|
|
}
|
|
|
|
template <class FloatType>
|
|
inline std::ostream &operator<<(std::ostream &strm,
|
|
const LatticeWeightTpl<FloatType> &w) {
|
|
LatticeWeightTpl<FloatType>::WriteFloatType(strm, w.Value1());
|
|
CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
|
|
strm << FLAGS_fst_weight_separator[0]; // comma by default;
|
|
// may or may not be settable from Kaldi programs.
|
|
LatticeWeightTpl<FloatType>::WriteFloatType(strm, w.Value2());
|
|
return strm;
|
|
}
|
|
|
|
template <class FloatType>
|
|
inline std::istream &operator>>(std::istream &strm,
|
|
LatticeWeightTpl<FloatType> &w1) {
|
|
CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
|
|
// separator defaults to ','
|
|
return w1.ReadNoParen(strm, FLAGS_fst_weight_separator[0]);
|
|
}
|
|
|
|
// CompactLattice will be an acceptor (accepting the words/output-symbols),
|
|
// with the weights and input-symbol-seqs on the arcs.
|
|
// There must be a total order on W. We assume for the sake of efficiency
|
|
// that there is a function
|
|
// Compare(W w1, W w2) that returns -1 if w1 < w2, +1 if w1 > w2, and
|
|
// zero if w1 == w2, and Plus for type W returns (Compare(w1,w2) >= 0 ? w1 :
|
|
// w2).
|
|
|
|
template <class WeightType, class IntType>
|
|
class CompactLatticeWeightTpl {
|
|
public:
|
|
typedef WeightType W;
|
|
|
|
typedef CompactLatticeWeightTpl<WeightType, IntType> ReverseWeight;
|
|
|
|
// Plus is like LexicographicWeight on the pair (weight_, string_), but where
|
|
// we use standard lexicographic order on string_ [this is not the same as
|
|
// NaturalLess on the StringWeight equivalent, which does not define a
|
|
// total order].
|
|
// Times, Divide obvious... (support both left & right division..)
|
|
// CommonDivisor would need to be coded separately.
|
|
|
|
CompactLatticeWeightTpl() {}
|
|
|
|
CompactLatticeWeightTpl(const WeightType &w, const std::vector<IntType> &s)
|
|
: weight_(w), string_(s) {}
|
|
|
|
CompactLatticeWeightTpl &operator=(
|
|
const CompactLatticeWeightTpl<WeightType, IntType> &w) {
|
|
weight_ = w.weight_;
|
|
string_ = w.string_;
|
|
return *this;
|
|
}
|
|
|
|
const W &Weight() const { return weight_; }
|
|
|
|
const std::vector<IntType> &String() const { return string_; }
|
|
|
|
void SetWeight(const W &w) { weight_ = w; }
|
|
|
|
void SetString(const std::vector<IntType> &s) { string_ = s; }
|
|
|
|
static const CompactLatticeWeightTpl<WeightType, IntType> Zero() {
|
|
return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::Zero(),
|
|
std::vector<IntType>());
|
|
}
|
|
|
|
static const CompactLatticeWeightTpl<WeightType, IntType> One() {
|
|
return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::One(),
|
|
std::vector<IntType>());
|
|
}
|
|
|
|
inline static std::string GetIntSizeString() {
|
|
char buf[2];
|
|
buf[0] = '0' + sizeof(IntType);
|
|
buf[1] = '\0';
|
|
return buf;
|
|
}
|
|
static const std::string &Type() {
|
|
static const std::string type =
|
|
"compact" + WeightType::Type() + GetIntSizeString();
|
|
return type;
|
|
}
|
|
|
|
static const CompactLatticeWeightTpl<WeightType, IntType> NoWeight() {
|
|
return CompactLatticeWeightTpl<WeightType, IntType>(WeightType::NoWeight(),
|
|
std::vector<IntType>());
|
|
}
|
|
|
|
CompactLatticeWeightTpl<WeightType, IntType> Reverse() const {
|
|
size_t s = string_.size();
|
|
std::vector<IntType> v(s);
|
|
for (size_t i = 0; i < s; i++) v[i] = string_[s - i - 1];
|
|
return CompactLatticeWeightTpl<WeightType, IntType>(weight_, v);
|
|
}
|
|
|
|
bool Member() const {
|
|
// a semiring has only one zero, this is the important property
|
|
// we're trying to maintain here. So force string_ to be empty if
|
|
// w_ == zero.
|
|
if (!weight_.Member()) return false;
|
|
if (weight_ == WeightType::Zero())
|
|
return string_.empty();
|
|
else
|
|
return true;
|
|
}
|
|
|
|
CompactLatticeWeightTpl Quantize(float delta = kDelta) const {
|
|
return CompactLatticeWeightTpl(weight_.Quantize(delta), string_);
|
|
}
|
|
|
|
static constexpr uint64 Properties() {
|
|
return kLeftSemiring | kRightSemiring | kPath | kIdempotent;
|
|
}
|
|
|
|
// This is used in OpenFst for binary I/O. This is OpenFst-style,
|
|
// not Kaldi-style, I/O.
|
|
std::istream &Read(std::istream &strm) {
|
|
weight_.Read(strm);
|
|
if (strm.fail()) {
|
|
return strm;
|
|
}
|
|
int32 sz;
|
|
ReadType(strm, &sz);
|
|
if (strm.fail()) {
|
|
return strm;
|
|
}
|
|
if (sz < 0) {
|
|
KALDI_WARN << "Negative string size! Read failure";
|
|
strm.clear(std::ios::badbit);
|
|
return strm;
|
|
}
|
|
string_.resize(sz);
|
|
for (int32 i = 0; i < sz; i++) {
|
|
ReadType(strm, &(string_[i]));
|
|
}
|
|
return strm;
|
|
}
|
|
|
|
// This is used in OpenFst for binary I/O. This is OpenFst-style,
|
|
// not Kaldi-style, I/O.
|
|
std::ostream &Write(std::ostream &strm) const {
|
|
weight_.Write(strm);
|
|
if (strm.fail()) {
|
|
return strm;
|
|
}
|
|
int32 sz = static_cast<int32>(string_.size());
|
|
WriteType(strm, sz);
|
|
for (int32 i = 0; i < sz; i++) WriteType(strm, string_[i]);
|
|
return strm;
|
|
}
|
|
size_t Hash() const {
|
|
size_t ans = weight_.Hash();
|
|
// any weird numbers here are largish primes
|
|
size_t sz = string_.size(), mult = 6967;
|
|
for (size_t i = 0; i < sz; i++) {
|
|
ans += string_[i] * mult;
|
|
mult *= 7499;
|
|
}
|
|
return ans;
|
|
}
|
|
|
|
private:
|
|
W weight_;
|
|
std::vector<IntType> string_;
|
|
};
|
|
|
|
template <class WeightType, class IntType>
|
|
inline bool operator==(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
|
|
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
|
|
return (w1.Weight() == w2.Weight() && w1.String() == w2.String());
|
|
}
|
|
|
|
template <class WeightType, class IntType>
|
|
inline bool operator!=(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
|
|
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
|
|
return (w1.Weight() != w2.Weight() || w1.String() != w2.String());
|
|
}
|
|
|
|
template <class WeightType, class IntType>
|
|
inline bool ApproxEqual(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
|
|
const CompactLatticeWeightTpl<WeightType, IntType> &w2,
|
|
float delta = kDelta) {
|
|
return (ApproxEqual(w1.Weight(), w2.Weight(), delta) &&
|
|
w1.String() == w2.String());
|
|
}
|
|
|
|
// Compare is not part of the standard for weight types, but used internally for
|
|
// efficiency. The comparison here first compares the weight; if this is the
|
|
// same, it compares the string. The comparison on strings is: first compare
|
|
// the length, if this is the same, use lexicographical order. We can't just
|
|
// use the lexicographical order because this would destroy the distributive
|
|
// property of multiplication over addition, taking into account that addition
|
|
// uses Compare. The string element of "Compare" isn't super-important in
|
|
// practical terms; it's only needed to ensure that Plus always give consistent
|
|
// answers and is symmetric. It's essentially for tie-breaking, but we need to
|
|
// make sure all the semiring axioms are satisfied otherwise OpenFst might
|
|
// break.
|
|
|
|
template <class WeightType, class IntType>
|
|
inline int Compare(const CompactLatticeWeightTpl<WeightType, IntType> &w1,
|
|
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
|
|
int c1 = Compare(w1.Weight(), w2.Weight());
|
|
if (c1 != 0) return c1;
|
|
int l1 = w1.String().size(), l2 = w2.String().size();
|
|
// Use opposite order on the string lengths, so that if the costs are the
|
|
// same, the shorter string wins.
|
|
if (l1 > l2)
|
|
return -1;
|
|
else if (l1 < l2)
|
|
return 1;
|
|
for (int i = 0; i < l1; i++) {
|
|
if (w1.String()[i] < w2.String()[i])
|
|
return -1;
|
|
else if (w1.String()[i] > w2.String()[i])
|
|
return 1;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
// For efficiency, override the NaturalLess template class.
|
|
template <class FloatType, class IntType>
|
|
class NaturalLess<
|
|
CompactLatticeWeightTpl<LatticeWeightTpl<FloatType>, IntType> > {
|
|
public:
|
|
typedef CompactLatticeWeightTpl<LatticeWeightTpl<FloatType>, IntType> Weight;
|
|
|
|
NaturalLess() {}
|
|
|
|
bool operator()(const Weight &w1, const Weight &w2) const {
|
|
// NaturalLess is a negative order (opposite to normal ordering).
|
|
// This operator () corresponds to "<" in the negative order, which
|
|
// corresponds to the ">" in the normal order.
|
|
return (Compare(w1, w2) == 1);
|
|
}
|
|
};
|
|
template <>
|
|
class NaturalLess<CompactLatticeWeightTpl<LatticeWeightTpl<float>, int32> > {
|
|
public:
|
|
typedef CompactLatticeWeightTpl<LatticeWeightTpl<float>, int32> Weight;
|
|
|
|
NaturalLess() {}
|
|
|
|
bool operator()(const Weight &w1, const Weight &w2) const {
|
|
// NaturalLess is a negative order (opposite to normal ordering).
|
|
// This operator () corresponds to "<" in the negative order, which
|
|
// corresponds to the ">" in the normal order.
|
|
return (Compare(w1, w2) == 1);
|
|
}
|
|
};
|
|
template <>
|
|
class NaturalLess<CompactLatticeWeightTpl<LatticeWeightTpl<double>, int32> > {
|
|
public:
|
|
typedef CompactLatticeWeightTpl<LatticeWeightTpl<double>, int32> Weight;
|
|
|
|
NaturalLess() {}
|
|
|
|
bool operator()(const Weight &w1, const Weight &w2) const {
|
|
// NaturalLess is a negative order (opposite to normal ordering).
|
|
// This operator () corresponds to "<" in the negative order, which
|
|
// corresponds to the ">" in the normal order.
|
|
return (Compare(w1, w2) == 1);
|
|
}
|
|
};
|
|
|
|
// Make sure Compare is defined for TropicalWeight, so everything works
|
|
// if we substitute LatticeWeight for TropicalWeight.
|
|
inline int Compare(const TropicalWeight &w1, const TropicalWeight &w2) {
|
|
float f1 = w1.Value(), f2 = w2.Value();
|
|
if (f1 == f2)
|
|
return 0;
|
|
else if (f1 > f2)
|
|
return -1;
|
|
else
|
|
return 1;
|
|
}
|
|
|
|
template <class WeightType, class IntType>
|
|
inline CompactLatticeWeightTpl<WeightType, IntType> Plus(
|
|
const CompactLatticeWeightTpl<WeightType, IntType> &w1,
|
|
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
|
|
return (Compare(w1, w2) >= 0 ? w1 : w2);
|
|
}
|
|
|
|
template <class WeightType, class IntType>
|
|
inline CompactLatticeWeightTpl<WeightType, IntType> Times(
|
|
const CompactLatticeWeightTpl<WeightType, IntType> &w1,
|
|
const CompactLatticeWeightTpl<WeightType, IntType> &w2) {
|
|
WeightType w = Times(w1.Weight(), w2.Weight());
|
|
if (w == WeightType::Zero()) {
|
|
return CompactLatticeWeightTpl<WeightType, IntType>::Zero();
|
|
// special case to ensure zero is unique
|
|
} else {
|
|
std::vector<IntType> v;
|
|
v.resize(w1.String().size() + w2.String().size());
|
|
typename std::vector<IntType>::iterator iter = v.begin();
|
|
iter = std::copy(w1.String().begin(), w1.String().end(),
|
|
iter); // returns end of first range.
|
|
std::copy(w2.String().begin(), w2.String().end(), iter);
|
|
return CompactLatticeWeightTpl<WeightType, IntType>(w, v);
|
|
}
|
|
}
|
|
|
|
template <class WeightType, class IntType>
|
|
inline CompactLatticeWeightTpl<WeightType, IntType> Divide(
|
|
const CompactLatticeWeightTpl<WeightType, IntType> &w1,
|
|
const CompactLatticeWeightTpl<WeightType, IntType> &w2,
|
|
DivideType div = DIVIDE_ANY) {
|
|
if (w1.Weight() == WeightType::Zero()) {
|
|
if (w2.Weight() != WeightType::Zero()) {
|
|
return CompactLatticeWeightTpl<WeightType, IntType>::Zero();
|
|
} else {
|
|
KALDI_ERR << "Division by zero [0/0]";
|
|
}
|
|
} else if (w2.Weight() == WeightType::Zero()) {
|
|
KALDI_ERR << "Error: division by zero";
|
|
}
|
|
WeightType w = Divide(w1.Weight(), w2.Weight());
|
|
|
|
const std::vector<IntType> v1 = w1.String(), v2 = w2.String();
|
|
if (v2.size() > v1.size()) {
|
|
KALDI_ERR << "Cannot divide, length mismatch";
|
|
}
|
|
typename std::vector<IntType>::const_iterator v1b = v1.begin(),
|
|
v1e = v1.end(),
|
|
v2b = v2.begin(),
|
|
v2e = v2.end();
|
|
if (div == DIVIDE_LEFT) {
|
|
if (!std::equal(v2b, v2e,
|
|
v1b)) { // v2 must be identical to first part of v1.
|
|
KALDI_ERR << "Cannot divide, data mismatch";
|
|
}
|
|
return CompactLatticeWeightTpl<WeightType, IntType>(
|
|
w, std::vector<IntType>(v1b + (v2e - v2b),
|
|
v1e)); // return last part of v1.
|
|
} else if (div == DIVIDE_RIGHT) {
|
|
if (!std::equal(
|
|
v2b, v2e,
|
|
v1e - (v2e - v2b))) { // v2 must be identical to last part of v1.
|
|
KALDI_ERR << "Cannot divide, data mismatch";
|
|
}
|
|
return CompactLatticeWeightTpl<WeightType, IntType>(
|
|
w, std::vector<IntType>(
|
|
v1b, v1e - (v2e - v2b))); // return first part of v1.
|
|
|
|
} else {
|
|
KALDI_ERR << "Cannot divide CompactLatticeWeightTpl with DIVIDE_ANY";
|
|
}
|
|
return CompactLatticeWeightTpl<WeightType,
|
|
IntType>::Zero(); // keep compiler happy.
|
|
}
|
|
|
|
template <class WeightType, class IntType>
|
|
inline std::ostream &operator<<(
|
|
std::ostream &strm, const CompactLatticeWeightTpl<WeightType, IntType> &w) {
|
|
strm << w.Weight();
|
|
CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
|
|
strm << FLAGS_fst_weight_separator[0]; // comma by default.
|
|
for (size_t i = 0; i < w.String().size(); i++) {
|
|
strm << w.String()[i];
|
|
if (i + 1 < w.String().size())
|
|
strm << kStringSeparator; // '_'; defined in string-weight.h in OpenFst
|
|
// code.
|
|
}
|
|
return strm;
|
|
}
|
|
|
|
template <class WeightType, class IntType>
|
|
inline std::istream &operator>>(
|
|
std::istream &strm, CompactLatticeWeightTpl<WeightType, IntType> &w) {
|
|
std::string s;
|
|
strm >> s;
|
|
if (strm.fail()) {
|
|
return strm;
|
|
}
|
|
CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT
|
|
size_t pos = s.find_last_of(FLAGS_fst_weight_separator); // normally ","
|
|
if (pos == std::string::npos) {
|
|
strm.clear(std::ios::badbit);
|
|
return strm;
|
|
}
|
|
// get parts of str before and after the separator (default: ',');
|
|
std::string s1(s, 0, pos), s2(s, pos + 1);
|
|
std::istringstream strm1(s1);
|
|
WeightType weight;
|
|
strm1 >> weight;
|
|
w.SetWeight(weight);
|
|
if (strm1.fail() || !strm1.eof()) {
|
|
strm.clear(std::ios::badbit);
|
|
return strm;
|
|
}
|
|
// read string part.
|
|
std::vector<IntType> string;
|
|
const char *c = s2.c_str();
|
|
while (*c != '\0') {
|
|
if (*c == kStringSeparator) // '_'
|
|
c++;
|
|
char *c2;
|
|
int64_t i = strtol(c, &c2, 10);
|
|
if (c2 == c || static_cast<int64_t>(static_cast<IntType>(i)) != i) {
|
|
strm.clear(std::ios::badbit);
|
|
return strm;
|
|
}
|
|
c = c2;
|
|
string.push_back(static_cast<IntType>(i));
|
|
}
|
|
w.SetString(string);
|
|
return strm;
|
|
}
|
|
|
|
template <class BaseWeightType, class IntType>
|
|
class CompactLatticeWeightCommonDivisorTpl {
|
|
public:
|
|
typedef CompactLatticeWeightTpl<BaseWeightType, IntType> Weight;
|
|
|
|
Weight operator()(const Weight &w1, const Weight &w2) const {
|
|
// First find longest common prefix of the strings.
|
|
typename std::vector<IntType>::const_iterator s1b = w1.String().begin(),
|
|
s1e = w1.String().end(),
|
|
s2b = w2.String().begin(),
|
|
s2e = w2.String().end();
|
|
while (s1b < s1e && s2b < s2e && *s1b == *s2b) {
|
|
s1b++;
|
|
s2b++;
|
|
}
|
|
return Weight(Plus(w1.Weight(), w2.Weight()),
|
|
std::vector<IntType>(w1.String().begin(), s1b));
|
|
}
|
|
};
|
|
|
|
/** Scales the pair (a, b) of floating-point weights inside a
|
|
CompactLatticeWeight by premultiplying it (viewed as a vector)
|
|
by a 2x2 matrix "scale".
|
|
Assumes there is a ScaleTupleWeight function that applies to "Weight";
|
|
this currently only works if Weight equals LatticeWeightTpl<FloatType>
|
|
for some FloatType.
|
|
*/
|
|
template <class Weight, class IntType, class ScaleFloatType>
|
|
inline CompactLatticeWeightTpl<Weight, IntType> ScaleTupleWeight(
|
|
const CompactLatticeWeightTpl<Weight, IntType> &w,
|
|
const std::vector<std::vector<ScaleFloatType> > &scale) {
|
|
return CompactLatticeWeightTpl<Weight, IntType>(
|
|
Weight(ScaleTupleWeight(w.Weight(), scale)), w.String());
|
|
}
|
|
|
|
/** Define some ConvertLatticeWeight functions that are used in various lattice
|
|
conversions... make them all templates, some with no arguments, since some
|
|
must be templates.*/
|
|
template <class Float1, class Float2>
|
|
inline void ConvertLatticeWeight(const LatticeWeightTpl<Float1> &w_in,
|
|
LatticeWeightTpl<Float2> *w_out) {
|
|
w_out->SetValue1(w_in.Value1());
|
|
w_out->SetValue2(w_in.Value2());
|
|
}
|
|
|
|
template <class Float1, class Float2, class Int>
|
|
inline void ConvertLatticeWeight(
|
|
const CompactLatticeWeightTpl<LatticeWeightTpl<Float1>, Int> &w_in,
|
|
CompactLatticeWeightTpl<LatticeWeightTpl<Float2>, Int> *w_out) {
|
|
LatticeWeightTpl<Float2> weight2(w_in.Weight().Value1(),
|
|
w_in.Weight().Value2());
|
|
w_out->SetWeight(weight2);
|
|
w_out->SetString(w_in.String());
|
|
}
|
|
|
|
// to convert from Lattice to standard FST
|
|
template <class Float1, class Float2>
|
|
inline void ConvertLatticeWeight(const LatticeWeightTpl<Float1> &w_in,
|
|
TropicalWeightTpl<Float2> *w_out) {
|
|
TropicalWeightTpl<Float2> w1(w_in.Value1());
|
|
TropicalWeightTpl<Float2> w2(w_in.Value2());
|
|
*w_out = Times(w1, w2);
|
|
}
|
|
|
|
template <class Float>
|
|
inline double ConvertToCost(const LatticeWeightTpl<Float> &w) {
|
|
return static_cast<double>(w.Value1()) + static_cast<double>(w.Value2());
|
|
}
|
|
|
|
template <class Float, class Int>
|
|
inline double ConvertToCost(
|
|
const CompactLatticeWeightTpl<LatticeWeightTpl<Float>, Int> &w) {
|
|
return static_cast<double>(w.Weight().Value1()) +
|
|
static_cast<double>(w.Weight().Value2());
|
|
}
|
|
|
|
template <class Float>
|
|
inline double ConvertToCost(const TropicalWeightTpl<Float> &w) {
|
|
return w.Value();
|
|
}
|
|
|
|
} // namespace fst
|
|
|
|
#endif // KALDI_FSTEXT_LATTICE_WEIGHT_H_
|