Spaces:
Runtime error
Runtime error
Commit
·
0b5f433
1
Parent(s):
f20afb2
Upload folder using huggingface_hub
Browse files- main.py +2 -2
- src/myNLI.py +3 -3
main.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
#uvicorn main:app --reload
|
2 |
import os
|
3 |
os.environ['HF_HOME'] = 'src/cache'
|
4 |
-
os.environ['TRANSFORMERS_CACHE'] = 'src/cache'
|
5 |
|
6 |
from fastapi import FastAPI, status
|
7 |
from fastapi.responses import Response, JSONResponse
|
@@ -10,7 +10,7 @@ from pydantic import BaseModel
|
|
10 |
from typing import List
|
11 |
|
12 |
import os
|
13 |
-
import json
|
14 |
import time
|
15 |
|
16 |
from src.myNLI import FactChecker
|
|
|
1 |
#uvicorn main:app --reload
|
2 |
import os
|
3 |
os.environ['HF_HOME'] = 'src/cache'
|
4 |
+
# os.environ['TRANSFORMERS_CACHE'] = 'src/cache'
|
5 |
|
6 |
from fastapi import FastAPI, status
|
7 |
from fastapi.responses import Response, JSONResponse
|
|
|
10 |
from typing import List
|
11 |
|
12 |
import os
|
13 |
+
# import json
|
14 |
import time
|
15 |
|
16 |
from src.myNLI import FactChecker
|
src/myNLI.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import torch
|
2 |
-
from transformers import AutoModel, AutoTokenizer
|
3 |
from sentence_transformers import SentenceTransformer, util
|
4 |
import nltk
|
5 |
|
@@ -26,8 +26,8 @@ class FactChecker:
|
|
26 |
self.envir = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
27 |
|
28 |
# Load LLM
|
29 |
-
self.tokenizer = AutoTokenizer.from_pretrained("MoritzLaurer/mDeBERTa-v3-base-mnli-xnli",
|
30 |
-
self.mDeBertaModel = AutoModel.from_pretrained(f"src/mDeBERTa (ft) V6/mDeBERTa-v3-base-mnli-xnli-{self.INPUT_TYPE}",
|
31 |
# Load classifier model
|
32 |
self.checkpoints = torch.load(f"src/mDeBERTa (ft) V6/{self.INPUT_TYPE}.pt", map_location=self.envir)
|
33 |
|
|
|
1 |
import torch
|
2 |
+
from transformers import AutoModel, AutoTokenizer
|
3 |
from sentence_transformers import SentenceTransformer, util
|
4 |
import nltk
|
5 |
|
|
|
26 |
self.envir = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
27 |
|
28 |
# Load LLM
|
29 |
+
self.tokenizer = AutoTokenizer.from_pretrained("MoritzLaurer/mDeBERTa-v3-base-mnli-xnli", token=False) # LOAD mDEBERTa TOKENIZER
|
30 |
+
self.mDeBertaModel = AutoModel.from_pretrained(f"src/mDeBERTa (ft) V6/mDeBERTa-v3-base-mnli-xnli-{self.INPUT_TYPE}", token=False) # LOAD FINETUNED MODEL
|
31 |
# Load classifier model
|
32 |
self.checkpoints = torch.load(f"src/mDeBERTa (ft) V6/{self.INPUT_TYPE}.pt", map_location=self.envir)
|
33 |
|