import torch import torch.nn as nn from transformers import PreTrainedModel import logging import floret import os from huggingface_hub import hf_hub_download from .configuration_lang import ImpressoConfig logger = logging.getLogger(__name__) class LangDetectorModel(PreTrainedModel): config_class = ImpressoConfig def __init__(self, config): super().__init__(config) self.config = config # Dummy for device checking self.dummy_param = nn.Parameter(torch.zeros(1)) bin_filename = self.config.config.filename # Check if the file is already present locally, else download it if not os.path.exists(bin_filename): # print(f"{bin_filename} not found locally, downloading from Hugging Face hub...") bin_filename = hf_hub_download(repo_id=self.config.config._name_or_path, filename=bin_filename) # Load floret model using the full path self.model_floret = floret.load_model(bin_filename) def forward(self, input_ids, **kwargs): if isinstance(input_ids, str): # If the input is a single string, make it a list for floret texts = [input_ids] elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids): texts = input_ids else: raise ValueError(f"Unexpected input type: {type(input_ids)}") predictions, probabilities = self.model_floret.predict(texts, k=1) return ( predictions, probabilities, ) @property def device(self): return next(self.parameters()).device @classmethod def from_pretrained(cls, *args, **kwargs): # print("Ignoring weights and using custom initialization.") # Manually create the config config = ImpressoConfig(**kwargs) # Pass the manually created config to the class model = cls(config) return model