update submit
Browse files- src/submission/submit.py +8 -1
src/submission/submit.py
CHANGED
@@ -72,6 +72,8 @@ def evaluate_model_accuracy(model_name, num_examples):
|
|
72 |
)
|
73 |
if torch.cuda.is_available():
|
74 |
model = model.cuda() # Move model to GPU if available
|
|
|
|
|
75 |
|
76 |
# Load your dataset
|
77 |
dataset = load_dataset("Omartificial-Intelligence-Space/Arabic_Openai_MMMLU")
|
@@ -125,11 +127,15 @@ Answer:"""
|
|
125 |
|
126 |
# Get the top prediction
|
127 |
top_prediction = get_top_prediction(text, tokenizer, model)
|
|
|
|
|
|
|
|
|
128 |
is_correct = (top_prediction == data['Answer'])
|
129 |
correct_predictions += int(is_correct)
|
130 |
total_questions += 1
|
131 |
overall_correct_predictions += int(is_correct)
|
132 |
-
overall_total_questions +=1
|
133 |
|
134 |
detailed_results.append({
|
135 |
'Subject': subject,
|
@@ -163,6 +169,7 @@ Answer:"""
|
|
163 |
except Exception as e:
|
164 |
return f"Error: {str(e)}", pd.DataFrame(), pd.DataFrame()
|
165 |
|
|
|
166 |
def add_new_eval(
|
167 |
model: str,
|
168 |
base_model: str,
|
|
|
72 |
)
|
73 |
if torch.cuda.is_available():
|
74 |
model = model.cuda() # Move model to GPU if available
|
75 |
+
else:
|
76 |
+
model = model.cpu()
|
77 |
|
78 |
# Load your dataset
|
79 |
dataset = load_dataset("Omartificial-Intelligence-Space/Arabic_Openai_MMMLU")
|
|
|
127 |
|
128 |
# Get the top prediction
|
129 |
top_prediction = get_top_prediction(text, tokenizer, model)
|
130 |
+
if top_prediction is None:
|
131 |
+
print(f"Skipping question due to tokenization issues: {data['Question']}")
|
132 |
+
continue # Skip this question if no valid options are found
|
133 |
+
|
134 |
is_correct = (top_prediction == data['Answer'])
|
135 |
correct_predictions += int(is_correct)
|
136 |
total_questions += 1
|
137 |
overall_correct_predictions += int(is_correct)
|
138 |
+
overall_total_questions += 1
|
139 |
|
140 |
detailed_results.append({
|
141 |
'Subject': subject,
|
|
|
169 |
except Exception as e:
|
170 |
return f"Error: {str(e)}", pd.DataFrame(), pd.DataFrame()
|
171 |
|
172 |
+
|
173 |
def add_new_eval(
|
174 |
model: str,
|
175 |
base_model: str,
|