// Copyright 2016 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License.! #ifndef BPE_MODEL_TRAINER_H_ #define BPE_MODEL_TRAINER_H_ #include #include #include #include #include "sentencepiece_model.pb.h" #include "third_party/absl/container/btree_set.h" #include "third_party/absl/container/flat_hash_map.h" #include "trainer_interface.h" namespace sentencepiece { namespace bpe { // Trainer class for BPE model. class Trainer : public TrainerInterface { public: Trainer(const TrainerSpec &trainer_spec, const NormalizerSpec &normalizer_spec, const NormalizerSpec &denormalizer_spec) : TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec, denormalizer_spec) {} util::Status Train() override; private: // Symbol represents a character or symbol bigram. struct Symbol { const Symbol *left; // left symbol in bigram const Symbol *right; // right symbol in bigram string_util::UnicodeText chars; // all flattend chracter sequence bool is_unk; // true if this symbol is unknown. uint64_t fp; // fingerprint of this symbol. uint64_t freq; // frequency of this symbol. // Position list. Use set so that we can keep the order of occurrence. // See EncodePos/DecodePos. absl::btree_set positions; bool IsBigram() const { return left != nullptr && right != nullptr; } std::string ToString() const; Symbol() : left(nullptr), right(nullptr), is_unk(false), fp(0), freq(0) {} }; struct Position { int sid; // sentence id int left; // left symbol index int right; // right symbol index }; // Encodes sid, left and right bigram index into uint64_t. // Encoded value keeps the order of sid, left and right. static uint64_t EncodePos(int sid, int l, int r) { CHECK_GE(l, 0); CHECK_GE(r, 0); CHECK_LE(l, std::numeric_limits::max()); CHECK_LE(r, std::numeric_limits::max()); const uint64_t n = (static_cast(sid) << 32) | (static_cast(l) << 16) | r; return n; } // Decodes sid, left and right bigram index from uint64_t. static Position DecodePos(uint64_t n) { Position p; p.sid = n >> 32; p.left = (n >> 16) & 0xffff; p.right = n & 0xffff; return p; } // Gets unary (character) symbol from the char code |c|. // The return value is cached. Symbol *GetCharSymbol(char32 c); // Gets symbol pair from left/right symbols. The return value is cached. Symbol *GetPairSymbol(const Symbol *left, const Symbol *right); // Computes the frequency of |symbol| and update symbol->freq field. void ComputeFreq(Symbol *symbol) const; // Returns the valid index before symbols_[sid][index]. int GetNextIndex(int sid, int index) const; // Returns the valid index after symbols_[sid][index]. int GetPrevIndex(int sid, int index) const; // Makes a new bigram from [symbols_[sid][left], symbols_[sid][right]] and // Adds it to symbols_cache_ and active_symbols_. void AddNewPair(int sid, int left, int right); // Resets the fequency of bigram [symbols_[sid][left] symbols_[sid][right]], // if this bigram is not |best|. void ResetFreq(int sid, int left, int right, const Symbol *best); // Updates |active_symbols_| by copying the top 5% frequent symbols in // symbols_cache_. void UpdateActiveSymbols(); // All unique symbols. Key is a fingerprint of Symbol. absl::flat_hash_map symbols_cache_; // Set of symbols from which we find the best symbol in each iteration. absl::btree_set active_symbols_; // Stores symbols allocated in heap so that we can delete them at onece. std::vector allocated_; // Sentences. symbols_[sid][index] stores a symbol in sentence_[sid][index]. std::vector> symbols_; }; } // namespace bpe } // namespace sentencepiece #endif // BPE_MODEL_TRAINER_H_