Commit
·
d564ed1
1
Parent(s):
cf9ff03
save via commit
Browse files
utils.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import base64
|
2 |
from huggingface_hub import hf_hub_download
|
3 |
import fasttext
|
@@ -13,9 +14,8 @@ from sklearn.metrics import (
|
|
13 |
matthews_corrcoef
|
14 |
)
|
15 |
import numpy as np
|
|
|
16 |
from constants import *
|
17 |
-
from huggingface_hub import HfApi, login
|
18 |
-
from pathlib import Path
|
19 |
|
20 |
def predict_label(text, model, language_mapping_dict, use_mapping=False):
|
21 |
"""
|
@@ -183,7 +183,51 @@ def run_eval_one_vs_all(data_test, TARGET_LANG='Morocco'):
|
|
183 |
|
184 |
return out
|
185 |
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
def handle_evaluation(model_path, model_path_bin, use_mapping=False):
|
188 |
|
189 |
# download model and get the model path
|
@@ -300,6 +344,60 @@ def process_results_file(file, uploaded_model_name, base_path_save="./atlasia/su
|
|
300 |
|
301 |
return create_leaderboard_display_multilingual(df_multilingual, target_label, default_metrics), status_message
|
302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
|
304 |
def load_leaderboard_one_vs_all(DIALECT_CONFUSION_LEADERBOARD_FILE):
|
305 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
@@ -457,122 +555,23 @@ def render_fixed_columns(df):
|
|
457 |
""" A function to render HTML table with fixed 'model' column for better visibility """
|
458 |
return NotImplementedError
|
459 |
|
460 |
-
def update_repo_file(api, repo_id, filename, data):
|
461 |
-
# Use the app directory
|
462 |
-
app_dir = Path("/home/user/app")
|
463 |
-
temp_file = app_dir / filename
|
464 |
-
|
465 |
-
# Write the updated data to file
|
466 |
-
with open(temp_file, "w") as f:
|
467 |
-
json.dump(data, f, indent=4)
|
468 |
-
|
469 |
-
try:
|
470 |
-
# Try to create the repo if it doesn't exist
|
471 |
-
api.create_repo(repo_id, exist_ok=True)
|
472 |
-
|
473 |
-
# Upload the file back to the repository
|
474 |
-
api.upload_file(
|
475 |
-
path_or_fileobj=str(temp_file),
|
476 |
-
path_in_repo=filename,
|
477 |
-
repo_id=repo_id,
|
478 |
-
repo_type="model", # Changed back to "model" since it's a regular repo
|
479 |
-
commit_message=f"Update {filename}"
|
480 |
-
)
|
481 |
-
except Exception as e:
|
482 |
-
print(f"Error during repository operation: {str(e)}")
|
483 |
-
raise
|
484 |
-
|
485 |
-
def update_darija_one_vs_all_leaderboard(result_df, model_name, target_lang, DIALECT_CONFUSION_LEADERBOARD_FILE="darija_leaderboard_dialect_confusion.json"):
|
486 |
-
# Initialize Hugging Face API
|
487 |
-
api = HfApi()
|
488 |
-
|
489 |
-
try:
|
490 |
-
# Download existing file
|
491 |
-
try:
|
492 |
-
file_content = api.fetch_file_content(
|
493 |
-
repo_id=LEADERBOARD_PATH,
|
494 |
-
filename=DIALECT_CONFUSION_LEADERBOARD_FILE,
|
495 |
-
repo_type="model"
|
496 |
-
)
|
497 |
-
data = json.loads(file_content)
|
498 |
-
except:
|
499 |
-
data = []
|
500 |
-
|
501 |
-
# Process the results
|
502 |
-
for _, row in result_df.iterrows():
|
503 |
-
dialect = row['dialect']
|
504 |
-
if dialect == 'Other':
|
505 |
-
continue
|
506 |
-
|
507 |
-
target_entry = next((item for item in data if target_lang in item), None)
|
508 |
-
if target_entry is None:
|
509 |
-
target_entry = {target_lang: {}}
|
510 |
-
data.append(target_entry)
|
511 |
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
except Exception as e:
|
523 |
-
print(f"Error updating repository: {str(e)}")
|
524 |
-
raise
|
525 |
|
526 |
-
|
527 |
-
# Initialize Hugging Face API
|
528 |
-
api = HfApi()
|
529 |
-
|
530 |
try:
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
filename=MULTI_DIALECTS_LEADERBOARD_FILE,
|
536 |
-
repo_type="model"
|
537 |
-
)
|
538 |
-
data = json.loads(file_content)
|
539 |
-
except:
|
540 |
-
data = []
|
541 |
-
|
542 |
-
# Process the results
|
543 |
-
for _, row in result_df.iterrows():
|
544 |
-
country = row['country']
|
545 |
-
if country == 'Other':
|
546 |
-
continue
|
547 |
-
|
548 |
-
metrics = {
|
549 |
-
'f1_score': float(row['f1_score']),
|
550 |
-
'precision': float(row['precision']),
|
551 |
-
'recall': float(row['recall']),
|
552 |
-
'macro_f1_score': float(row['macro_f1_score']),
|
553 |
-
'micro_f1_score': float(row['micro_f1_score']),
|
554 |
-
'weighted_f1_score': float(row['weighted_f1_score']),
|
555 |
-
'specificity': float(row['specificity']),
|
556 |
-
'false_positive_rate': float(row['false_positive_rate']),
|
557 |
-
'false_negative_rate': float(row['false_negative_rate']),
|
558 |
-
'negative_predictive_value': float(row['negative_predictive_value']),
|
559 |
-
'balanced_accuracy': float(row['balanced_accuracy']),
|
560 |
-
'matthews_correlation': float(row['matthews_correlation']),
|
561 |
-
'n_test_samples': int(row['samples'])
|
562 |
-
}
|
563 |
-
|
564 |
-
country_entry = next((item for item in data if country in item), None)
|
565 |
-
if country_entry is None:
|
566 |
-
country_entry = {country: {}}
|
567 |
-
data.append(country_entry)
|
568 |
-
|
569 |
-
if country not in country_entry:
|
570 |
-
country_entry[country] = {}
|
571 |
-
country_entry[country][model_name] = metrics
|
572 |
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
except Exception as e:
|
577 |
-
print(f"Error updating repository: {str(e)}")
|
578 |
-
raise
|
|
|
1 |
+
import subprocess
|
2 |
import base64
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
import fasttext
|
|
|
14 |
matthews_corrcoef
|
15 |
)
|
16 |
import numpy as np
|
17 |
+
|
18 |
from constants import *
|
|
|
|
|
19 |
|
20 |
def predict_label(text, model, language_mapping_dict, use_mapping=False):
|
21 |
"""
|
|
|
183 |
|
184 |
return out
|
185 |
|
186 |
+
def update_darija_one_vs_all_leaderboard(result_df, model_name, target_lang, DIALECT_CONFUSION_LEADERBOARD_FILE="darija_leaderboard_binary.json"):
|
187 |
+
|
188 |
+
# use base path to ensure correct saving
|
189 |
+
base_path = os.path.dirname(__file__)
|
190 |
+
json_file_path = os.path.join(base_path, DIALECT_CONFUSION_LEADERBOARD_FILE)
|
191 |
+
|
192 |
+
print(f"[INFO] Loading leaderboard data (json file) from: {json_file_path}")
|
193 |
+
|
194 |
+
# Load leaderboard data
|
195 |
+
try:
|
196 |
+
with open(json_file_path, "r") as f:
|
197 |
+
data = json.load(f)
|
198 |
+
except FileNotFoundError:
|
199 |
+
data = []
|
200 |
+
|
201 |
+
# Process the results for each dialect/country
|
202 |
+
for _, row in result_df.iterrows():
|
203 |
+
dialect = row['dialect']
|
204 |
+
# Skip 'Other' class, it is considered as the null space
|
205 |
+
if dialect == 'Other':
|
206 |
+
continue
|
207 |
+
|
208 |
+
# Find existing target_lang entry or create a new one
|
209 |
+
target_entry = next((item for item in data if target_lang in item), None)
|
210 |
+
if target_entry is None:
|
211 |
+
target_entry = {target_lang: {}}
|
212 |
+
data.append(target_entry)
|
213 |
+
|
214 |
+
# Get the country-specific data for this target language
|
215 |
+
country_data = target_entry[target_lang]
|
216 |
+
|
217 |
+
# Initialize the dialect/country entry if it doesn't exist
|
218 |
+
if dialect not in country_data:
|
219 |
+
country_data[dialect] = {}
|
220 |
+
|
221 |
+
# Update the model metrics under the model name for the given dialect
|
222 |
+
country_data[dialect][model_name] = float(row['false_positive_rate'])
|
223 |
+
|
224 |
+
# Save updated leaderboard data
|
225 |
+
with open(json_file_path, "w") as f:
|
226 |
+
json.dump(data, f, indent=4)
|
227 |
+
|
228 |
+
save_leaderboard_file(DIALECT_CONFUSION_LEADERBOARD_FILE)
|
229 |
+
|
230 |
+
|
231 |
def handle_evaluation(model_path, model_path_bin, use_mapping=False):
|
232 |
|
233 |
# download model and get the model path
|
|
|
344 |
|
345 |
return create_leaderboard_display_multilingual(df_multilingual, target_label, default_metrics), status_message
|
346 |
|
347 |
+
def update_darija_multilingual_leaderboard(result_df, model_name, MULTI_DIALECTS_LEADERBOARD_FILE):
|
348 |
+
|
349 |
+
# use base path to ensure correct saving
|
350 |
+
base_path = os.path.dirname(__file__)
|
351 |
+
json_file_path = os.path.join(base_path, MULTI_DIALECTS_LEADERBOARD_FILE)
|
352 |
+
|
353 |
+
# Load leaderboard data
|
354 |
+
try:
|
355 |
+
with open(json_file_path, "r") as f:
|
356 |
+
data = json.load(f)
|
357 |
+
except FileNotFoundError:
|
358 |
+
data = []
|
359 |
+
|
360 |
+
# Process the results for each dialect/country
|
361 |
+
for _, row in result_df.iterrows():
|
362 |
+
country = row['country']
|
363 |
+
# skip 'Other' class, it is considered as the null space
|
364 |
+
if country == 'Other':
|
365 |
+
continue
|
366 |
+
|
367 |
+
# Create metrics dictionary directly
|
368 |
+
metrics = {
|
369 |
+
'f1_score': float(row['f1_score']),
|
370 |
+
'precision': float(row['precision']),
|
371 |
+
'recall': float(row['recall']),
|
372 |
+
'macro_f1_score': float(row['macro_f1_score']),
|
373 |
+
'micro_f1_score': float(row['micro_f1_score']),
|
374 |
+
'weighted_f1_score': float(row['weighted_f1_score']),
|
375 |
+
'specificity': float(row['specificity']),
|
376 |
+
'false_positive_rate': float(row['false_positive_rate']),
|
377 |
+
'false_negative_rate': float(row['false_negative_rate']),
|
378 |
+
'negative_predictive_value': float(row['negative_predictive_value']),
|
379 |
+
'balanced_accuracy': float(row['balanced_accuracy']),
|
380 |
+
'matthews_correlation': float(row['matthews_correlation']),
|
381 |
+
'n_test_samples': int(row['samples'])
|
382 |
+
}
|
383 |
+
|
384 |
+
# Find existing country entry or create new one
|
385 |
+
country_entry = next((item for item in data if country in item), None)
|
386 |
+
if country_entry is None:
|
387 |
+
country_entry = {country: {}}
|
388 |
+
data.append(country_entry)
|
389 |
+
|
390 |
+
# Update the model metrics directly under the model name
|
391 |
+
if country not in country_entry:
|
392 |
+
country_entry[country] = {}
|
393 |
+
country_entry[country][model_name] = metrics
|
394 |
+
|
395 |
+
# Save updated leaderboard data
|
396 |
+
with open(json_file_path, "w") as f:
|
397 |
+
json.dump(data, f, indent=4)
|
398 |
+
|
399 |
+
save_leaderboard_file(MULTI_DIALECTS_LEADERBOARD_FILE)
|
400 |
+
|
401 |
|
402 |
def load_leaderboard_one_vs_all(DIALECT_CONFUSION_LEADERBOARD_FILE):
|
403 |
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
555 |
""" A function to render HTML table with fixed 'model' column for better visibility """
|
556 |
return NotImplementedError
|
557 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
558 |
|
559 |
+
# Function to save and commit leaderboard files
|
560 |
+
def save_leaderboard_file(FILE_PATH):
|
561 |
+
# Example data to save (replace with actual leaderboard data)
|
562 |
+
data = {"status": "updated", "data": []}
|
563 |
|
564 |
+
# Save data in json
|
565 |
+
with open(FILE_PATH, "w") as f:
|
566 |
+
json.dump(data, f, indent=4)
|
567 |
+
print(f"[INFO] Saved {FILE_PATH}")
|
|
|
|
|
|
|
|
|
568 |
|
569 |
+
# Commit changes to the repository
|
|
|
|
|
|
|
570 |
try:
|
571 |
+
subprocess.run(["git", "add", FILE_PATH], check=True)
|
572 |
+
subprocess.run(["git", "commit", "-m", "Update leaderboard file"], check=True)
|
573 |
+
subprocess.run(["git", "push"], check=True)
|
574 |
+
print("[INFO] Leaderboard file committed and pushed to the repository.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
575 |
|
576 |
+
except subprocess.CalledProcessError as e:
|
577 |
+
print(f"[ERROR] Failed to commit or push changes: {e}")
|
|
|
|
|
|
|
|