SivaMallikarjun's picture
Create model.py
76aeebf verified
import torch
import torch.nn as nn
import fasttext
class SimpleMultilingualClassifier(nn.Module):
def __init__(self, embedding_files, num_classes, embedding_dim=100):
super().__init__()
self.embedding_files = embedding_files
self.embedding_dim = embedding_dim
self.linear = nn.Linear(embedding_dim, num_classes)
self.language_models = {}
for lang, path in embedding_files.items():
self.language_models[lang] = fasttext.load_model(path)
def get_embedding(self, text, lang):
if lang in self.language_models:
return torch.tensor(self.language_models[lang].get_sentence_vector(text))
else:
raise ValueError(f"Language '{lang}' not supported.")
def forward(self, text, lang):
embedding = self.get_embedding(text, lang)
return self.linear(embedding)
def predict(self, text, lang, class_labels):
self.eval()
with torch.no_grad():
output = self.forward(text, lang).unsqueeze(0) # Add batch dimension
probabilities = torch.softmax(output, dim=-1)
predicted_class_index = torch.argmax(probabilities, dim=-1).item()
return class_labels[predicted_class_index]
# Example usage (you'd need to define your classes and supported languages)
if __name__ == '__main__':
embedding_files = {
'en': 'fasttext_embeddings/cc.en.100.bin',
'fr': 'fasttext_embeddings/cc.fr.100.bin'
}
num_classes = 3 # Example number of classes
class_labels = ["positive", "negative", "neutral"]
model = SimpleMultilingualClassifier(embedding_files, num_classes)
# Dummy prediction
text_en = "This is a great movie."
lang_en = 'en'
prediction_en = model.predict(text_en, lang_en, class_labels)
print(f"English Prediction: {prediction_en}")
text_fr = "C'est un film incroyable."
lang_fr = 'fr'
prediction_fr = model.predict(text_fr, lang_fr, class_labels)
print(f"French Prediction: {prediction_fr}")