// Copyright 2005-2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the 'License'); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an 'AS IS' BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Scripting API support for FarReader and FarWriter. #ifndef FST_EXTENSIONS_FAR_FAR_CLASS_H_ #define FST_EXTENSIONS_FAR_FAR_CLASS_H_ #include #include #include #include #include #include #include #include #include #include #include namespace fst { namespace script { // FarReader API. // Virtual interface implemented by each concrete FarReaderImpl. // See the FarReader interface in far.h for the exact semantics. class FarReaderImplBase { public: virtual const std::string &ArcType() const = 0; virtual bool Done() const = 0; virtual bool Error() const = 0; virtual const std::string &GetKey() const = 0; virtual const FstClass *GetFstClass() const = 0; virtual bool Find(const std::string &key) = 0; virtual void Next() = 0; virtual void Reset() = 0; virtual FarType Type() const = 0; virtual ~FarReaderImplBase() = default; }; // Templated implementation. template class FarReaderClassImpl : public FarReaderImplBase { public: explicit FarReaderClassImpl(std::string_view source) : reader_(FarReader::Open(source)) {} explicit FarReaderClassImpl(const std::vector &sources) : reader_(FarReader::Open(sources)) {} const std::string &ArcType() const final { return Arc::Type(); } bool Done() const final { return reader_->Done(); } bool Error() const final { return reader_->Error(); } bool Find(const std::string &key) final { return reader_->Find(key); } const FstClass *GetFstClass() const final { fstc_ = std::make_unique(*reader_->GetFst()); return fstc_.get(); } const std::string &GetKey() const final { return reader_->GetKey(); } void Next() final { return reader_->Next(); } void Reset() final { reader_->Reset(); } FarType Type() const final { return reader_->Type(); } const FarReader *GetFarReader() const { return reader_.get(); } FarReader *GetFarReader() { return reader_.get(); } private: std::unique_ptr> reader_; mutable std::unique_ptr fstc_; }; class FarReaderClass; using OpenFarReaderClassArgs = WithReturnValue, const std::vector &>; // Untemplated user-facing class holding a templated pimpl. class FarReaderClass { public: const std::string &ArcType() const { return impl_->ArcType(); } bool Done() const { return impl_->Done(); } // Returns True if the impl is null (i.e., due to read failure). // Attempting to call any other function will result in null dereference. bool Error() const { return (impl_) ? impl_->Error() : true; } bool Find(const std::string &key) { return impl_->Find(key); } const FstClass *GetFstClass() const { return impl_->GetFstClass(); } const std::string &GetKey() const { return impl_->GetKey(); } void Next() { impl_->Next(); } void Reset() { impl_->Reset(); } FarType Type() const { return impl_->Type(); } template const FarReader *GetFarReader() const { if (Arc::Type() != ArcType()) return nullptr; const FarReaderClassImpl *typed_impl = down_cast *>(impl_.get()); return typed_impl->GetFarReader(); } template FarReader *GetFarReader() { if (Arc::Type() != ArcType()) return nullptr; FarReaderClassImpl *typed_impl = down_cast *>(impl_.get()); return typed_impl->GetFarReader(); } template friend void OpenFarReaderClass(OpenFarReaderClassArgs *args); // Defined in the CC. static std::unique_ptr Open( std::string_view source); static std::unique_ptr Open( const std::vector &sources); private: template explicit FarReaderClass(std::unique_ptr> impl) : impl_(std::move(impl)) {} std::unique_ptr impl_; }; // These exist solely for registration purposes; users should call the // static method FarReaderClass::Open instead. template void OpenFarReaderClass(OpenFarReaderClassArgs *args) { auto impl = std::make_unique>(args->args); if (impl->GetFarReader() == nullptr) { // Underlying reader failed to open, so return failure here, too. args->retval = nullptr; } else { args->retval = fst::WrapUnique(new FarReaderClass(std::move(impl))); } } // FarWriter API. // Virtual interface implemented by each concrete FarWriterImpl. class FarWriterImplBase { public: // Unlike the lower-level library, this returns a boolean to signal failure // due to non-conformant arc types. virtual bool Add(const std::string &key, const FstClass &fst) = 0; virtual const std::string &ArcType() const = 0; virtual bool Error() const = 0; virtual FarType Type() const = 0; virtual ~FarWriterImplBase() = default; }; // Templated implementation. template class FarWriterClassImpl : public FarWriterImplBase { public: explicit FarWriterClassImpl(std::string_view source, FarType type = FarType::DEFAULT) : writer_(FarWriter::Create(source, type)) {} bool Add(const std::string &key, const FstClass &fst) final { if (ArcType() != fst.ArcType()) { FSTERROR() << "Cannot write FST with " << fst.ArcType() << " arcs to " << "FAR with " << ArcType() << " arcs"; return false; } writer_->Add(key, *(fst.GetFst())); return true; } const std::string &ArcType() const final { return Arc::Type(); } bool Error() const final { return writer_->Error(); } FarType Type() const final { return writer_->Type(); } const FarWriter *GetFarWriter() const { return writer_.get(); } FarWriter *GetFarWriter() { return writer_.get(); } private: std::unique_ptr> writer_; }; class FarWriterClass; using CreateFarWriterClassInnerArgs = std::pair; using CreateFarWriterClassArgs = WithReturnValue, CreateFarWriterClassInnerArgs>; // Untemplated user-facing class holding a templated pimpl. class FarWriterClass { public: static std::unique_ptr Create( const std::string &source, const std::string &arc_type, FarType type = FarType::DEFAULT); bool Add(const std::string &key, const FstClass &fst) { return impl_->Add(key, fst); } // Returns True if the impl is null (i.e., due to construction failure). // Attempting to call any other function will result in null dereference. bool Error() const { return (impl_) ? impl_->Error() : true; } const std::string &ArcType() const { return impl_->ArcType(); } FarType Type() const { return impl_->Type(); } template const FarWriter *GetFarWriter() const { if (Arc::Type() != ArcType()) return nullptr; const FarWriterClassImpl *typed_impl = down_cast *>(impl_.get()); return typed_impl->GetFarWriter(); } template FarWriter *GetFarWriter() { if (Arc::Type() != ArcType()) return nullptr; FarWriterClassImpl *typed_impl = down_cast *>(impl_.get()); return typed_impl->GetFarWriter(); } template friend void CreateFarWriterClass(CreateFarWriterClassArgs *args); private: template explicit FarWriterClass(std::unique_ptr> impl) : impl_(std::move(impl)) {} std::unique_ptr impl_; }; // This exists solely for registration purposes; users should call the // static method FarWriterClass::Create instead. template void CreateFarWriterClass(CreateFarWriterClassArgs *args) { args->retval = fst::WrapUnique( new FarWriterClass(std::make_unique>( std::get<0>(args->args), std::get<1>(args->args)))); } } // namespace script } // namespace fst #endif // FST_EXTENSIONS_FAR_FAR_CLASS_H_