Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from transformers import AutoModel | |
from config import ( | |
HIDDEN_SIZE, | |
DROPOUT_PROB, | |
LAST_NUM_NEURON, | |
HF_REPO_NAME, | |
WEIGHTS_FILE_NAME, | |
PRETRAINED_MODEL, | |
) | |
from huggingface_hub import hf_hub_download | |
class EnergySmellsDetector(nn.Module): | |
def __init__(self, model_name): | |
super(EnergySmellsDetector, self).__init__() | |
self.model = AutoModel.from_pretrained(model_name) | |
self.dropout = nn.Dropout(DROPOUT_PROB) | |
self.fc = nn.Linear(HIDDEN_SIZE, LAST_NUM_NEURON) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) | |
x = self.dropout(outputs.pooler_output) | |
logits = self.fc(x) | |
return torch.sigmoid(logits).to(float) | |
def load_model_from_hf(): | |
model_path = hf_hub_download(repo_id=HF_REPO_NAME, filename=WEIGHTS_FILE_NAME) | |
# Load model | |
model = EnergySmellsDetector(PRETRAINED_MODEL) | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
return model | |