|
import torch |
|
from torch import nn, Tensor |
|
from typing import Union, Tuple, List, Iterable, Dict |
|
import logging |
|
import gzip |
|
from tqdm import tqdm |
|
import numpy as np |
|
import os |
|
import json |
|
from ..util import import_from_string, fullname, http_get |
|
from .tokenizer import WordTokenizer, WhitespaceTokenizer |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class WordEmbeddings(nn.Module): |
|
def __init__(self, tokenizer: WordTokenizer, embedding_weights, update_embeddings: bool = False, max_seq_length: int = 1000000): |
|
nn.Module.__init__(self) |
|
if isinstance(embedding_weights, list): |
|
embedding_weights = np.asarray(embedding_weights) |
|
|
|
if isinstance(embedding_weights, np.ndarray): |
|
embedding_weights = torch.from_numpy(embedding_weights) |
|
|
|
num_embeddings, embeddings_dimension = embedding_weights.size() |
|
self.embeddings_dimension = embeddings_dimension |
|
self.emb_layer = nn.Embedding(num_embeddings, embeddings_dimension) |
|
self.emb_layer.load_state_dict({'weight': embedding_weights}) |
|
self.emb_layer.weight.requires_grad = update_embeddings |
|
self.tokenizer = tokenizer |
|
self.update_embeddings = update_embeddings |
|
self.max_seq_length = max_seq_length |
|
|
|
def forward(self, features): |
|
token_embeddings = self.emb_layer(features['input_ids']) |
|
cls_tokens = None |
|
features.update({'token_embeddings': token_embeddings, 'cls_token_embeddings': cls_tokens, 'attention_mask': features['attention_mask']}) |
|
return features |
|
|
|
def tokenize(self, texts: List[str]): |
|
tokenized_texts = [self.tokenizer.tokenize(text) for text in texts] |
|
sentence_lengths = [len(tokens) for tokens in tokenized_texts] |
|
max_len = max(sentence_lengths) |
|
|
|
input_ids = [] |
|
attention_masks = [] |
|
for tokens in tokenized_texts: |
|
padding = [0] * (max_len - len(tokens)) |
|
input_ids.append(tokens + padding) |
|
attention_masks.append([1]*len(tokens) + padding) |
|
|
|
output = {'input_ids': torch.tensor(input_ids, dtype=torch.long), |
|
'attention_mask': torch.tensor(attention_masks, dtype=torch.long), |
|
'sentence_lengths': torch.tensor(sentence_lengths, dtype=torch.long)} |
|
|
|
return output |
|
|
|
|
|
|
|
def get_word_embedding_dimension(self) -> int: |
|
return self.embeddings_dimension |
|
|
|
def save(self, output_path: str): |
|
with open(os.path.join(output_path, 'wordembedding_config.json'), 'w') as fOut: |
|
json.dump(self.get_config_dict(), fOut, indent=2) |
|
|
|
torch.save(self.state_dict(), os.path.join(output_path, 'pytorch_model.bin')) |
|
self.tokenizer.save(output_path) |
|
|
|
def get_config_dict(self): |
|
return {'tokenizer_class': fullname(self.tokenizer), 'update_embeddings': self.update_embeddings, 'max_seq_length': self.max_seq_length} |
|
|
|
@staticmethod |
|
def load(input_path: str): |
|
with open(os.path.join(input_path, 'wordembedding_config.json'), 'r') as fIn: |
|
config = json.load(fIn) |
|
|
|
tokenizer_class = import_from_string(config['tokenizer_class']) |
|
tokenizer = tokenizer_class.load(input_path) |
|
weights = torch.load(os.path.join(input_path, 'pytorch_model.bin'), map_location=torch.device('cpu')) |
|
embedding_weights = weights['emb_layer.weight'] |
|
model = WordEmbeddings(tokenizer=tokenizer, embedding_weights=embedding_weights, update_embeddings=config['update_embeddings']) |
|
return model |
|
|
|
@staticmethod |
|
def from_text_file(embeddings_file_path: str, update_embeddings: bool = False, item_separator: str = " ", tokenizer=WhitespaceTokenizer(), max_vocab_size: int = None): |
|
logger.info("Read in embeddings file {}".format(embeddings_file_path)) |
|
|
|
if not os.path.exists(embeddings_file_path): |
|
logger.info("{} does not exist, try to download from server".format(embeddings_file_path)) |
|
|
|
if '/' in embeddings_file_path or '\\' in embeddings_file_path: |
|
raise ValueError("Embeddings file not found: ".format(embeddings_file_path)) |
|
|
|
url = "https://public.ukp.informatik.tu-darmstadt.de/reimers/embeddings/"+embeddings_file_path |
|
http_get(url, embeddings_file_path) |
|
|
|
embeddings_dimension = None |
|
vocab = [] |
|
embeddings = [] |
|
|
|
with gzip.open(embeddings_file_path, "rt", encoding="utf8") if embeddings_file_path.endswith('.gz') else open(embeddings_file_path, encoding="utf8") as fIn: |
|
iterator = tqdm(fIn, desc="Load Word Embeddings", unit="Embeddings") |
|
for line in iterator: |
|
split = line.rstrip().split(item_separator) |
|
word = split[0] |
|
|
|
if embeddings_dimension == None: |
|
embeddings_dimension = len(split) - 1 |
|
vocab.append("PADDING_TOKEN") |
|
embeddings.append(np.zeros(embeddings_dimension)) |
|
|
|
if (len(split) - 1) != embeddings_dimension: |
|
logger.error("ERROR: A line in the embeddings file had more or less dimensions than expected. Skip token.") |
|
continue |
|
|
|
vector = np.array([float(num) for num in split[1:]]) |
|
embeddings.append(vector) |
|
vocab.append(word) |
|
|
|
if max_vocab_size is not None and max_vocab_size > 0 and len(vocab) > max_vocab_size: |
|
break |
|
|
|
embeddings = np.asarray(embeddings) |
|
|
|
tokenizer.set_vocab(vocab) |
|
return WordEmbeddings(tokenizer=tokenizer, embedding_weights=embeddings, update_embeddings=update_embeddings) |
|
|
|
|