// Copyright 2005-2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the 'License'); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an 'AS IS' BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Compresses and decompresses unweighted FSTs. #ifndef FST_EXTENSIONS_COMPRESS_COMPRESS_H_ #define FST_EXTENSIONS_COMPRESS_COMPRESS_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace fst { // Identifies stream data as a vanilla compressed FST. inline constexpr int32_t kCompressMagicNumber = 1858869554; namespace internal { // Expands a Lempel Ziv code and returns the set of code words where // expanded_code[i] is the i^th Lempel Ziv codeword. template [[nodiscard]] bool ExpandLZCode(const std::vector> &code, std::vector> *expanded_code) { expanded_code->resize(code.size()); for (int i = 0; i < code.size(); ++i) { if (code[i].first > i) { LOG(ERROR) << "ExpandLZCode: Not a valid code"; return false; } auto &codeword = (*expanded_code)[i]; if (code[i].first == 0) { codeword.resize(1, code[i].second); } else { const auto &other_codeword = (*expanded_code)[code[i].first - 1]; codeword.resize(other_codeword.size() + 1); std::copy(other_codeword.cbegin(), other_codeword.cend(), codeword.begin()); codeword[other_codeword.size()] = code[i].second; } } return true; } } // namespace internal // Lempel Ziv on data structure Edge, with a less-than operator EdgeLessThan and // an equals operator EdgeEquals. template class LempelZiv { public: LempelZiv() : dict_number_(0), default_edge_() { root_.current_number = dict_number_++; root_.current_edge = default_edge_; decode_vector_.emplace_back(0, default_edge_); } // Encodes a vector input into output. void BatchEncode(const std::vector &input, std::vector> *output); // Decodes codedvector to output, returning false if the index exceeds the // size. [[nodiscard]] bool BatchDecode(const std::vector> &input, std::vector *output); // Decodes a single dictionary element, returning false if the index exceeds // the size. [[nodiscard]] bool SingleDecode(const Var &index, Edge *output) { if (index >= decode_vector_.size()) { LOG(ERROR) << "LempelZiv::SingleDecode: " << "Index exceeded the dictionary size"; return false; } else { *output = decode_vector_[index].second; return true; } } private: struct Node { Var current_number; Edge current_edge; std::map, EdgeLessThan> next_number; }; Node root_; Var dict_number_; std::vector> decode_vector_; Edge default_edge_; }; template void LempelZiv::BatchEncode( const std::vector &input, std::vector> *output) { for (auto it = input.cbegin(); it != input.cend(); ++it) { auto *temp_node = &root_; while (it != input.cend()) { auto next = temp_node->next_number.find(*it); if (next != temp_node->next_number.cend()) { temp_node = next->second.get(); ++it; } else { break; } } if (it == input.cend() && temp_node->current_number != 0) { output->emplace_back(temp_node->current_number, default_edge_); } else if (it != input.cend()) { output->emplace_back(temp_node->current_number, *it); auto new_node = std::make_unique(); new_node->current_number = dict_number_++; new_node->current_edge = *it; temp_node->next_number[*it] = std::move(new_node); } if (it == input.cend()) break; } } template [[nodiscard]] bool LempelZiv::BatchDecode( const std::vector> &input, std::vector *output) { for (const auto &[var, edge] : input) { std::vector temp_output; EdgeEquals InstEdgeEquals; if (InstEdgeEquals(edge, default_edge_) != 1) { decode_vector_.emplace_back(var, edge); temp_output.push_back(edge); } auto temp_integer = var; if (temp_integer >= decode_vector_.size()) { LOG(ERROR) << "LempelZiv::BatchDecode: " << "Index exceeded the dictionary size"; return false; } else { while (temp_integer != 0) { temp_output.push_back(decode_vector_[temp_integer].second); temp_integer = decode_vector_[temp_integer].first; } output->insert(output->cend(), temp_output.rbegin(), temp_output.rend()); } } return true; } template class Compressor { public: using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; Compressor() = default; // Compresses an FST into a boolean vector code, returning true on success. [[nodiscard]] bool Compress(const Fst &fst, std::ostream &strm); // Decompresses the boolean vector into an FST, returning true on success. [[nodiscard]] bool Decompress(std::istream &strm, std::string_view source, MutableFst *fst); // Computes the BFS order of a FST. void BfsOrder(const ExpandedFst &fst, std::vector *order); // Preprocessing step to convert an FST to a isomorphic FST. void Preprocess(const Fst &fst, MutableFst *preprocessedfst, EncodeMapper *encoder); // Performs Lempel Ziv and outputs a stream of integers. void EncodeProcessedFst(const ExpandedFst &fst, std::ostream &strm); // Decodes FST from the stream. void DecodeProcessedFst(const std::vector &input, MutableFst *fst, bool unweighted); // Writes the boolean file to the stream. void WriteToStream(std::ostream &strm); // Writes the weights to the stream. void WriteWeight(const std::vector &input, std::ostream &strm); void ReadWeight(std::istream &strm, std::vector *output); // Same as fst::Decode, but doesn't remove the final epsilons. void DecodeForCompress(MutableFst *fst, const EncodeMapper &mapper); // Updates buffer_code_. template void WriteToBuffer(CVar input) { std::vector current_code; Elias::DeltaEncode(input, ¤t_code); buffer_code_.insert(buffer_code_.cend(), current_code.begin(), current_code.end()); } private: struct LZLabel { LZLabel() : label(0) {} Label label; }; struct LabelLessThan { bool operator()(const LZLabel &labelone, const LZLabel &labeltwo) const { return labelone.label < labeltwo.label; } }; struct LabelEquals { bool operator()(const LZLabel &labelone, const LZLabel &labeltwo) const { return labelone.label == labeltwo.label; } }; struct Transition { Transition() : nextstate(0), label(0), weight(Weight::Zero()) {} StateId nextstate; Label label; Weight weight; }; struct TransitionLessThan { bool operator()(const Transition &transition_one, const Transition &transition_two) const { if (transition_one.nextstate == transition_two.nextstate) { return transition_one.label < transition_two.label; } else { return transition_one.nextstate < transition_two.nextstate; } } }; struct TransitionEquals { bool operator()(const Transition &transition_one, const Transition &transition_two) const { return transition_one.nextstate == transition_two.nextstate && transition_one.label == transition_two.label; } }; struct OldDictCompare { bool operator()(const std::pair &pair_one, const std::pair &pair_two) const { if (pair_one.second.nextstate == pair_two.second.nextstate) { return pair_one.second.label < pair_two.second.label; } else { return pair_one.second.nextstate < pair_two.second.nextstate; } } }; std::vector buffer_code_; std::vector arc_weight_; std::vector final_weight_; }; template void Compressor::DecodeForCompress(MutableFst *fst, const EncodeMapper &mapper) { ArcMap(fst, EncodeMapper(mapper, DECODE)); fst->SetInputSymbols(mapper.InputSymbols()); fst->SetOutputSymbols(mapper.OutputSymbols()); } template void Compressor::BfsOrder(const ExpandedFst &fst, std::vector *order) { class BfsVisitor { public: // Requires order->size() >= fst.NumStates(). explicit BfsVisitor(std::vector *order) : order_(order) {} void InitVisit(const Fst &fst) {} bool InitState(StateId s, StateId) { order_->at(s) = num_bfs_states_++; return true; } bool WhiteArc(StateId s, const Arc &arc) { return true; } bool GreyArc(StateId s, const Arc &arc) { return true; } bool BlackArc(StateId s, const Arc &arc) { return true; } void FinishState(StateId s) {} void FinishVisit() {} private: std::vector *order_ = nullptr; StateId num_bfs_states_ = 0; }; order->assign(fst.NumStates(), kNoStateId); BfsVisitor visitor(order); FifoQueue queue; Visit(fst, &visitor, &queue, AnyArcFilter()); } template void Compressor::Preprocess(const Fst &fst, MutableFst *preprocessedfst, EncodeMapper *encoder) { *preprocessedfst = fst; if (!preprocessedfst->NumStates()) return; // Relabels the edges and develops a dictionary. Encode(preprocessedfst, encoder); std::vector order; // Finds the BFS sorting order of the FST. BfsOrder(*preprocessedfst, &order); // Reorders the states according to the BFS order. StateSort(preprocessedfst, order); } template void Compressor::EncodeProcessedFst(const ExpandedFst &fst, std::ostream &strm) { std::vector output; LempelZiv dict_new; LempelZiv dict_old; std::vector current_new_input; std::vector current_old_input; std::vector> current_new_output; std::vector> current_old_output; std::vector final_states; const auto number_of_states = fst.NumStates(); StateId seen_states = 0; // Adds the number of states. WriteToBuffer(number_of_states); for (StateId state = 0; state < number_of_states; ++state) { current_new_input.clear(); current_old_input.clear(); current_new_output.clear(); current_old_output.clear(); if (state > seen_states) ++seen_states; // Collects the final states. if (fst.Final(state) != Weight::Zero()) { final_states.push_back(state); final_weight_.push_back(fst.Final(state)); } // Reads the states. for (ArcIterator> aiter(fst, state); !aiter.Done(); aiter.Next()) { const auto &arc = aiter.Value(); if (arc.nextstate > seen_states) { // RILEY: > or >= ? ++seen_states; LZLabel temp_label; temp_label.label = arc.ilabel; arc_weight_.push_back(arc.weight); current_new_input.push_back(temp_label); } else { Transition temp_transition; temp_transition.nextstate = arc.nextstate; temp_transition.label = arc.ilabel; temp_transition.weight = arc.weight; current_old_input.push_back(temp_transition); } } // Adds new states. dict_new.BatchEncode(current_new_input, ¤t_new_output); WriteToBuffer(current_new_output.size()); for (auto it = current_new_output.cbegin(); it != current_new_output.cend(); ++it) { WriteToBuffer(it->first); WriteToBuffer