Tonic commited on
Commit
aaabc84
·
unverified ·
1 Parent(s): 5a0374b

ModernBertConfig class that inherits from PretrainedConfig

Browse files
Files changed (1) hide show
  1. tasks/text.py +34 -17
tasks/text.py CHANGED
@@ -15,6 +15,32 @@ router = APIRouter()
15
  DESCRIPTION = "Climate Guard Toxic Agent is a ModernBERT for Climate Disinformation Detection"
16
  ROUTE = "/text"
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @router.post(ROUTE, tags=["Text Task"],
19
  description=DESCRIPTION)
20
  async def evaluate_text(request: TextEvaluationRequest):
@@ -52,35 +78,25 @@ async def evaluate_text(request: TextEvaluationRequest):
52
  #--------------------------------------------------------------------------------------------
53
  # MODEL INFERENCE CODE
54
  #--------------------------------------------------------------------------------------------
55
-
56
  try:
57
  # Set device
58
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
 
60
  # Model and tokenizer paths
61
  model_name = "Tonic/climate-guard-toxic-agent"
62
- tokenizer_name = "Tonic/climate-guard-toxic-agent"
63
 
64
- # Create config with essential parameters
65
- config = AutoConfig.from_pretrained(
66
- model_name,
67
  num_labels=8,
68
- problem_type="single_label_classification",
69
- architectures=["ModernBertForSequenceClassification"],
70
- model_type="modernbert",
71
- hidden_size=768,
72
- num_attention_heads=12,
73
- num_hidden_layers=22,
74
- intermediate_size=1152,
75
- max_position_embeddings=8192,
76
- layer_norm_eps=1e-05,
77
- classifier_dropout=0.0
78
  )
79
 
80
  # Load tokenizer
81
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
82
 
83
- # Load model with modified config
84
  model = AutoModelForSequenceClassification.from_pretrained(
85
  model_name,
86
  config=config,
@@ -138,6 +154,7 @@ async def evaluate_text(request: TextEvaluationRequest):
138
  print(f"Error during model inference: {str(e)}")
139
  raise
140
 
 
141
  #--------------------------------------------------------------------------------------------
142
  # MODEL INFERENCE ENDS HERE
143
  #--------------------------------------------------------------------------------------------
 
15
  DESCRIPTION = "Climate Guard Toxic Agent is a ModernBERT for Climate Disinformation Detection"
16
  ROUTE = "/text"
17
 
18
+ class ModernBertConfig(PretrainedConfig):
19
+ model_type = "modernbert"
20
+
21
+ def __init__(
22
+ self,
23
+ vocab_size=50368,
24
+ hidden_size=768,
25
+ num_hidden_layers=22,
26
+ num_attention_heads=12,
27
+ intermediate_size=1152,
28
+ max_position_embeddings=8192,
29
+ layer_norm_eps=1e-5,
30
+ classifier_dropout=0.0,
31
+ **kwargs
32
+ ):
33
+ super().__init__(**kwargs)
34
+ self.vocab_size = vocab_size
35
+ self.hidden_size = hidden_size
36
+ self.num_hidden_layers = num_hidden_layers
37
+ self.num_attention_heads = num_attention_heads
38
+ self.intermediate_size = intermediate_size
39
+ self.max_position_embeddings = max_position_embeddings
40
+ self.layer_norm_eps = layer_norm_eps
41
+ self.classifier_dropout = classifier_dropout
42
+
43
+
44
  @router.post(ROUTE, tags=["Text Task"],
45
  description=DESCRIPTION)
46
  async def evaluate_text(request: TextEvaluationRequest):
 
78
  #--------------------------------------------------------------------------------------------
79
  # MODEL INFERENCE CODE
80
  #--------------------------------------------------------------------------------------------
81
+
82
  try:
83
  # Set device
84
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
 
86
  # Model and tokenizer paths
87
  model_name = "Tonic/climate-guard-toxic-agent"
 
88
 
89
+ # Create custom config
90
+ config = ModernBertConfig(
 
91
  num_labels=8,
92
+ id2label={str(i): label for i, label in enumerate(LABEL_MAPPING.keys())},
93
+ label2id=LABEL_MAPPING
 
 
 
 
 
 
 
 
94
  )
95
 
96
  # Load tokenizer
97
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
98
 
99
+ # Load model with custom config
100
  model = AutoModelForSequenceClassification.from_pretrained(
101
  model_name,
102
  config=config,
 
154
  print(f"Error during model inference: {str(e)}")
155
  raise
156
 
157
+
158
  #--------------------------------------------------------------------------------------------
159
  # MODEL INFERENCE ENDS HERE
160
  #--------------------------------------------------------------------------------------------