// Copyright 2005-2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the 'License'); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an 'AS IS' BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Convenience file for including all PDT operations at once, and/or // registering them for new arc types. #ifndef FST_EXTENSIONS_PDT_PDTSCRIPT_H_ #define FST_EXTENSIONS_PDT_PDTSCRIPT_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // for ComposeOptions #include #include #include #include #include #include #include #include #include #include namespace fst { namespace script { using PdtComposeArgs = std::tuple> &, MutableFstClass *, const PdtComposeOptions &, bool>; template void Compose(PdtComposeArgs *args) { const Fst &ifst1 = *(std::get<0>(*args).GetFst()); const Fst &ifst2 = *(std::get<1>(*args).GetFst()); MutableFst *ofst = std::get<3>(*args)->GetMutableFst(); // In case Arc::Label is not the same as FstClass::Label, we make a // copy. Truncation may occur if FstClass::Label has more precision than // Arc::Label. std::vector> typed_parens( std::get<2>(*args).size()); std::copy(std::get<2>(*args).begin(), std::get<2>(*args).end(), typed_parens.begin()); if (std::get<5>(*args)) { Compose(ifst1, typed_parens, ifst2, ofst, std::get<4>(*args)); } else { Compose(ifst1, ifst2, typed_parens, ofst, std::get<4>(*args)); } } void Compose(const FstClass &ifst1, const FstClass &ifst2, const std::vector> &parens, MutableFstClass *ofst, const PdtComposeOptions &opts, bool left_pdt); struct PdtExpandOptions { bool connect; bool keep_parentheses; const WeightClass &weight_threshold; PdtExpandOptions(bool c, bool k, const WeightClass &w) : connect(c), keep_parentheses(k), weight_threshold(w) {} }; using PdtExpandArgs = std::tuple> &, MutableFstClass *, const PdtExpandOptions &>; template void Expand(PdtExpandArgs *args) { const Fst &fst = *(std::get<0>(*args).GetFst()); MutableFst *ofst = std::get<2>(*args)->GetMutableFst(); // In case Arc::Label is not the same as FstClass::Label, we make a // copy. Truncation may occur if FstClass::Label has more precision than // Arc::Label. std::vector> typed_parens( std::get<1>(*args).size()); std::copy(std::get<1>(*args).begin(), std::get<1>(*args).end(), typed_parens.begin()); Expand(fst, typed_parens, ofst, fst::PdtExpandOptions( std::get<3>(*args).connect, std::get<3>(*args).keep_parentheses, *(std::get<3>(*args) .weight_threshold.GetWeight()))); } void Expand(const FstClass &ifst, const std::vector> &parens, MutableFstClass *ofst, const PdtExpandOptions &opts); void Expand(const FstClass &ifst, const std::vector> &parens, MutableFstClass *ofst, bool connect, bool keep_parentheses, const WeightClass &weight_threshold); using PdtReplaceArgs = std::tuple> &, MutableFstClass *, std::vector> *, int64_t, PdtParserType, int64_t, const std::string &, const std::string &>; template void Replace(PdtReplaceArgs *args) { const auto &untyped_pairs = std::get<0>(*args); auto size = untyped_pairs.size(); std::vector *>> typed_pairs( size); for (size_t i = 0; i < size; ++i) { typed_pairs[i].first = untyped_pairs[i].first; typed_pairs[i].second = untyped_pairs[i].second->GetFst(); } MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); std::vector> typed_parens; const PdtReplaceOptions opts(std::get<3>(*args), std::get<4>(*args), std::get<5>(*args), std::get<6>(*args), std::get<7>(*args)); Replace(typed_pairs, ofst, &typed_parens, opts); // Copies typed parens into arg3. std::get<2>(*args)->resize(typed_parens.size()); std::copy(typed_parens.begin(), typed_parens.end(), std::get<2>(*args)->begin()); } void Replace(const std::vector> &pairs, MutableFstClass *ofst, std::vector> *parens, int64_t root, PdtParserType parser_type = PdtParserType::LEFT, int64_t start_paren_labels = kNoLabel, const std::string &left_paren_prefix = "(_", const std::string &right_paren_prefix = "_)"); using PdtReverseArgs = std::tuple> &, MutableFstClass *>; template void Reverse(PdtReverseArgs *args) { const Fst &fst = *(std::get<0>(*args).GetFst()); MutableFst *ofst = std::get<2>(*args)->GetMutableFst(); // In case Arc::Label is not the same as FstClass::Label, we make a // copy. Truncation may occur if FstClass::Label has more precision than // Arc::Label. std::vector> typed_parens( std::get<1>(*args).size()); std::copy(std::get<1>(*args).begin(), std::get<1>(*args).end(), typed_parens.begin()); Reverse(fst, typed_parens, ofst); } void Reverse(const FstClass &ifst, const std::vector> &, MutableFstClass *ofst); // PDT SHORTESTPATH struct PdtShortestPathOptions { QueueType queue_type; bool keep_parentheses; bool path_gc; explicit PdtShortestPathOptions(QueueType qt = FIFO_QUEUE, bool kp = false, bool gc = true) : queue_type(qt), keep_parentheses(kp), path_gc(gc) {} }; using PdtShortestPathArgs = std::tuple> &, MutableFstClass *, const PdtShortestPathOptions &>; template void ShortestPath(PdtShortestPathArgs *args) { const Fst &fst = *(std::get<0>(*args).GetFst()); MutableFst *ofst = std::get<2>(*args)->GetMutableFst(); const PdtShortestPathOptions &opts = std::get<3>(*args); // In case Arc::Label is not the same as FstClass::Label, we make a // copy. Truncation may occur if FstClass::Label has more precision than // Arc::Label. std::vector> typed_parens( std::get<1>(*args).size()); std::copy(std::get<1>(*args).begin(), std::get<1>(*args).end(), typed_parens.begin()); switch (opts.queue_type) { default: FSTERROR() << "Unknown queue type: " << opts.queue_type; [[fallthrough]]; case FIFO_QUEUE: { using Queue = FifoQueue; fst::PdtShortestPathOptions spopts(opts.keep_parentheses, opts.path_gc); ShortestPath(fst, typed_parens, ofst, spopts); return; } case LIFO_QUEUE: { using Queue = LifoQueue; fst::PdtShortestPathOptions spopts(opts.keep_parentheses, opts.path_gc); ShortestPath(fst, typed_parens, ofst, spopts); return; } case STATE_ORDER_QUEUE: { using Queue = StateOrderQueue; fst::PdtShortestPathOptions spopts(opts.keep_parentheses, opts.path_gc); ShortestPath(fst, typed_parens, ofst, spopts); return; } } } void ShortestPath( const FstClass &ifst, const std::vector> &parens, MutableFstClass *ofst, const PdtShortestPathOptions &opts = PdtShortestPathOptions()); // PRINT INFO using PdtInfoArgs = std::pair> &>; template void Info(PdtInfoArgs *args) { const Fst &fst = *(std::get<0>(*args).GetFst()); // In case Arc::Label is not the same as FstClass::Label, we make a // copy. Truncation may occur if FstClass::Label has more precision than // Arc::Label. std::vector> typed_parens( std::get<1>(*args).size()); std::copy(std::get<1>(*args).begin(), std::get<1>(*args).end(), typed_parens.begin()); PdtInfo pdtinfo(fst, typed_parens); pdtinfo.Print(); } void Info(const FstClass &ifst, const std::vector> &parens); } // namespace script } // namespace fst #define REGISTER_FST_PDT_OPERATIONS(ArcType) \ REGISTER_FST_OPERATION(PdtCompose, ArcType, PdtComposeArgs); \ REGISTER_FST_OPERATION(PdtExpand, ArcType, PdtExpandArgs); \ REGISTER_FST_OPERATION(PdtReplace, ArcType, PdtReplaceArgs); \ REGISTER_FST_OPERATION(PdtReverse, ArcType, PdtReverseArgs); \ REGISTER_FST_OPERATION(PdtShortestPath, ArcType, PdtShortestPathArgs); \ REGISTER_FST_OPERATION(PrintPdtInfo, ArcType, PrintPdtInfoArgs) #endif // FST_EXTENSIONS_PDT_PDTSCRIPT_H_