Omartificial-Intelligence-Space commited on
Commit
91e6eee
·
verified ·
1 Parent(s): 560f753

update submit

Browse files
Files changed (1) hide show
  1. src/submission/submit.py +8 -1
src/submission/submit.py CHANGED
@@ -72,6 +72,8 @@ def evaluate_model_accuracy(model_name, num_examples):
72
  )
73
  if torch.cuda.is_available():
74
  model = model.cuda() # Move model to GPU if available
 
 
75
 
76
  # Load your dataset
77
  dataset = load_dataset("Omartificial-Intelligence-Space/Arabic_Openai_MMMLU")
@@ -125,11 +127,15 @@ Answer:"""
125
 
126
  # Get the top prediction
127
  top_prediction = get_top_prediction(text, tokenizer, model)
 
 
 
 
128
  is_correct = (top_prediction == data['Answer'])
129
  correct_predictions += int(is_correct)
130
  total_questions += 1
131
  overall_correct_predictions += int(is_correct)
132
- overall_total_questions +=1
133
 
134
  detailed_results.append({
135
  'Subject': subject,
@@ -163,6 +169,7 @@ Answer:"""
163
  except Exception as e:
164
  return f"Error: {str(e)}", pd.DataFrame(), pd.DataFrame()
165
 
 
166
  def add_new_eval(
167
  model: str,
168
  base_model: str,
 
72
  )
73
  if torch.cuda.is_available():
74
  model = model.cuda() # Move model to GPU if available
75
+ else:
76
+ model = model.cpu()
77
 
78
  # Load your dataset
79
  dataset = load_dataset("Omartificial-Intelligence-Space/Arabic_Openai_MMMLU")
 
127
 
128
  # Get the top prediction
129
  top_prediction = get_top_prediction(text, tokenizer, model)
130
+ if top_prediction is None:
131
+ print(f"Skipping question due to tokenization issues: {data['Question']}")
132
+ continue # Skip this question if no valid options are found
133
+
134
  is_correct = (top_prediction == data['Answer'])
135
  correct_predictions += int(is_correct)
136
  total_questions += 1
137
  overall_correct_predictions += int(is_correct)
138
+ overall_total_questions += 1
139
 
140
  detailed_results.append({
141
  'Subject': subject,
 
169
  except Exception as e:
170
  return f"Error: {str(e)}", pd.DataFrame(), pd.DataFrame()
171
 
172
+
173
  def add_new_eval(
174
  model: str,
175
  base_model: str,