Tonic commited on
Commit
bc4f464
·
unverified ·
1 Parent(s): 21262c6

use pipeline

Browse files
Files changed (1) hide show
  1. tasks/text.py +14 -36
tasks/text.py CHANGED
@@ -7,7 +7,7 @@ import os
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
11
  from huggingface_hub import login
12
  from dotenv import load_dotenv
13
 
@@ -18,7 +18,7 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info, star
18
  load_dotenv()
19
 
20
  # Authenticate with Hugging Face
21
- HF_TOKEN = os.getenv('HF_TOKEN')
22
  if HF_TOKEN:
23
  login(token=HF_TOKEN)
24
 
@@ -38,26 +38,13 @@ class TextClassifier:
38
 
39
  for attempt in range(max_retries):
40
  try:
41
- # Load config first
42
- config = AutoConfig.from_pretrained(model_name)
43
-
44
- # Initialize tokenizer with specific model type
45
- self.tokenizer = AutoTokenizer.from_pretrained(
46
- model_name,
47
- model_max_length=512,
48
- padding_side='right',
49
- truncation_side='right'
50
- )
51
-
52
- # Initialize model with config
53
- self.model = AutoModelForSequenceClassification.from_pretrained(
54
- model_name,
55
- config=config,
56
- torch_dtype=torch.float32
57
  )
58
-
59
- self.model.to(self.device)
60
- self.model.eval()
61
  print("Model initialized successfully")
62
  break
63
 
@@ -72,18 +59,9 @@ class TextClassifier:
72
  try:
73
  print(f"Processing batch {batch_idx} with {len(batch)} items")
74
 
75
- # Process entire batch at once
76
- inputs = self.tokenizer(
77
- batch,
78
- return_tensors="pt",
79
- truncation=True,
80
- max_length=512,
81
- padding='max_length'
82
- ).to(self.device)
83
-
84
- with torch.no_grad():
85
- outputs = self.model(**inputs)
86
- predictions = torch.argmax(outputs.logits, dim=-1).tolist()
87
 
88
  print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
89
  return predictions, batch_idx
@@ -112,13 +90,13 @@ async def evaluate_text(request: TextEvaluationRequest):
112
  }
113
 
114
  try:
115
- # Load and prepare the dataset using the correct dataset name
116
- dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train", use_auth_token=HF_TOKEN)
117
 
118
  # Convert string labels to integers
119
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
120
 
121
- # Split dataset according to request parameters
122
  test_dataset = dataset["test"]
123
 
124
  # Start tracking emissions
 
7
  from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
+ from transformers import pipeline
11
  from huggingface_hub import login
12
  from dotenv import load_dotenv
13
 
 
18
  load_dotenv()
19
 
20
  # Authenticate with Hugging Face
21
+ HF_TOKEN = os.getenv('HUGGINGFACE_TOKEN')
22
  if HF_TOKEN:
23
  login(token=HF_TOKEN)
24
 
 
38
 
39
  for attempt in range(max_retries):
40
  try:
41
+ # Initialize pipeline
42
+ self.classifier = pipeline(
43
+ "text-classification",
44
+ model=model_name,
45
+ device=self.device,
46
+ batch_size=32
 
 
 
 
 
 
 
 
 
 
47
  )
 
 
 
48
  print("Model initialized successfully")
49
  break
50
 
 
59
  try:
60
  print(f"Processing batch {batch_idx} with {len(batch)} items")
61
 
62
+ # Use pipeline for prediction
63
+ results = self.classifier(batch)
64
+ predictions = [int(result['label'].split('_')[0]) for result in results]
 
 
 
 
 
 
 
 
 
65
 
66
  print(f"Completed batch {batch_idx} with {len(predictions)} predictions")
67
  return predictions, batch_idx
 
90
  }
91
 
92
  try:
93
+ # Load and prepare the dataset
94
+ dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train", token=HF_TOKEN)
95
 
96
  # Convert string labels to integers
97
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
98
 
99
+ # Split dataset
100
  test_dataset = dataset["test"]
101
 
102
  # Start tracking emissions