langident
custom_code
impresso-langident / modeling_lang.py
Gleb Vinarskis
added emas pipeline
ae8276c
import torch
import torch.nn as nn
from transformers import PreTrainedModel
import logging
import floret
import os
from huggingface_hub import hf_hub_download
from .configuration_lang import ImpressoConfig
logger = logging.getLogger(__name__)
class LangDetectorModel(PreTrainedModel):
config_class = ImpressoConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Dummy for device checking
self.dummy_param = nn.Parameter(torch.zeros(1))
bin_filename = self.config.config.filename
# Check if the file is already present locally, else download it
if not os.path.exists(bin_filename):
# print(f"{bin_filename} not found locally, downloading from Hugging Face hub...")
bin_filename = hf_hub_download(repo_id=self.config.config._name_or_path,
filename=bin_filename)
# Load floret model using the full path
self.model_floret = floret.load_model(bin_filename)
def forward(self, input_ids, **kwargs):
if isinstance(input_ids, str):
# If the input is a single string, make it a list for floret
texts = [input_ids]
elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids):
texts = input_ids
else:
raise ValueError(f"Unexpected input type: {type(input_ids)}")
predictions, probabilities = self.model_floret.predict(texts, k=1)
return (
predictions,
probabilities,
)
@property
def device(self):
return next(self.parameters()).device
@classmethod
def from_pretrained(cls, *args, **kwargs):
# print("Ignoring weights and using custom initialization.")
# Manually create the config
config = ImpressoConfig(**kwargs)
# Pass the manually created config to the class
model = cls(config)
return model