Spaces:
Sleeping
Sleeping
use modified config
Browse files- tasks/text.py +22 -5
tasks/text.py
CHANGED
@@ -53,18 +53,36 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
53 |
# MODEL INFERENCE CODE
|
54 |
#--------------------------------------------------------------------------------------------
|
55 |
|
56 |
-
|
57 |
try:
|
58 |
# Set device
|
59 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
60 |
|
61 |
# Model and tokenizer paths
|
62 |
-
model_name = "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
# Load tokenizer
|
65 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
66 |
model = AutoModelForSequenceClassification.from_pretrained(
|
67 |
-
|
|
|
68 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
69 |
trust_remote_code=True
|
70 |
).to(device)
|
@@ -123,7 +141,6 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
123 |
print(f"Error during model inference: {str(e)}")
|
124 |
raise
|
125 |
|
126 |
-
|
127 |
#--------------------------------------------------------------------------------------------
|
128 |
# MODEL INFERENCE ENDS HERE
|
129 |
#--------------------------------------------------------------------------------------------
|
|
|
53 |
# MODEL INFERENCE CODE
|
54 |
#--------------------------------------------------------------------------------------------
|
55 |
|
|
|
56 |
try:
|
57 |
# Set device
|
58 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
59 |
|
60 |
# Model and tokenizer paths
|
61 |
+
model_name = "answerdotai/ModernBERT-base"
|
62 |
+
|
63 |
+
# Load config and modify it
|
64 |
+
config = AutoConfig.from_pretrained(model_name)
|
65 |
+
config.num_labels = 8
|
66 |
+
config.id2label = {
|
67 |
+
"0": "0_not_relevant",
|
68 |
+
"1": "1_not_happening",
|
69 |
+
"2": "2_not_human",
|
70 |
+
"3": "3_not_bad",
|
71 |
+
"4": "4_solutions_harmful_unnecessary",
|
72 |
+
"5": "5_science_is_unreliable",
|
73 |
+
"6": "6_proponents_biased",
|
74 |
+
"7": "7_fossil_fuels_needed"
|
75 |
+
}
|
76 |
+
config.label2id = {v: int(k) for k, v in config.id2label.items()}
|
77 |
+
config.problem_type = "single_label_classification"
|
78 |
|
79 |
+
# Load tokenizer
|
80 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
81 |
+
|
82 |
+
# Load model with modified config
|
83 |
model = AutoModelForSequenceClassification.from_pretrained(
|
84 |
+
"Tonic/climate-guard-toxic-agent",
|
85 |
+
config=config,
|
86 |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
87 |
trust_remote_code=True
|
88 |
).to(device)
|
|
|
141 |
print(f"Error during model inference: {str(e)}")
|
142 |
raise
|
143 |
|
|
|
144 |
#--------------------------------------------------------------------------------------------
|
145 |
# MODEL INFERENCE ENDS HERE
|
146 |
#--------------------------------------------------------------------------------------------
|