Tonic commited on
Commit
08e3356
·
unverified ·
1 Parent(s): 68ff849

revert to direct model loading

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. tasks/text.py +20 -12
requirements.txt CHANGED
@@ -9,4 +9,5 @@ python-dotenv==1.0.0
9
  requests==2.31.0
10
  numpy==1.24.3
11
  pydantic==2.4.2
12
- accelerate
 
 
9
  requests==2.31.0
10
  numpy==1.24.3
11
  pydantic==2.4.2
12
+ accelerate
13
+ huggingface-hub
tasks/text.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
11
 
12
  from .utils.evaluation import TextEvaluationRequest
13
  from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking
@@ -26,12 +26,11 @@ class TextClassifier:
26
  max_retries = 3
27
  for attempt in range(max_retries):
28
  try:
29
- # Initialize using pipeline instead
30
- self.classifier = pipeline(
31
- "text-classification",
32
- model="Tonic/climate-guard-toxic-agent",
33
- device=self.device
34
- )
35
  print("Model initialized successfully")
36
  break
37
  except Exception as e:
@@ -43,11 +42,20 @@ class TextClassifier:
43
  def predict_single(self, text: str) -> int:
44
  """Predict single text instance"""
45
  try:
46
- result = self.classifier(text)
47
- # Extract the label index from the result
48
- # Assuming the model outputs label indices directly
49
- label = int(result[0]['label'].split('_')[0])
50
- return label
 
 
 
 
 
 
 
 
 
51
  except Exception as e:
52
  print(f"Error in single prediction: {str(e)}")
53
  return 0 # Return default prediction on error
 
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
11
 
12
  from .utils.evaluation import TextEvaluationRequest
13
  from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking
 
26
  max_retries = 3
27
  for attempt in range(max_retries):
28
  try:
29
+ # Initialize tokenizer and model separately
30
+ self.tokenizer = AutoTokenizer.from_pretrained("Tonic/climate-guard-toxic-agent")
31
+ self.model = AutoModelForSequenceClassification.from_pretrained("Tonic/climate-guard-toxic-agent")
32
+ self.model.to(self.device)
33
+ self.model.eval()
 
34
  print("Model initialized successfully")
35
  break
36
  except Exception as e:
 
42
  def predict_single(self, text: str) -> int:
43
  """Predict single text instance"""
44
  try:
45
+ # Tokenize and prepare input
46
+ inputs = self.tokenizer(
47
+ text,
48
+ return_tensors="pt",
49
+ truncation=True,
50
+ max_length=512,
51
+ padding=True
52
+ ).to(self.device)
53
+
54
+ # Get prediction
55
+ with torch.no_grad():
56
+ outputs = self.model(**inputs)
57
+ predictions = outputs.logits.argmax(-1)
58
+ return predictions.item()
59
  except Exception as e:
60
  print(f"Error in single prediction: {str(e)}")
61
  return 0 # Return default prediction on error