Tonic commited on
Commit
b10572e
·
unverified ·
1 Parent(s): b4aa97d

use modified config

Browse files
Files changed (1) hide show
  1. tasks/text.py +22 -5
tasks/text.py CHANGED
@@ -53,18 +53,36 @@ async def evaluate_text(request: TextEvaluationRequest):
53
  # MODEL INFERENCE CODE
54
  #--------------------------------------------------------------------------------------------
55
 
56
-
57
  try:
58
  # Set device
59
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
 
61
  # Model and tokenizer paths
62
- model_name = "Tonic/climate-guard-toxic-agent"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- # Load tokenizer and model
65
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
66
  model = AutoModelForSequenceClassification.from_pretrained(
67
- model_name,
 
68
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
69
  trust_remote_code=True
70
  ).to(device)
@@ -123,7 +141,6 @@ async def evaluate_text(request: TextEvaluationRequest):
123
  print(f"Error during model inference: {str(e)}")
124
  raise
125
 
126
-
127
  #--------------------------------------------------------------------------------------------
128
  # MODEL INFERENCE ENDS HERE
129
  #--------------------------------------------------------------------------------------------
 
53
  # MODEL INFERENCE CODE
54
  #--------------------------------------------------------------------------------------------
55
 
 
56
  try:
57
  # Set device
58
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
 
60
  # Model and tokenizer paths
61
+ model_name = "answerdotai/ModernBERT-base"
62
+
63
+ # Load config and modify it
64
+ config = AutoConfig.from_pretrained(model_name)
65
+ config.num_labels = 8
66
+ config.id2label = {
67
+ "0": "0_not_relevant",
68
+ "1": "1_not_happening",
69
+ "2": "2_not_human",
70
+ "3": "3_not_bad",
71
+ "4": "4_solutions_harmful_unnecessary",
72
+ "5": "5_science_is_unreliable",
73
+ "6": "6_proponents_biased",
74
+ "7": "7_fossil_fuels_needed"
75
+ }
76
+ config.label2id = {v: int(k) for k, v in config.id2label.items()}
77
+ config.problem_type = "single_label_classification"
78
 
79
+ # Load tokenizer
80
  tokenizer = AutoTokenizer.from_pretrained(model_name)
81
+
82
+ # Load model with modified config
83
  model = AutoModelForSequenceClassification.from_pretrained(
84
+ "Tonic/climate-guard-toxic-agent",
85
+ config=config,
86
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
87
  trust_remote_code=True
88
  ).to(device)
 
141
  print(f"Error during model inference: {str(e)}")
142
  raise
143
 
 
144
  #--------------------------------------------------------------------------------------------
145
  # MODEL INFERENCE ENDS HERE
146
  #--------------------------------------------------------------------------------------------