nlpia-rnn / app.py
hobs
load state_dict
cc9cfea
raw
history blame
6.99 kB
# app.py
import gradio as gr
import os
from pathlib import Path
# import random
# import time
import torch
import torch.nn as nn
import pandas as pd
from nlpia2.init import SRC_DATA_DIR, maybe_download
from nlpia2.string_normalizers import Asciifier, ASCII_NAME_CHARS
name_char_vocab_size = len(ASCII_NAME_CHARS) + 1 # Plus EOS marker
# Transcode Unicode str ASCII without embelishments, diacritics (https://stackoverflow.com/a/518232/2809427)
asciify = Asciifier(include=ASCII_NAME_CHARS)
def find_files(path, pattern):
return Path(path).glob(pattern)
# all_letters = ''.join(set(ASCII_NAME_CHARS).union(set(" .,;'")))
char2i = {c: i for i, c in enumerate(ASCII_NAME_CHARS)}
# !curl -O https://download.pytorch.org/tutorial/data.zip; unzip data.zip
print(f'asciify("O’Néàl") => {asciify("O’Néàl")}')
# Build the category_lines dictionary, a list of names per language
category_lines = {}
all_categories = []
labeled_lines = []
categories = []
for filepath in find_files(SRC_DATA_DIR / 'names', '*.txt'):
filename = Path(filepath).name
filepath = maybe_download(filename=Path('names') / filename)
with filepath.open() as fin:
lines = [asciify(line.rstrip()) for line in fin]
category = Path(filename).with_suffix('')
categories.append(category)
labeled_lines += list(zip(lines, [category] * len(lines)))
n_categories = len(categories)
df = pd.DataFrame(labeled_lines, columns=('name', 'category'))
def readLines(filename):
lines = open(filename, encoding='utf-8').read().strip().split('\n')
return [asciify(line) for line in lines]
for filename in find_files(path='data/names', pattern='*.txt'):
category = os.path.splitext(os.path.basename(filename))[0]
all_categories.append(category)
lines = readLines(filename)
category_lines[category] = lines
n_categories = len(all_categories)
######################################################################
# Now we have ``category_lines``, a dictionary mapping each category
# (language) to a list of lines (names). We also kept track of
# ``all_categories`` (just a list of languages) and ``n_categories`` for
# later reference.
#
print(category_lines['Italian'][:5])
######################################################################
# Turning Names into Tensors
# --------------------------
#
# Now that we have all the names organized, we need to turn them into
# Tensors to make any use of them.
#
# To represent a single letter, we use a "one-hot vector" of size
# ``<1 x n_letters>``. A one-hot vector is filled with 0s except for a 1
# at index of the current letter, e.g. ``"b" = <0 1 0 0 0 ...>``.
#
# To make a word we join a bunch of those into a 2D matrix
# ``<line_length x 1 x n_letters>``.
#
# That extra 1 dimension is because PyTorch assumes everything is in
# batches - we're just using a batch size of 1 here.
#
# Find letter index from all_letters, e.g. "a" = 0
def letterToIndex(c):
return char2i[c]
# Just for demonstration, turn a letter into a <1 x n_letters> Tensor
def encode_one_hot_vec(letter):
tensor = torch.zeros(1, len(ASCII_NAME_CHARS))
tensor[0][letterToIndex(letter)] = 1
return tensor
# Turn a line into a <line_length x 1 x n_letters>,
# or an array of one-hot letter vectors
def encode_one_hot_seq(line):
tensor = torch.zeros(len(line), 1, len(ASCII_NAME_CHARS))
for li, letter in enumerate(line):
tensor[li][0][letterToIndex(letter)] = 1
return tensor
print(encode_one_hot_vec('A'))
print(encode_one_hot_seq('Abe').size())
######################################################################
# Creating the Network
# ====================
#
# Before autograd, creating a recurrent neural network in Torch involved
# cloning the parameters of a layer over several timesteps. The layers
# held hidden state and gradients which are now entirely handled by the
# graph itself. This means you can implement a RNN in a very "pure" way,
# as regular feed-forward layers.
#
# This RNN module (mostly copied from `the PyTorch for Torch users
# tutorial <https://pytorch.org/tutorials/beginner/former_torchies/
# nn_tutorial.html#example-2-recurrent-net>`__)
# is just 2 linear layers which operate on an input and hidden state, with
# a LogSoftmax layer after the output.
#
# .. figure:: https://i.imgur.com/Z2xbySO.png
# :alt:
#
#
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, char_tens, hidden):
combined = torch.cat((char_tens, hidden), 1)
hidden = self.i2h(combined)
output = self.i2o(combined)
output = self.softmax(output)
return output, hidden
def initHidden(self):
return torch.zeros(1, self.hidden_size)
n_hidden = 128
rnn = RNN(len(ASCII_NAME_CHARS), n_hidden, n_categories)
input = encode_one_hot_vec('A')
hidden = torch.zeros(1, n_hidden)
output, next_hidden = rnn(input, hidden)
def categoryFromOutput(output):
top_n, top_i = output.topk(1)
category_i = top_i[0].item()
return all_categories[category_i], category_i
def output_from_str(s):
global rnn
input = encode_one_hot_seq(s)
hidden = torch.zeros(1, n_hidden)
output, next_hidden = rnn(input[0], hidden)
print(output)
return categoryFromOutput(output)
########################################
# load/save test for use on the huggingface spaces server
# torch.save(rnn.state_dict(), 'rnn_from_scratch_name_nationality.state_dict.pickle')
state_dict = torch.load('rnn_from_scratch_name_nationality.state_dict.pickle')
rnn.load_state_dict(state_dict)
def evaluate(line_tensor):
hidden = rnn.initHidden()
for i in range(line_tensor.size()[0]):
output, hidden = rnn(line_tensor[i], hidden)
return output
def predict(input_line, n_predictions=3):
print('\n> %s' % input_line)
with torch.no_grad():
output = evaluate(encode_one_hot_seq(input_line))
# Get top N categories
topv, topi = output.topk(n_predictions, 1, True)
predictions = []
for i in range(n_predictions):
value = topv[0][i].item()
category_index = topi[0][i].item()
print('(%.2f) %s' % (value, all_categories[category_index]))
predictions.append([value, all_categories[category_index]])
predict('Dovesky')
predict('Jackson')
predict('Satoshi')
# load/save test for use on the huggingface spaces server
########################################
def greet_nationality(name):
nationality = predict(name)
return f"Hello {name}!!\n Your name seems to be from {nationality}. Am I right?"
iface = gr.Interface(fn=greet_nationality, inputs="text", outputs="text")
iface.launch()