Tonic commited on
Commit
a036e74
·
unverified ·
1 Parent(s): 822db29

update imports , textclassifier

Browse files
Files changed (1) hide show
  1. tasks/text.py +16 -42
tasks/text.py CHANGED
@@ -8,7 +8,7 @@ from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
  import torch.nn as nn
11
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
12
  from huggingface_hub import login
13
  from dotenv import load_dotenv
14
 
@@ -42,34 +42,22 @@ class TextClassifier:
42
  # Initialize tokenizer first
43
  self.tokenizer = AutoTokenizer.from_pretrained(
44
  model_name,
45
- model_max_length=8192,
46
  padding_side='right',
47
  truncation_side='right'
48
  )
49
 
50
- # Load base config
51
- self.config = AutoConfig.from_pretrained(
52
- model_name,
53
- num_labels=8,
54
- problem_type="single_label_classification"
55
- )
56
-
57
- # Set required attributes
58
- self.config.hidden_size = 768
59
- self.config.num_attention_heads = 12
60
- self.config.num_hidden_layers = 12
61
- self.config.norm_eps = 1e-5
62
-
63
- # Initialize model with basic config
64
- self.model = AutoModelForSequenceClassification.from_pretrained(
65
- model_name,
66
- config=self.config,
67
- ignore_mismatched_sizes=True
68
  )
69
 
70
- # Move model to appropriate device
71
- self.model = self.model.to(self.device)
72
- self.model.eval()
73
  print("Model initialized successfully")
74
  break
75
 
@@ -84,22 +72,9 @@ class TextClassifier:
84
  try:
85
  print(f"Processing batch {batch_idx} with {len(batch)} items")
86
 
87
- # Tokenize with padding and truncation
88
- inputs = self.tokenizer(
89
- batch,
90
- return_tensors="pt",
91
- truncation=True,
92
- max_length=512,
93
- padding=True
94
- )
95
-
96
- # Move inputs to device
97
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
98
-
99
- # Get predictions
100
- with torch.no_grad():
101
- outputs = self.model(**inputs)
102
- predictions = torch.argmax(outputs.logits, dim=-1).cpu().tolist()
103
 
104
  print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
105
  return predictions, batch_idx
@@ -110,11 +85,10 @@ class TextClassifier:
110
 
111
  def __del__(self):
112
  # Clean up CUDA memory
113
- if hasattr(self, 'model'):
114
- del self.model
115
  if torch.cuda.is_available():
116
  torch.cuda.empty_cache()
117
-
118
 
119
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
120
  async def evaluate_text(request: TextEvaluationRequest):
 
8
  from typing import List, Dict, Tuple
9
  import torch
10
  import torch.nn as nn
11
+ from transformers import AutoTokenizer, pipeline
12
  from huggingface_hub import login
13
  from dotenv import load_dotenv
14
 
 
42
  # Initialize tokenizer first
43
  self.tokenizer = AutoTokenizer.from_pretrained(
44
  model_name,
45
+ model_max_length=512, # Reduced from 8192
46
  padding_side='right',
47
  truncation_side='right'
48
  )
49
 
50
+ # Use pipeline for simpler initialization
51
+ self.classifier = pipeline(
52
+ "text-classification",
53
+ model=model_name,
54
+ tokenizer=self.tokenizer,
55
+ device=self.device,
56
+ max_length=512,
57
+ truncation=True,
58
+ batch_size=32
 
 
 
 
 
 
 
 
 
59
  )
60
 
 
 
 
61
  print("Model initialized successfully")
62
  break
63
 
 
72
  try:
73
  print(f"Processing batch {batch_idx} with {len(batch)} items")
74
 
75
+ # Use pipeline for prediction
76
+ outputs = self.classifier(batch)
77
+ predictions = [int(output['label'].split('_')[0]) for output in outputs]
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
80
  return predictions, batch_idx
 
85
 
86
  def __del__(self):
87
  # Clean up CUDA memory
88
+ if hasattr(self, 'classifier'):
89
+ del self.classifier
90
  if torch.cuda.is_available():
91
  torch.cuda.empty_cache()
 
92
 
93
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
94
  async def evaluate_text(request: TextEvaluationRequest):