Add script for fixing params number
Browse files
backend/app/services/models.py
CHANGED
|
@@ -454,11 +454,11 @@ class ModelService(HuggingFaceService):
|
|
| 454 |
if model_size is None:
|
| 455 |
logger.error(LogFormatter.error("Model size validation failed", error))
|
| 456 |
raise Exception(error)
|
| 457 |
-
logger.info(LogFormatter.success(f"Model size validation passed: {model_size:.1f}
|
| 458 |
|
| 459 |
# Size limits based on precision
|
| 460 |
if model_data["precision"] in ["float16", "bfloat16"] and model_size > 100:
|
| 461 |
-
error_msg = f"Model too large for {model_data['precision']} (limit:
|
| 462 |
logger.error(LogFormatter.error("Size limit exceeded", error_msg))
|
| 463 |
raise Exception(error_msg)
|
| 464 |
|
|
|
|
| 454 |
if model_size is None:
|
| 455 |
logger.error(LogFormatter.error("Model size validation failed", error))
|
| 456 |
raise Exception(error)
|
| 457 |
+
logger.info(LogFormatter.success(f"Model size validation passed: {model_size:.1f}B"))
|
| 458 |
|
| 459 |
# Size limits based on precision
|
| 460 |
if model_data["precision"] in ["float16", "bfloat16"] and model_size > 100:
|
| 461 |
+
error_msg = f"Model too large for {model_data['precision']} (limit: 100B)"
|
| 462 |
logger.error(LogFormatter.error("Size limit exceeded", error_msg))
|
| 463 |
raise Exception(error_msg)
|
| 464 |
|
backend/utils/fix_wrong_model_size.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import pytz
|
| 4 |
+
import logging
|
| 5 |
+
import asyncio
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import huggingface_hub
|
| 9 |
+
from huggingface_hub.errors import RepositoryNotFoundError, RevisionNotFoundError
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
from git import Repo
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from tqdm.auto import tqdm
|
| 14 |
+
from tqdm.contrib.logging import logging_redirect_tqdm
|
| 15 |
+
|
| 16 |
+
from app.config.hf_config import HF_TOKEN, QUEUE_REPO, API, EVAL_REQUESTS_PATH
|
| 17 |
+
|
| 18 |
+
from app.utils.model_validation import ModelValidator
|
| 19 |
+
|
| 20 |
+
huggingface_hub.logging.set_verbosity_error()
|
| 21 |
+
huggingface_hub.utils.disable_progress_bars()
|
| 22 |
+
|
| 23 |
+
logging.basicConfig(
|
| 24 |
+
level=logging.ERROR,
|
| 25 |
+
format='%(message)s'
|
| 26 |
+
)
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
load_dotenv()
|
| 29 |
+
|
| 30 |
+
validator = ModelValidator()
|
| 31 |
+
|
| 32 |
+
def get_changed_files(repo_path, start_date, end_date):
|
| 33 |
+
repo = Repo(repo_path)
|
| 34 |
+
start = datetime.strptime(start_date, '%Y-%m-%d')
|
| 35 |
+
end = datetime.strptime(end_date, '%Y-%m-%d')
|
| 36 |
+
|
| 37 |
+
changed_files = set()
|
| 38 |
+
pbar = tqdm(repo.iter_commits(), desc=f"Reading commits from {end_date} to {start_date}")
|
| 39 |
+
for commit in pbar:
|
| 40 |
+
commit_date = datetime.fromtimestamp(commit.committed_date)
|
| 41 |
+
pbar.set_postfix_str(f"Commit date: {commit_date}")
|
| 42 |
+
if start <= commit_date <= end:
|
| 43 |
+
changed_files.update(item.a_path for item in commit.diff(commit.parents[0]))
|
| 44 |
+
|
| 45 |
+
if commit_date < start:
|
| 46 |
+
break
|
| 47 |
+
|
| 48 |
+
return changed_files
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def read_json(repo_path, file):
|
| 52 |
+
with open(f"{repo_path}/{file}") as file:
|
| 53 |
+
return json.load(file)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def write_json(repo_path, file, content):
|
| 57 |
+
with open(f"{repo_path}/{file}", "w") as file:
|
| 58 |
+
json.dump(content, file, indent=2)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main():
|
| 62 |
+
requests_path = "/Users/lozowski/Developer/requests"
|
| 63 |
+
start_date = "2024-12-09"
|
| 64 |
+
end_date = "2025-01-07"
|
| 65 |
+
|
| 66 |
+
changed_files = get_changed_files(requests_path, start_date, end_date)
|
| 67 |
+
|
| 68 |
+
for file in tqdm(changed_files):
|
| 69 |
+
try:
|
| 70 |
+
request_data = read_json(requests_path, file)
|
| 71 |
+
except FileNotFoundError as e:
|
| 72 |
+
tqdm.write(f"File {file} not found")
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
model_info = API.model_info(
|
| 77 |
+
repo_id=request_data["model"],
|
| 78 |
+
revision=request_data["revision"],
|
| 79 |
+
token=HF_TOKEN
|
| 80 |
+
)
|
| 81 |
+
except (RepositoryNotFoundError, RevisionNotFoundError) as e:
|
| 82 |
+
tqdm.write(f"Model info for {request_data["model"]} not found")
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
with logging_redirect_tqdm():
|
| 86 |
+
new_model_size, error = asyncio.run(validator.get_model_size(
|
| 87 |
+
model_info=model_info,
|
| 88 |
+
precision=request_data["precision"],
|
| 89 |
+
base_model=request_data["base_model"],
|
| 90 |
+
revision=request_data["revision"]
|
| 91 |
+
))
|
| 92 |
+
|
| 93 |
+
if error:
|
| 94 |
+
tqdm.write(f"Error getting model size info for {request_data["model"]}, {error}")
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
old_model_size = request_data["params"]
|
| 98 |
+
if old_model_size != new_model_size:
|
| 99 |
+
if new_model_size > 100:
|
| 100 |
+
tqdm.write(f"Model: {request_data["model"]}, size is more 100B: {new_model_size}")
|
| 101 |
+
|
| 102 |
+
tqdm.write(f"Model: {request_data["model"]}, old size: {request_data["params"]} new size: {new_model_size}")
|
| 103 |
+
tqdm.write(f"Updating request file {file}")
|
| 104 |
+
|
| 105 |
+
request_data["params"] = new_model_size
|
| 106 |
+
write_json(requests_path, file, content=request_data)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
main()
|