Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import fasttext | |
class SimpleMultilingualClassifier(nn.Module): | |
def __init__(self, embedding_files, num_classes, embedding_dim=100): | |
super().__init__() | |
self.embedding_files = embedding_files | |
self.embedding_dim = embedding_dim | |
self.linear = nn.Linear(embedding_dim, num_classes) | |
self.language_models = {} | |
for lang, path in embedding_files.items(): | |
self.language_models[lang] = fasttext.load_model(path) | |
def get_embedding(self, text, lang): | |
if lang in self.language_models: | |
return torch.tensor(self.language_models[lang].get_sentence_vector(text)) | |
else: | |
raise ValueError(f"Language '{lang}' not supported.") | |
def forward(self, text, lang): | |
embedding = self.get_embedding(text, lang) | |
return self.linear(embedding) | |
def predict(self, text, lang, class_labels): | |
self.eval() | |
with torch.no_grad(): | |
output = self.forward(text, lang).unsqueeze(0) # Add batch dimension | |
probabilities = torch.softmax(output, dim=-1) | |
predicted_class_index = torch.argmax(probabilities, dim=-1).item() | |
return class_labels[predicted_class_index] | |
# Example usage (you'd need to define your classes and supported languages) | |
if __name__ == '__main__': | |
embedding_files = { | |
'en': 'fasttext_embeddings/cc.en.100.bin', | |
'fr': 'fasttext_embeddings/cc.fr.100.bin' | |
} | |
num_classes = 3 # Example number of classes | |
class_labels = ["positive", "negative", "neutral"] | |
model = SimpleMultilingualClassifier(embedding_files, num_classes) | |
# Dummy prediction | |
text_en = "This is a great movie." | |
lang_en = 'en' | |
prediction_en = model.predict(text_en, lang_en, class_labels) | |
print(f"English Prediction: {prediction_en}") | |
text_fr = "C'est un film incroyable." | |
lang_fr = 'fr' | |
prediction_fr = model.predict(text_fr, lang_fr, class_labels) | |
print(f"French Prediction: {prediction_fr}") |