Update tasks/text.py
Browse files- tasks/text.py +7 -2
tasks/text.py
CHANGED
|
@@ -65,9 +65,14 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 65 |
|
| 66 |
# Load model and tokenizer from Hugging Face Hub
|
| 67 |
MODEL_REPO = "ClimateDebunk/FineTunedDistilBert4SeqClass"
|
|
|
|
| 68 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
def preprocess(texts):
|
| 73 |
""" Tokenize text inputs for DistilBERT """
|
|
|
|
| 65 |
|
| 66 |
# Load model and tokenizer from Hugging Face Hub
|
| 67 |
MODEL_REPO = "ClimateDebunk/FineTunedDistilBert4SeqClass"
|
| 68 |
+
MODEL_PATH = "./distilbert_trained.pth"
|
| 69 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
|
| 70 |
+
|
| 71 |
+
config = DistilBertConfig.from_pretrained("distilbert-base-uncased", num_labels=8)
|
| 72 |
+
model = DistilBertForSequenceClassification(config)
|
| 73 |
+
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device("cpu")))
|
| 74 |
+
#model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
|
| 75 |
+
model.eval()
|
| 76 |
|
| 77 |
def preprocess(texts):
|
| 78 |
""" Tokenize text inputs for DistilBERT """
|