fasttext-language-identification / custom_fasttext.py
Hiveurban's picture
Upload model
93b5b4d verified
raw
history blame
1.49 kB
import fasttext
from typing import List
import torch
from torch import nn
from transformers import PretrainedConfig
from transformers import PreTrainedModel
from huggingface_hub import hf_hub_download
class FastTextConfig(PretrainedConfig):
model_type = "fasttext-language-identification"
def __init__(
self,
repo_id: str = "facebook/fasttext-language-identification",
top_k: int = 1,
**kwargs
):
self.repo_id = repo_id
self.top_k = top_k
super().__init__(**kwargs)
class FastTextModel(PreTrainedModel):
config_class = FastTextConfig
def __init__(self, config):
super().__init__(config)
self.model = FastText(config.repo_id)
def forward(self, words: List[str], k=1) -> List[str]:
return self.model(words, k=k)
class FastText(nn.Module):
def __init__(self, repo_id: str, filename: str = "model.bin", *args, **kwargs) -> None:
super(FastText, self).__init__()
self.ft = fasttext.load_model(
hf_hub_download(repo_id=repo_id, filename=filename)
)
word_vectors = torch.from_numpy(self.ft.get_input_matrix())
num_embeddings = word_vectors.size(0) # vocabulary size
embedding_dim = word_vectors.size(1) # embedding size
self.embeddings = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
def forward(self, text: str, k=1) -> List[str]:
return self.ft.predict(text, k=k)