// fstext/lattice-weight.h // Copyright 2009-2012 Microsoft Corporation // Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_FSTEXT_LATTICE_WEIGHT_H_ #define KALDI_FSTEXT_LATTICE_WEIGHT_H_ #include #include #include #include #include "base/kaldi-common.h" #include "fst/fstlib.h" namespace fst { // Declare weight type for lattice... will import to namespace kaldi. has two // members, value1_ and value2_, of type BaseFloat (normally equals float). It // is basically the same as the tropical semiring on value1_+value2_, except it // keeps track of a and b separately. More precisely, it is equivalent to the // lexicographic semiring on (value1_+value2_), (value1_-value2_) template class LatticeWeightTpl; template inline std::ostream &operator<<(std::ostream &strm, const LatticeWeightTpl &w); template inline std::istream &operator>>(std::istream &strm, LatticeWeightTpl &w); template class LatticeWeightTpl { public: typedef FloatType T; // normally float. typedef LatticeWeightTpl ReverseWeight; inline T Value1() const { return value1_; } inline T Value2() const { return value2_; } inline void SetValue1(T f) { value1_ = f; } inline void SetValue2(T f) { value2_ = f; } LatticeWeightTpl() : value1_{}, value2_{} {} LatticeWeightTpl(T a, T b) : value1_(a), value2_(b) {} LatticeWeightTpl(const LatticeWeightTpl &other) : value1_(other.value1_), value2_(other.value2_) {} LatticeWeightTpl &operator=(const LatticeWeightTpl &w) { value1_ = w.value1_; value2_ = w.value2_; return *this; } LatticeWeightTpl Reverse() const { return *this; } static const LatticeWeightTpl Zero() { return LatticeWeightTpl(std::numeric_limits::infinity(), std::numeric_limits::infinity()); } static const LatticeWeightTpl One() { return LatticeWeightTpl(0.0, 0.0); } static const std::string &Type() { static const std::string type = (sizeof(T) == 4 ? "lattice4" : "lattice8"); return type; } static const LatticeWeightTpl NoWeight() { return LatticeWeightTpl(std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()); } bool Member() const { // value1_ == value1_ tests for NaN. // also test for no -inf, and either both or neither // must be +inf, and if (value1_ != value1_ || value2_ != value2_) return false; // NaN if (value1_ == -std::numeric_limits::infinity() || value2_ == -std::numeric_limits::infinity()) return false; // -infty not allowed if (value1_ == std::numeric_limits::infinity() || value2_ == std::numeric_limits::infinity()) { if (value1_ != std::numeric_limits::infinity() || value2_ != std::numeric_limits::infinity()) return false; // both must be +infty; // this is necessary so that the semiring has only one zero. } return true; } LatticeWeightTpl Quantize(float delta = kDelta) const { if (value1_ + value2_ == -std::numeric_limits::infinity()) { return LatticeWeightTpl(-std::numeric_limits::infinity(), -std::numeric_limits::infinity()); } else if (value1_ + value2_ == std::numeric_limits::infinity()) { return LatticeWeightTpl(std::numeric_limits::infinity(), std::numeric_limits::infinity()); } else if (value1_ + value2_ != value1_ + value2_) { // NaN return LatticeWeightTpl(value1_ + value2_, value1_ + value2_); } else { return LatticeWeightTpl(floor(value1_ / delta + 0.5F) * delta, floor(value2_ / delta + 0.5F) * delta); } } static constexpr uint64 Properties() { return kLeftSemiring | kRightSemiring | kCommutative | kPath | kIdempotent; } // This is used in OpenFst for binary I/O. This is OpenFst-style, // not Kaldi-style, I/O. std::istream &Read(std::istream &strm) { // Always read/write as float, even if T is double, // so we can use OpenFst-style read/write and still maintain // compatibility when compiling with different FloatTypes ReadType(strm, &value1_); ReadType(strm, &value2_); return strm; } // This is used in OpenFst for binary I/O. This is OpenFst-style, // not Kaldi-style, I/O. std::ostream &Write(std::ostream &strm) const { WriteType(strm, value1_); WriteType(strm, value2_); return strm; } size_t Hash() const { size_t ans; union { T f; size_t s; } u; u.s = 0; u.f = value1_; ans = u.s; u.f = value2_; ans += u.s; return ans; } protected: inline static void WriteFloatType(std::ostream &strm, const T &f) { if (f == std::numeric_limits::infinity()) strm << "Infinity"; else if (f == -std::numeric_limits::infinity()) strm << "-Infinity"; else if (f != f) strm << "BadNumber"; else strm << f; } // Internal helper function, used in ReadNoParen. inline static void ReadFloatType(std::istream &strm, T &f) { // NOLINT std::string s; strm >> s; if (s == "Infinity") { f = std::numeric_limits::infinity(); } else if (s == "-Infinity") { f = -std::numeric_limits::infinity(); } else if (s == "BadNumber") { f = std::numeric_limits::quiet_NaN(); } else { char *p; f = strtod(s.c_str(), &p); if (p < s.c_str() + s.size()) strm.clear(std::ios::badbit); } } // Reads LatticeWeight when there are no parentheses around pair terms... // currently the only form supported. inline std::istream &ReadNoParen(std::istream &strm, char separator) { int c; do { c = strm.get(); } while (isspace(c)); std::string s1; while (c != separator) { if (c == EOF) { strm.clear(std::ios::badbit); return strm; } s1 += c; c = strm.get(); } std::istringstream strm1(s1); ReadFloatType(strm1, value1_); // ReadFloatType is class member function // read second element ReadFloatType(strm, value2_); return strm; } friend std::istream &operator>> (std::istream &, LatticeWeightTpl &); friend std::ostream &operator<<( std::ostream &, const LatticeWeightTpl &); private: T value1_; T value2_; }; /* ScaleTupleWeight is a function defined for LatticeWeightTpl and CompactLatticeWeightTpl that mutliplies the pair (value1_, value2_) by a 2x2 matrix. Used, for example, in applying acoustic scaling. */ template inline LatticeWeightTpl ScaleTupleWeight( const LatticeWeightTpl &w, const std::vector > &scale) { // Without the next special case we'd get NaNs from infinity * 0 if (w.Value1() == std::numeric_limits::infinity()) return LatticeWeightTpl::Zero(); return LatticeWeightTpl( scale[0][0] * w.Value1() + scale[0][1] * w.Value2(), scale[1][0] * w.Value1() + scale[1][1] * w.Value2()); } /* For testing purposes and in case it's ever useful, we define a similar function to apply to LexicographicWeight and the like, templated on TropicalWeight etc.; we use PairWeight which is the base class of LexicographicWeight. */ template inline PairWeight, TropicalWeightTpl > ScaleTupleWeight(const PairWeight, TropicalWeightTpl > &w, const std::vector > &scale) { typedef TropicalWeightTpl BaseType; typedef PairWeight PairType; const BaseType zero = BaseType::Zero(); // Without the next special case we'd get NaNs from infinity * 0 if (w.Value1() == zero || w.Value2() == zero) return PairType(zero, zero); FloatType f1 = w.Value1().Value(), f2 = w.Value2().Value(); return PairType(BaseType(scale[0][0] * f1 + scale[0][1] * f2), BaseType(scale[1][0] * f1 + scale[1][1] * f2)); } template inline bool operator==(const LatticeWeightTpl &wa, const LatticeWeightTpl &wb) { // Volatile qualifier thwarts over-aggressive compiler optimizations // that lead to problems esp. with NaturalLess(). volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(), vb2 = wb.Value2(); return (va1 == vb1 && va2 == vb2); } template inline bool operator!=(const LatticeWeightTpl &wa, const LatticeWeightTpl &wb) { // Volatile qualifier thwarts over-aggressive compiler optimizations // that lead to problems esp. with NaturalLess(). volatile FloatType va1 = wa.Value1(), va2 = wa.Value2(), vb1 = wb.Value1(), vb2 = wb.Value2(); return (va1 != vb1 || va2 != vb2); } // We define a Compare function LatticeWeightTpl even though it's // not required by the semiring standard-- it's just more efficient // to do it this way rather than using the NaturalLess template. /// Compare returns -1 if w1 < w2, +1 if w1 > w2, and 0 if w1 == w2. template inline int Compare(const LatticeWeightTpl &w1, const LatticeWeightTpl &w2) { FloatType f1 = w1.Value1() + w1.Value2(), f2 = w2.Value1() + w2.Value2(); if (f1 < f2) { // having smaller cost means you're larger return 1; } else if (f1 > f2) { // in the semiring [higher probability] return -1; } else if (w1.Value1() < w2.Value1()) { // mathematically we should be comparing (w1.value1_-w1.value2_ < // w2.value1_-w2.value2_) in the next line, but add w1.value1_+w1.value2_ = // w2.value1_+w2.value2_ to both sides and divide by two, and we get the // simpler equivalent form w1.value1_ < w2.value1_. return 1; } else if (w1.Value1() > w2.Value1()) { return -1; } else { return 0; } } template inline LatticeWeightTpl Plus(const LatticeWeightTpl &w1, const LatticeWeightTpl &w2) { return (Compare(w1, w2) >= 0 ? w1 : w2); } // For efficiency, override the NaturalLess template class. template class NaturalLess > { public: typedef LatticeWeightTpl Weight; NaturalLess() {} bool operator()(const Weight &w1, const Weight &w2) const { // NaturalLess is a negative order (opposite to normal ordering). // This operator () corresponds to "<" in the negative order, which // corresponds to the ">" in the normal order. return (Compare(w1, w2) == 1); } }; template <> class NaturalLess > { public: typedef LatticeWeightTpl Weight; NaturalLess() {} bool operator()(const Weight &w1, const Weight &w2) const { // NaturalLess is a negative order (opposite to normal ordering). // This operator () corresponds to "<" in the negative order, which // corresponds to the ">" in the normal order. return (Compare(w1, w2) == 1); } }; template <> class NaturalLess > { public: typedef LatticeWeightTpl Weight; NaturalLess() {} bool operator()(const Weight &w1, const Weight &w2) const { // NaturalLess is a negative order (opposite to normal ordering). // This operator () corresponds to "<" in the negative order, which // corresponds to the ">" in the normal order. return (Compare(w1, w2) == 1); } }; template inline LatticeWeightTpl Times( const LatticeWeightTpl &w1, const LatticeWeightTpl &w2) { return LatticeWeightTpl(w1.Value1() + w2.Value1(), w1.Value2() + w2.Value2()); } // divide w1 by w2 (on left/right/any doesn't matter as // commutative). template inline LatticeWeightTpl Divide(const LatticeWeightTpl &w1, const LatticeWeightTpl &w2, DivideType typ = DIVIDE_ANY) { typedef FloatType T; T a = w1.Value1() - w2.Value1(), b = w1.Value2() - w2.Value2(); if (a != a || b != b || a == -std::numeric_limits::infinity() || b == -std::numeric_limits::infinity()) { KALDI_WARN << "LatticeWeightTpl::Divide, NaN or invalid number produced. " << "[dividing by zero?] Returning zero"; return LatticeWeightTpl::Zero(); } if (a == std::numeric_limits::infinity() || b == std::numeric_limits::infinity()) return LatticeWeightTpl::Zero(); // not a valid number if only one is // infinite. return LatticeWeightTpl(a, b); } template inline bool ApproxEqual(const LatticeWeightTpl &w1, const LatticeWeightTpl &w2, float delta = kDelta) { if (w1.Value1() == w2.Value1() && w1.Value2() == w2.Value2()) return true; // handles Zero(). return (fabs((w1.Value1() + w1.Value2()) - (w2.Value1() + w2.Value2())) <= delta); } template inline std::ostream &operator<<(std::ostream &strm, const LatticeWeightTpl &w) { LatticeWeightTpl::WriteFloatType(strm, w.Value1()); CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT strm << FLAGS_fst_weight_separator[0]; // comma by default; // may or may not be settable from Kaldi programs. LatticeWeightTpl::WriteFloatType(strm, w.Value2()); return strm; } template inline std::istream &operator>>(std::istream &strm, LatticeWeightTpl &w1) { CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT // separator defaults to ',' return w1.ReadNoParen(strm, FLAGS_fst_weight_separator[0]); } // CompactLattice will be an acceptor (accepting the words/output-symbols), // with the weights and input-symbol-seqs on the arcs. // There must be a total order on W. We assume for the sake of efficiency // that there is a function // Compare(W w1, W w2) that returns -1 if w1 < w2, +1 if w1 > w2, and // zero if w1 == w2, and Plus for type W returns (Compare(w1,w2) >= 0 ? w1 : // w2). template class CompactLatticeWeightTpl { public: typedef WeightType W; typedef CompactLatticeWeightTpl ReverseWeight; // Plus is like LexicographicWeight on the pair (weight_, string_), but where // we use standard lexicographic order on string_ [this is not the same as // NaturalLess on the StringWeight equivalent, which does not define a // total order]. // Times, Divide obvious... (support both left & right division..) // CommonDivisor would need to be coded separately. CompactLatticeWeightTpl() {} CompactLatticeWeightTpl(const WeightType &w, const std::vector &s) : weight_(w), string_(s) {} CompactLatticeWeightTpl &operator=( const CompactLatticeWeightTpl &w) { weight_ = w.weight_; string_ = w.string_; return *this; } const W &Weight() const { return weight_; } const std::vector &String() const { return string_; } void SetWeight(const W &w) { weight_ = w; } void SetString(const std::vector &s) { string_ = s; } static const CompactLatticeWeightTpl Zero() { return CompactLatticeWeightTpl(WeightType::Zero(), std::vector()); } static const CompactLatticeWeightTpl One() { return CompactLatticeWeightTpl(WeightType::One(), std::vector()); } inline static std::string GetIntSizeString() { char buf[2]; buf[0] = '0' + sizeof(IntType); buf[1] = '\0'; return buf; } static const std::string &Type() { static const std::string type = "compact" + WeightType::Type() + GetIntSizeString(); return type; } static const CompactLatticeWeightTpl NoWeight() { return CompactLatticeWeightTpl(WeightType::NoWeight(), std::vector()); } CompactLatticeWeightTpl Reverse() const { size_t s = string_.size(); std::vector v(s); for (size_t i = 0; i < s; i++) v[i] = string_[s - i - 1]; return CompactLatticeWeightTpl(weight_, v); } bool Member() const { // a semiring has only one zero, this is the important property // we're trying to maintain here. So force string_ to be empty if // w_ == zero. if (!weight_.Member()) return false; if (weight_ == WeightType::Zero()) return string_.empty(); else return true; } CompactLatticeWeightTpl Quantize(float delta = kDelta) const { return CompactLatticeWeightTpl(weight_.Quantize(delta), string_); } static constexpr uint64 Properties() { return kLeftSemiring | kRightSemiring | kPath | kIdempotent; } // This is used in OpenFst for binary I/O. This is OpenFst-style, // not Kaldi-style, I/O. std::istream &Read(std::istream &strm) { weight_.Read(strm); if (strm.fail()) { return strm; } int32 sz; ReadType(strm, &sz); if (strm.fail()) { return strm; } if (sz < 0) { KALDI_WARN << "Negative string size! Read failure"; strm.clear(std::ios::badbit); return strm; } string_.resize(sz); for (int32 i = 0; i < sz; i++) { ReadType(strm, &(string_[i])); } return strm; } // This is used in OpenFst for binary I/O. This is OpenFst-style, // not Kaldi-style, I/O. std::ostream &Write(std::ostream &strm) const { weight_.Write(strm); if (strm.fail()) { return strm; } int32 sz = static_cast(string_.size()); WriteType(strm, sz); for (int32 i = 0; i < sz; i++) WriteType(strm, string_[i]); return strm; } size_t Hash() const { size_t ans = weight_.Hash(); // any weird numbers here are largish primes size_t sz = string_.size(), mult = 6967; for (size_t i = 0; i < sz; i++) { ans += string_[i] * mult; mult *= 7499; } return ans; } private: W weight_; std::vector string_; }; template inline bool operator==(const CompactLatticeWeightTpl &w1, const CompactLatticeWeightTpl &w2) { return (w1.Weight() == w2.Weight() && w1.String() == w2.String()); } template inline bool operator!=(const CompactLatticeWeightTpl &w1, const CompactLatticeWeightTpl &w2) { return (w1.Weight() != w2.Weight() || w1.String() != w2.String()); } template inline bool ApproxEqual(const CompactLatticeWeightTpl &w1, const CompactLatticeWeightTpl &w2, float delta = kDelta) { return (ApproxEqual(w1.Weight(), w2.Weight(), delta) && w1.String() == w2.String()); } // Compare is not part of the standard for weight types, but used internally for // efficiency. The comparison here first compares the weight; if this is the // same, it compares the string. The comparison on strings is: first compare // the length, if this is the same, use lexicographical order. We can't just // use the lexicographical order because this would destroy the distributive // property of multiplication over addition, taking into account that addition // uses Compare. The string element of "Compare" isn't super-important in // practical terms; it's only needed to ensure that Plus always give consistent // answers and is symmetric. It's essentially for tie-breaking, but we need to // make sure all the semiring axioms are satisfied otherwise OpenFst might // break. template inline int Compare(const CompactLatticeWeightTpl &w1, const CompactLatticeWeightTpl &w2) { int c1 = Compare(w1.Weight(), w2.Weight()); if (c1 != 0) return c1; int l1 = w1.String().size(), l2 = w2.String().size(); // Use opposite order on the string lengths, so that if the costs are the // same, the shorter string wins. if (l1 > l2) return -1; else if (l1 < l2) return 1; for (int i = 0; i < l1; i++) { if (w1.String()[i] < w2.String()[i]) return -1; else if (w1.String()[i] > w2.String()[i]) return 1; } return 0; } // For efficiency, override the NaturalLess template class. template class NaturalLess< CompactLatticeWeightTpl, IntType> > { public: typedef CompactLatticeWeightTpl, IntType> Weight; NaturalLess() {} bool operator()(const Weight &w1, const Weight &w2) const { // NaturalLess is a negative order (opposite to normal ordering). // This operator () corresponds to "<" in the negative order, which // corresponds to the ">" in the normal order. return (Compare(w1, w2) == 1); } }; template <> class NaturalLess, int32> > { public: typedef CompactLatticeWeightTpl, int32> Weight; NaturalLess() {} bool operator()(const Weight &w1, const Weight &w2) const { // NaturalLess is a negative order (opposite to normal ordering). // This operator () corresponds to "<" in the negative order, which // corresponds to the ">" in the normal order. return (Compare(w1, w2) == 1); } }; template <> class NaturalLess, int32> > { public: typedef CompactLatticeWeightTpl, int32> Weight; NaturalLess() {} bool operator()(const Weight &w1, const Weight &w2) const { // NaturalLess is a negative order (opposite to normal ordering). // This operator () corresponds to "<" in the negative order, which // corresponds to the ">" in the normal order. return (Compare(w1, w2) == 1); } }; // Make sure Compare is defined for TropicalWeight, so everything works // if we substitute LatticeWeight for TropicalWeight. inline int Compare(const TropicalWeight &w1, const TropicalWeight &w2) { float f1 = w1.Value(), f2 = w2.Value(); if (f1 == f2) return 0; else if (f1 > f2) return -1; else return 1; } template inline CompactLatticeWeightTpl Plus( const CompactLatticeWeightTpl &w1, const CompactLatticeWeightTpl &w2) { return (Compare(w1, w2) >= 0 ? w1 : w2); } template inline CompactLatticeWeightTpl Times( const CompactLatticeWeightTpl &w1, const CompactLatticeWeightTpl &w2) { WeightType w = Times(w1.Weight(), w2.Weight()); if (w == WeightType::Zero()) { return CompactLatticeWeightTpl::Zero(); // special case to ensure zero is unique } else { std::vector v; v.resize(w1.String().size() + w2.String().size()); typename std::vector::iterator iter = v.begin(); iter = std::copy(w1.String().begin(), w1.String().end(), iter); // returns end of first range. std::copy(w2.String().begin(), w2.String().end(), iter); return CompactLatticeWeightTpl(w, v); } } template inline CompactLatticeWeightTpl Divide( const CompactLatticeWeightTpl &w1, const CompactLatticeWeightTpl &w2, DivideType div = DIVIDE_ANY) { if (w1.Weight() == WeightType::Zero()) { if (w2.Weight() != WeightType::Zero()) { return CompactLatticeWeightTpl::Zero(); } else { KALDI_ERR << "Division by zero [0/0]"; } } else if (w2.Weight() == WeightType::Zero()) { KALDI_ERR << "Error: division by zero"; } WeightType w = Divide(w1.Weight(), w2.Weight()); const std::vector v1 = w1.String(), v2 = w2.String(); if (v2.size() > v1.size()) { KALDI_ERR << "Cannot divide, length mismatch"; } typename std::vector::const_iterator v1b = v1.begin(), v1e = v1.end(), v2b = v2.begin(), v2e = v2.end(); if (div == DIVIDE_LEFT) { if (!std::equal(v2b, v2e, v1b)) { // v2 must be identical to first part of v1. KALDI_ERR << "Cannot divide, data mismatch"; } return CompactLatticeWeightTpl( w, std::vector(v1b + (v2e - v2b), v1e)); // return last part of v1. } else if (div == DIVIDE_RIGHT) { if (!std::equal( v2b, v2e, v1e - (v2e - v2b))) { // v2 must be identical to last part of v1. KALDI_ERR << "Cannot divide, data mismatch"; } return CompactLatticeWeightTpl( w, std::vector( v1b, v1e - (v2e - v2b))); // return first part of v1. } else { KALDI_ERR << "Cannot divide CompactLatticeWeightTpl with DIVIDE_ANY"; } return CompactLatticeWeightTpl::Zero(); // keep compiler happy. } template inline std::ostream &operator<<( std::ostream &strm, const CompactLatticeWeightTpl &w) { strm << w.Weight(); CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT strm << FLAGS_fst_weight_separator[0]; // comma by default. for (size_t i = 0; i < w.String().size(); i++) { strm << w.String()[i]; if (i + 1 < w.String().size()) strm << kStringSeparator; // '_'; defined in string-weight.h in OpenFst // code. } return strm; } template inline std::istream &operator>>( std::istream &strm, CompactLatticeWeightTpl &w) { std::string s; strm >> s; if (strm.fail()) { return strm; } CHECK(FLAGS_fst_weight_separator.size() == 1); // NOLINT size_t pos = s.find_last_of(FLAGS_fst_weight_separator); // normally "," if (pos == std::string::npos) { strm.clear(std::ios::badbit); return strm; } // get parts of str before and after the separator (default: ','); std::string s1(s, 0, pos), s2(s, pos + 1); std::istringstream strm1(s1); WeightType weight; strm1 >> weight; w.SetWeight(weight); if (strm1.fail() || !strm1.eof()) { strm.clear(std::ios::badbit); return strm; } // read string part. std::vector string; const char *c = s2.c_str(); while (*c != '\0') { if (*c == kStringSeparator) // '_' c++; char *c2; int64_t i = strtol(c, &c2, 10); if (c2 == c || static_cast(static_cast(i)) != i) { strm.clear(std::ios::badbit); return strm; } c = c2; string.push_back(static_cast(i)); } w.SetString(string); return strm; } template class CompactLatticeWeightCommonDivisorTpl { public: typedef CompactLatticeWeightTpl Weight; Weight operator()(const Weight &w1, const Weight &w2) const { // First find longest common prefix of the strings. typename std::vector::const_iterator s1b = w1.String().begin(), s1e = w1.String().end(), s2b = w2.String().begin(), s2e = w2.String().end(); while (s1b < s1e && s2b < s2e && *s1b == *s2b) { s1b++; s2b++; } return Weight(Plus(w1.Weight(), w2.Weight()), std::vector(w1.String().begin(), s1b)); } }; /** Scales the pair (a, b) of floating-point weights inside a CompactLatticeWeight by premultiplying it (viewed as a vector) by a 2x2 matrix "scale". Assumes there is a ScaleTupleWeight function that applies to "Weight"; this currently only works if Weight equals LatticeWeightTpl for some FloatType. */ template inline CompactLatticeWeightTpl ScaleTupleWeight( const CompactLatticeWeightTpl &w, const std::vector > &scale) { return CompactLatticeWeightTpl( Weight(ScaleTupleWeight(w.Weight(), scale)), w.String()); } /** Define some ConvertLatticeWeight functions that are used in various lattice conversions... make them all templates, some with no arguments, since some must be templates.*/ template inline void ConvertLatticeWeight(const LatticeWeightTpl &w_in, LatticeWeightTpl *w_out) { w_out->SetValue1(w_in.Value1()); w_out->SetValue2(w_in.Value2()); } template inline void ConvertLatticeWeight( const CompactLatticeWeightTpl, Int> &w_in, CompactLatticeWeightTpl, Int> *w_out) { LatticeWeightTpl weight2(w_in.Weight().Value1(), w_in.Weight().Value2()); w_out->SetWeight(weight2); w_out->SetString(w_in.String()); } // to convert from Lattice to standard FST template inline void ConvertLatticeWeight(const LatticeWeightTpl &w_in, TropicalWeightTpl *w_out) { TropicalWeightTpl w1(w_in.Value1()); TropicalWeightTpl w2(w_in.Value2()); *w_out = Times(w1, w2); } template inline double ConvertToCost(const LatticeWeightTpl &w) { return static_cast(w.Value1()) + static_cast(w.Value2()); } template inline double ConvertToCost( const CompactLatticeWeightTpl, Int> &w) { return static_cast(w.Weight().Value1()) + static_cast(w.Weight().Value2()); } template inline double ConvertToCost(const TropicalWeightTpl &w) { return w.Value(); } } // namespace fst #endif // KALDI_FSTEXT_LATTICE_WEIGHT_H_