Spaces:
Running
Running
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| from collections import defaultdict, OrderedDict | |
| import os | |
| from typing import Any, Callable, Dict, Iterable, List, Set | |
| def namespace_match(pattern: str, namespace: str): | |
| """ | |
| Matches a namespace pattern against a namespace string. For example, ``*tags`` matches | |
| ``passage_tags`` and ``question_tags`` and ``tokens`` matches ``tokens`` but not | |
| ``stemmed_tokens``. | |
| """ | |
| if pattern[0] == '*' and namespace.endswith(pattern[1:]): | |
| return True | |
| elif pattern == namespace: | |
| return True | |
| return False | |
| class _NamespaceDependentDefaultDict(defaultdict): | |
| def __init__(self, | |
| non_padded_namespaces: Set[str], | |
| padded_function: Callable[[], Any], | |
| non_padded_function: Callable[[], Any]) -> None: | |
| self._non_padded_namespaces = set(non_padded_namespaces) | |
| self._padded_function = padded_function | |
| self._non_padded_function = non_padded_function | |
| super(_NamespaceDependentDefaultDict, self).__init__() | |
| def __missing__(self, key: str): | |
| if any(namespace_match(pattern, key) for pattern in self._non_padded_namespaces): | |
| value = self._non_padded_function() | |
| else: | |
| value = self._padded_function() | |
| dict.__setitem__(self, key, value) | |
| return value | |
| def add_non_padded_namespaces(self, non_padded_namespaces: Set[str]): | |
| # add non_padded_namespaces which weren't already present | |
| self._non_padded_namespaces.update(non_padded_namespaces) | |
| class _TokenToIndexDefaultDict(_NamespaceDependentDefaultDict): | |
| def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None: | |
| super(_TokenToIndexDefaultDict, self).__init__(non_padded_namespaces, | |
| lambda: {padding_token: 0, oov_token: 1}, | |
| lambda: {}) | |
| class _IndexToTokenDefaultDict(_NamespaceDependentDefaultDict): | |
| def __init__(self, non_padded_namespaces: Set[str], padding_token: str, oov_token: str) -> None: | |
| super(_IndexToTokenDefaultDict, self).__init__(non_padded_namespaces, | |
| lambda: {0: padding_token, 1: oov_token}, | |
| lambda: {}) | |
| DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels") | |
| DEFAULT_PADDING_TOKEN = '[PAD]' | |
| DEFAULT_OOV_TOKEN = '[UNK]' | |
| NAMESPACE_PADDING_FILE = 'non_padded_namespaces.txt' | |
| class Vocabulary(object): | |
| def __init__(self, non_padded_namespaces: Iterable[str] = DEFAULT_NON_PADDED_NAMESPACES): | |
| self._non_padded_namespaces = set(non_padded_namespaces) | |
| self._padding_token = DEFAULT_PADDING_TOKEN | |
| self._oov_token = DEFAULT_OOV_TOKEN | |
| self._token_to_index = _TokenToIndexDefaultDict(self._non_padded_namespaces, | |
| self._padding_token, | |
| self._oov_token) | |
| self._index_to_token = _IndexToTokenDefaultDict(self._non_padded_namespaces, | |
| self._padding_token, | |
| self._oov_token) | |
| def add_token_to_namespace(self, token: str, namespace: str = 'tokens') -> int: | |
| if token not in self._token_to_index[namespace]: | |
| index = len(self._token_to_index[namespace]) | |
| self._token_to_index[namespace][token] = index | |
| self._index_to_token[namespace][index] = token | |
| return index | |
| else: | |
| return self._token_to_index[namespace][token] | |
| def get_index_to_token_vocabulary(self, namespace: str = 'tokens') -> Dict[int, str]: | |
| return self._index_to_token[namespace] | |
| def get_token_to_index_vocabulary(self, namespace: str = 'tokens') -> Dict[str, int]: | |
| return self._token_to_index[namespace] | |
| def get_token_index(self, token: str, namespace: str = 'tokens') -> int: | |
| if token in self._token_to_index[namespace]: | |
| return self._token_to_index[namespace][token] | |
| else: | |
| return self._token_to_index[namespace][self._oov_token] | |
| def get_token_from_index(self, index: int, namespace: str = 'tokens'): | |
| return self._index_to_token[namespace][index] | |
| def get_vocab_size(self, namespace: str = 'tokens') -> int: | |
| return len(self._token_to_index[namespace]) | |
| def save_to_files(self, directory: str): | |
| os.makedirs(directory, exist_ok=True) | |
| with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'w', encoding='utf-8') as f: | |
| for namespace_str in self._non_padded_namespaces: | |
| f.write('{}\n'.format(namespace_str)) | |
| for namespace, token_to_index in self._token_to_index.items(): | |
| filename = os.path.join(directory, '{}.txt'.format(namespace)) | |
| with open(filename, 'w', encoding='utf-8') as f: | |
| for token, _ in token_to_index.items(): | |
| f.write('{}\n'.format(token)) | |
| def from_files(cls, directory: str) -> 'Vocabulary': | |
| with open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'r', encoding='utf-8') as f: | |
| non_padded_namespaces = [namespace_str.strip() for namespace_str in f] | |
| vocab = cls(non_padded_namespaces=non_padded_namespaces) | |
| for namespace_filename in os.listdir(directory): | |
| if namespace_filename == NAMESPACE_PADDING_FILE: | |
| continue | |
| if namespace_filename.startswith("."): | |
| continue | |
| namespace = namespace_filename.replace('.txt', '') | |
| if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces): | |
| is_padded = False | |
| else: | |
| is_padded = True | |
| filename = os.path.join(directory, namespace_filename) | |
| vocab.set_from_file(filename, is_padded, namespace=namespace) | |
| return vocab | |
| def set_from_file(self, | |
| filename: str, | |
| is_padded: bool = True, | |
| oov_token: str = DEFAULT_OOV_TOKEN, | |
| namespace: str = "tokens" | |
| ): | |
| if is_padded: | |
| self._token_to_index[namespace] = {self._padding_token: 0} | |
| self._index_to_token[namespace] = {0: self._padding_token} | |
| else: | |
| self._token_to_index[namespace] = {} | |
| self._index_to_token[namespace] = {} | |
| with open(filename, 'r', encoding='utf-8') as f: | |
| index = 1 if is_padded else 0 | |
| for row in f: | |
| token = str(row).strip() | |
| if token == oov_token: | |
| token = self._oov_token | |
| self._token_to_index[namespace][token] = index | |
| self._index_to_token[namespace][index] = token | |
| index += 1 | |
| def convert_tokens_to_ids(self, tokens: List[str], namespace: str = "tokens"): | |
| result = list() | |
| for token in tokens: | |
| idx = self._token_to_index[namespace].get(token) | |
| if idx is None: | |
| idx = self._token_to_index[namespace][self._oov_token] | |
| result.append(idx) | |
| return result | |
| def convert_ids_to_tokens(self, ids: List[int], namespace: str = "tokens"): | |
| result = list() | |
| for idx in ids: | |
| idx = self._index_to_token[namespace][idx] | |
| result.append(idx) | |
| return result | |
| def pad_or_truncate_ids_by_max_length(self, ids: List[int], max_length: int, namespace: str = "tokens"): | |
| pad_idx = self._token_to_index[namespace][self._padding_token] | |
| length = len(ids) | |
| if length > max_length: | |
| result = ids[:max_length] | |
| else: | |
| result = ids + [pad_idx] * (max_length - length) | |
| return result | |
| def demo1(): | |
| import jieba | |
| vocabulary = Vocabulary() | |
| vocabulary.add_token_to_namespace('白天', 'tokens') | |
| vocabulary.add_token_to_namespace('晚上', 'tokens') | |
| text = '不是在白天, 就是在晚上' | |
| tokens = jieba.lcut(text) | |
| print(tokens) | |
| ids = vocabulary.convert_tokens_to_ids(tokens) | |
| print(ids) | |
| padded_idx = vocabulary.pad_or_truncate_ids_by_max_length(ids, 10) | |
| print(padded_idx) | |
| tokens = vocabulary.convert_ids_to_tokens(padded_idx) | |
| print(tokens) | |
| return | |
| if __name__ == '__main__': | |
| demo1() | |