Spaces:
Sleeping
Sleeping
| import unicodedata | |
| import glob | |
| import torch | |
| import random | |
| ALL_LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz .,;'-" | |
| N_LETTERS = len(ALL_LETTERS) | |
| def load_data(): | |
| category_lines = {} | |
| all_categories = [] | |
| def find_files(path): | |
| return glob.glob(path) | |
| def read_lines(filename): | |
| lines = open(filename, encoding='utf-8').read().strip().split('\n') | |
| return [unicode_to_ascii(line) for line in lines] | |
| for filename in find_files('data/names/*.txt'): | |
| category = filename.split('/')[-1].split('.')[0] | |
| all_categories.append(category) | |
| lines = read_lines(filename) | |
| category_lines[category] = lines | |
| return category_lines, all_categories | |
| def letter_to_index(letter): | |
| return ALL_LETTERS.find(letter) | |
| def letter_to_tensor(letter): | |
| tensor = torch.zeros(1,N_LETTERS) | |
| tensor[0][letter_to_index(letter)] = 1 | |
| return tensor | |
| def line_to_tensor(line): | |
| tensor = torch.zeros(len(line), 1, N_LETTERS) | |
| for i, letter in enumerate(line): | |
| tensor[i][0][letter_to_index(letter)] = 1 | |
| return tensor | |
| def random_training_example(category_lines, all_categories): | |
| def random_choice(a): | |
| random_idx = random.randint(0, len(a) - 1) | |
| return a[random_idx] | |
| category = random_choice(all_categories) | |
| line = random_choice(category_lines[category]) | |
| category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long) | |
| line_tensor = line_to_tensor(line) | |
| return category, line, category_tensor, line_tensor | |
| # Turn unicode string to plain ASCII | |
| def unicode_to_ascii(s): | |
| return ''.join( | |
| c for c in unicodedata.normalize('NFD', s) | |
| if unicodedata.category(c) != 'Mn' | |
| and c in ALL_LETTERS | |
| ) | |
| if __name__ == '__main__': | |
| print(unicode_to_ascii("O'Néàl")) |