// Copyright 2005-2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the 'License'); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an 'AS IS' BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Recursively replaces FST arcs with other FSTs, returning a PDT. #ifndef FST_EXTENSIONS_PDT_REPLACE_H_ #define FST_EXTENSIONS_PDT_REPLACE_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace fst { namespace internal { // Hash to paren IDs template struct ReplaceParenHash { size_t operator()(const std::pair &paren) const { static constexpr auto prime = 7853; return paren.first + paren.second * prime; } }; } // namespace internal // Parser types characterize the PDT construction method. When applied to a CFG, // each non-terminal is encoded as a DFA that accepts precisely the RHS's of // productions of that non-terminal. For parsing (rather than just recognition), // production numbers can used as outputs (placed as early as possible) in the // DFAs promoted to DFTs. For more information on the strongly regular // construction, see: // // Mohri, M., and Pereira, F. 1998. Dynamic compilation of weighted context-free // grammars. In Proc. ACL, pages 891-897. enum class PdtParserType : uint8_t { // Top-down construction. Applied to a simple LL(1) grammar (among others), // gives a DPDA. If promoted to a DPDT, with outputs being production // numbers, gives a leftmost derivation. Left recursive grammars are // problematic in use. LEFT, // Top-down construction. Similar to LEFT except bounded-stack // (expandable as an FST) result with regular or, more generally, strongly // regular grammars. Epsilons may replace some parentheses, which may // introduce some non-determinism. LEFT_SR, /* TODO(riley): // Bottom-up construction. Applied to a LR(0) grammar, gives a DPDA. // If promoted to a DPDT, with outputs being the production numbers, // gives the reverse of a rightmost derivation. RIGHT, */ }; template struct PdtReplaceOptions { using Label = typename Arc::Label; explicit PdtReplaceOptions(Label root, PdtParserType type = PdtParserType::LEFT, Label start_paren_labels = kNoLabel, std::string left_paren_prefix = "(_", std::string right_paren_prefix = ")_") : root(root), type(type), start_paren_labels(start_paren_labels), left_paren_prefix(std::move(left_paren_prefix)), right_paren_prefix(std::move(right_paren_prefix)) {} Label root; PdtParserType type; Label start_paren_labels; const std::string left_paren_prefix; const std::string right_paren_prefix; }; // PdtParser: Base PDT parser class common to specific parsers. template class PdtParser { public: using Label = typename Arc::Label; using StateId = typename Arc::StateId; using Weight = typename Arc::Weight; using LabelFstPair = std::pair *>; using LabelPair = std::pair; using LabelStatePair = std::pair; using StateWeightPair = std::pair; using ParenKey = std::pair; using ParenMap = std::unordered_map>; PdtParser(const std::vector &fst_array, const PdtReplaceOptions &opts) : root_(opts.root), start_paren_labels_(opts.start_paren_labels), left_paren_prefix_(std::move(opts.left_paren_prefix)), right_paren_prefix_(std::move(opts.right_paren_prefix)), error_(false) { for (size_t i = 0; i < fst_array.size(); ++i) { if (!CompatSymbols(fst_array[0].second->InputSymbols(), fst_array[i].second->InputSymbols())) { FSTERROR() << "PdtParser: Input symbol table of input FST " << i << " does not match input symbol table of 0th input FST"; error_ = true; } if (!CompatSymbols(fst_array[0].second->OutputSymbols(), fst_array[i].second->OutputSymbols())) { FSTERROR() << "PdtParser: Output symbol table of input FST " << i << " does not match output symbol table of 0th input FST"; error_ = true; } fst_array_.emplace_back(fst_array[i].first, fst_array[i].second->Copy()); // Builds map from non-terminal label to FST ID. label2id_[fst_array[i].first] = i; } } virtual ~PdtParser() { for (auto &pair : fst_array_) delete pair.second; } // Constructs the output PDT, dependent on the derived parser type. virtual void GetParser(MutableFst *ofst, std::vector *parens) = 0; protected: const std::vector &FstArray() const { return fst_array_; } Label Root() const { return root_; } // Maps from non-terminal label to corresponding FST ID, or returns // kNoStateId to signal lookup failure. StateId Label2Id(Label l) const { auto it = label2id_.find(l); return it == label2id_.end() ? kNoStateId : it->second; } // Maps from output state to input FST label, state pair, or returns a // (kNoLabel, kNoStateId) pair to signal lookup failure. LabelStatePair GetLabelStatePair(StateId os) const { if (os >= label_state_pairs_.size()) { static const LabelStatePair no_pair(kNoLabel, kNoLabel); return no_pair; } else { return label_state_pairs_[os]; } } // Maps to output state from input FST (label, state) pair, or returns // kNoStateId to signal lookup failure. StateId GetState(const LabelStatePair &lsp) const { auto it = state_map_.find(lsp); if (it == state_map_.end()) { return kNoStateId; } else { return it->second; } } // Builds single FST combining all referenced input FSTs, leaving in the // non-termnals for now; also tabulates the PDT states that correspond to the // start and final states of the input FSTs. void CreateFst(MutableFst *ofst, std::vector *open_dest, std::vector> *close_src); // Assigns parenthesis labels from total allocated paren IDs. void AssignParenLabels(size_t total_nparens, std::vector *parens) { parens->clear(); for (size_t paren_id = 0; paren_id < total_nparens; ++paren_id) { const auto open_paren = start_paren_labels_ + paren_id; const auto close_paren = open_paren + total_nparens; parens->emplace_back(open_paren, close_paren); } } // Determines how non-terminal instances are assigned parentheses IDs. virtual size_t AssignParenIds(const Fst &ofst, ParenMap *paren_map) const = 0; // Changes a non-terminal transition to an open parenthesis transition // redirected to the PDT state specified in the open_dest argument, when // indexed by the input FST ID for the non-terminal. Adds close parenthesis // transitions (with specified weights) from the PDT states specified in the // close_src argument, when indexed by the input FST ID for the non-terminal, // to the former destination state of the non-terminal transition. The // paren_map argument gives the parenthesis ID for a given non-terminal FST ID // and destination state pair. The close_non_term_weight vector specifies // non-terminals for which the non-terminal arc weight should be applied on // the close parenthesis (multiplying the close_src weight above) rather than // on the open parenthesis. If no paren ID is found, then an epsilon replaces // the parenthesis that would carry the non-terminal arc weight and the other // parenthesis is omitted (appropriate for the strongly-regular case). void AddParensToFst( const std::vector &parens, const ParenMap &paren_map, const std::vector &open_dest, const std::vector> &close_src, const std::vector &close_non_term_weight, MutableFst *ofst); // Ensures that parentheses arcs are added to the symbol table. void AddParensToSymbolTables(const std::vector &parens, MutableFst *ofst); private: std::vector fst_array_; Label root_; // Index to use for the first parenthesis. Label start_paren_labels_; const std::string left_paren_prefix_; const std::string right_paren_prefix_; // Maps from non-terminal label to FST ID. std::unordered_map label2id_; // Given an output state, specifies the input FST (label, state) pair. std::vector label_state_pairs_; // Given an FST (label, state) pair, specifies the output FST state ID. std::map state_map_; bool error_; }; template void PdtParser::CreateFst( MutableFst *ofst, std::vector *open_dest, std::vector> *close_src) { ofst->DeleteStates(); if (error_) { ofst->SetProperties(kError, kError); return; } open_dest->resize(fst_array_.size(), kNoStateId); close_src->resize(fst_array_.size()); // Queue of non-terminals to replace. std::deque