// Copyright 2018 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.! #ifndef SENTENCEPIECE_TRAINER_H_ #define SENTENCEPIECE_TRAINER_H_ #include #include #include #include "sentencepiece_processor.h" namespace sentencepiece { class TrainerSpec; class NormalizerSpec; class ModelProto; namespace pretokenizer { class PretokenizerForTrainingInterface; } // namespace pretokenizer namespace normalizer { class Normalizer; } // namespace normalizer // Iterator over the training sentences. // Training sentences are loaded sequentially as follows: // // for (; !it.done(); it.Next()) { // const std::string &s = it.value(); // } // RETURN_IF_ERROR(it.status()); // class SentenceIterator { public: virtual ~SentenceIterator() {} // Returns true if iteration finishes (including error case). // Uses SentenceIterator::status() method to know whether // all sentences are loaded successfully. virtual bool done() const = 0; virtual void Next() = 0; virtual const std::string &value() const = 0; virtual util::Status status() const = 0; }; class SentencePieceTrainer { public: // Trains SentencePiece model with `trainer_spec`. // Default `normalizer_spec` is used. // When `sentence_iterator` is passed, load sentences from the iterator. static util::Status Train(const TrainerSpec &trainer_spec, SentenceIterator *sentence_iterator = nullptr, std::string *serialized_model_proto = nullptr); // Trains SentencePiece model with `trainer_spec` and // `normalizer_spec`. // When `sentence_iterator` is passed, load sentences from the iterator. static util::Status Train(const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec, SentenceIterator *sentence_iterator = nullptr, std::string *serialized_model_proto = nullptr); // Trains SentencePiece model with `trainer_spec`, `normalizer_spec` // and `denormalizer_spec`. // When `sentence_iterator` is passed, load sentences from the iterator. static util::Status Train(const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec, const NormalizerSpec &denormalizer_spec, SentenceIterator *sentence_iterator = nullptr, std::string *serialized_model_proto = nullptr); // Trains SentencePiece model with command-line string in `args`, // e.g., // '--input=data --model_prefix=m --vocab_size=8192 model_type=unigram' // When `sentence_iterator` is passed, load sentences from the iterator. static util::Status Train(absl::string_view args, SentenceIterator *sentence_iterator = nullptr, std::string *serialized_model_proto = nullptr); // Trains SentencePiece model with mapin `kwargs`. // e.g., {{"input", "data"}, {"model_prefix, "m"}, {"vocab_size", "8192"}...} static util::Status Train( const std::unordered_map &kwargs, SentenceIterator *sentence_iterator = nullptr, std::string *serialized_model_proto = nullptr); // The same as above, but passes the list of sentences. static util::Status Train(absl::string_view args, const std::vector &sentences, std::string *serialized_model_proto = nullptr); // The same as above, but passes the list of sentences. static util::Status Train( const std::unordered_map &kwargs, const std::vector &sentences, std::string *serialized_model_proto = nullptr); // Handy function to make a normalizer spec from the pre-compiled // normalization name. Do not use this method in production as it crashes // When `name` is invalid. Useful for unittesting. static NormalizerSpec GetNormalizerSpec(absl::string_view name); // Populates necessary fields (precompiled_charmap) from // `NormalizerSpec::name` or `NormalizerSpec::normalization_rule_tsv`. static util::Status PopulateNormalizerSpec(NormalizerSpec *normalizer_spec, bool is_denormalizer = false); // Overrides `trainer_spec`, `normalizer_spec`, `denormalizer_spec` with the // std::unordered_map in `kargs`. static util::Status MergeSpecsFromArgs( const std::unordered_map &kwargs, TrainerSpec *trainer_spec, NormalizerSpec *normalizer_spec, NormalizerSpec *denormalizer_spec); // Overrides `trainer_spec`, `normalizer_spec`, `denormalizer_spec` with the // command line flags in `args`. static util::Status MergeSpecsFromArgs(absl::string_view args, TrainerSpec *trainer_spec, NormalizerSpec *normalizer_spec, NormalizerSpec *denormalizer_spec); // Injects global pre-tokenizer that are applied in training time. // Pretokenizer is only used for extracting pieces. // TODO(taku): It would be better to inject per `trainer_spec`. static util::Status SetPretokenizerForTraining( const pretokenizer::PretokenizerForTrainingInterface *pretokenizer); // Returns the current pretokenizer. if no pretokenizer is defined, returns // nullptr. static const pretokenizer::PretokenizerForTrainingInterface * GetPretokenizerForTraining(); // Helper function to set `field_name=value` in `message`. // When `field_name` is repeated, multiple values can be passed // with comma-separated values. `field_name` must not be a nested message. // The body of these functions are automatically generated with // data/gen_spec_parser.pl static util::Status SetProtoField(absl::string_view name, absl::string_view value, TrainerSpec *message); static util::Status SetProtoField(absl::string_view name, absl::string_view value, NormalizerSpec *message); // Populates model type from string representation, e.g., "bpe". // Supported model: "unigram", "bpe", "word", "char". static util::Status PopulateModelTypeFromString(absl::string_view type, TrainerSpec *trainer_spec); private: SentencePieceTrainer() {} ~SentencePieceTrainer() {} }; class SentencePieceNormalizer { public: SentencePieceNormalizer(); virtual ~SentencePieceNormalizer(); virtual util::Status Load(std::unique_ptr model_proto); virtual util::Status Load(absl::string_view filename); virtual util::Status LoadFromSerializedProto(absl::string_view serialized); virtual util::Status LoadFromRuleTSV(absl::string_view filename); virtual util::Status LoadFromRuleName(absl::string_view name); virtual util::Status Normalize(absl::string_view input, std::string *normalized) const; virtual util::Status Normalize(absl::string_view input, std::string *normalized, std::vector *norm_to_orig) const; virtual std::string Normalize(absl::string_view input) const; virtual NormalizerSpec *mutable_normalizer_spec() const; virtual std::string serialized_model_proto() const; private: std::unique_ptr normalizer_; std::unique_ptr model_proto_; }; // Converts the utf8 byte spans into Unicode char span. void ConvertToUnicodeAlignment(absl::string_view orig, absl::string_view norm, std::vector *norm_to_orig); } // namespace sentencepiece #endif // SENTENCEPIECE_TRAINER_H_