Tonic commited on
Commit
21262c6
·
unverified ·
1 Parent(s): ada5a12

fix dataset loading

Browse files
Files changed (1) hide show
  1. tasks/text.py +12 -24
tasks/text.py CHANGED
@@ -8,10 +8,20 @@ from concurrent.futures import ThreadPoolExecutor
8
  from typing import List, Dict, Tuple
9
  import torch
10
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
 
 
11
 
12
  from .utils.evaluation import TextEvaluationRequest
13
  from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking
14
 
 
 
 
 
 
 
 
 
15
  # Disable torch compile
16
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
17
 
@@ -57,28 +67,6 @@ class TextClassifier:
57
  print(f"Attempt {attempt + 1} failed, retrying... Error: {str(e)}")
58
  time.sleep(1)
59
 
60
- def predict_single(self, text: str) -> int:
61
- """Predict single text instance"""
62
- try:
63
- # Tokenize with explicit padding and truncation
64
- inputs = self.tokenizer(
65
- text,
66
- return_tensors="pt",
67
- truncation=True,
68
- max_length=512,
69
- padding='max_length'
70
- ).to(self.device)
71
-
72
- # Get prediction
73
- with torch.no_grad():
74
- outputs = self.model(**inputs)
75
- predictions = torch.argmax(outputs.logits, dim=-1)
76
- return predictions.item()
77
-
78
- except Exception as e:
79
- print(f"Error in single prediction: {str(e)}")
80
- return 0 # Return default prediction on error
81
-
82
  def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]:
83
  """Process a batch of texts and return their predictions"""
84
  try:
@@ -124,8 +112,8 @@ async def evaluate_text(request: TextEvaluationRequest):
124
  }
125
 
126
  try:
127
- # Load and prepare the dataset using the dataset name from the request
128
- dataset = load_dataset(request.dataset_name)
129
 
130
  # Convert string labels to integers
131
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
 
8
  from typing import List, Dict, Tuple
9
  import torch
10
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
11
+ from huggingface_hub import login
12
+ from dotenv import load_dotenv
13
 
14
  from .utils.evaluation import TextEvaluationRequest
15
  from .utils.emissions import tracker, clean_emissions_data, get_space_info, start_tracking, stop_tracking
16
 
17
+ # Load environment variables
18
+ load_dotenv()
19
+
20
+ # Authenticate with Hugging Face
21
+ HF_TOKEN = os.getenv('HF_TOKEN')
22
+ if HF_TOKEN:
23
+ login(token=HF_TOKEN)
24
+
25
  # Disable torch compile
26
  os.environ["TORCH_COMPILE_DISABLE"] = "1"
27
 
 
67
  print(f"Attempt {attempt + 1} failed, retrying... Error: {str(e)}")
68
  time.sleep(1)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def process_batch(self, batch: List[str], batch_idx: int) -> Tuple[List[int], int]:
71
  """Process a batch of texts and return their predictions"""
72
  try:
 
112
  }
113
 
114
  try:
115
+ # Load and prepare the dataset using the correct dataset name
116
+ dataset = load_dataset("QuotaClimat/frugalaichallenge-text-train", use_auth_token=HF_TOKEN)
117
 
118
  # Convert string labels to integers
119
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})