Spaces:
Running
Running
File size: 1,129 Bytes
d407fa8 da5c744 d407fa8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch
import torch.nn as nn
from transformers import AutoModel
from config import (
HIDDEN_SIZE,
DROPOUT_PROB,
LAST_NUM_NEURON,
HF_REPO_NAME,
WEIGHTS_FILE_NAME,
PRETRAINED_MODEL,
)
from huggingface_hub import hf_hub_download
class EnergySmellsDetector(nn.Module):
def __init__(self, model_name):
super(EnergySmellsDetector, self).__init__()
self.model = AutoModel.from_pretrained(model_name)
self.dropout = nn.Dropout(DROPOUT_PROB)
self.fc = nn.Linear(HIDDEN_SIZE, LAST_NUM_NEURON)
def forward(self, input_ids, attention_mask):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
x = self.dropout(outputs.pooler_output)
logits = self.fc(x)
return torch.sigmoid(logits).to(float)
@staticmethod
def load_model_from_hf():
model_path = hf_hub_download(repo_id=HF_REPO_NAME, filename=WEIGHTS_FILE_NAME)
# Load model
model = EnergySmellsDetector(PRETRAINED_MODEL)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
return model
|