babel_machine / utils.py
kovacsvi
hf http cache cleanup
b1b87fb
raw
history blame
5.78 kB
import os
import shutil
import glob
import subprocess
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from interfaces.cap import languages as languages_cap
from interfaces.cap import domains as domains_cap
from interfaces.emotion9 import languages as languages_emotion9
from interfaces.illframes import domains as domains_illframes
from interfaces.cap import build_huggingface_path as hf_cap_path
from interfaces.cap_minor import build_huggingface_path as hf_cap_minor_path
from interfaces.cap_minor_media import build_huggingface_path as hf_cap_minor_media_path
from interfaces.cap_media_demo import build_huggingface_path as hf_cap_media_path # why... just follow the name template the next time pls
from interfaces.cap_media2 import build_huggingface_path as hf_cap_media2_path
from interfaces.manifesto import build_huggingface_path as hf_manifesto_path
from interfaces.sentiment import build_huggingface_path as hf_sentiment_path
from interfaces.emotion import build_huggingface_path as hf_emotion_path
from interfaces.emotion9 import build_huggingface_path as hf_emotion9_path
from interfaces.ontolisst import build_huggingface_path as hf_ontlisst_path
from interfaces.illframes import build_huggingface_path as hf_illframes_path
from interfaces.ontolisst import build_huggingface_path as hf_ontolisst_path
from huggingface_hub import scan_cache_dir
JIT_DIR = "/data/jit_models"
HF_TOKEN = os.environ["hf_read"]
# should be a temporary solution
models = [hf_manifesto_path(""), hf_sentiment_path(""), hf_emotion_path(""), hf_cap_minor_path("", ""), hf_ontolisst_path("")]
# it gets more difficult with cap
domains_cap = list(domains_cap.values())
for language in languages_cap:
for domain in domains_cap:
models.append(hf_cap_path(language, domain))
# cap media
models.append(hf_cap_media_path("", ""))
# cap media2
models.append(hf_cap_media2_path("", ""))
# cap minor media
models.append(hf_cap_minor_media_path("", "", False))
# emotion9
for language in languages_emotion9:
models.append(hf_emotion9_path(language))
# illframes (domains is a dict for some reason?)
for domain in domains_illframes.values():
models.append(hf_illframes_path(domain))
tokenizers = ["xlm-roberta-large"]
def download_hf_models():
# Ensure the JIT model directory exists
os.makedirs(JIT_DIR, exist_ok=True)
for model_id in models:
print(f"Downloading + JIT tracing model: {model_id}")
# Load model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained(
model_id,
token=HF_TOKEN,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large")
safe_model_name = model_id.replace("/", "_")
traced_model_path = os.path.join(JIT_DIR, f"{safe_model_name}.pt")
if os.path.exists(traced_model_path):
print(f"⏩ Skipping JIT β€” already exists: {traced_model_path}")
else:
print(f"βš™οΈ Tracing and saving: {traced_model_path}")
model.eval()
# Dummy input for tracing
dummy_input = tokenizer(
"Hello, world!",
return_tensors="pt",
padding=True,
truncation=True,
max_length=256
)
# JIT trace
traced_model = torch.jit.trace(
model,
(dummy_input["input_ids"], dummy_input["attention_mask"]),
strict=False
)
# Save traced model
traced_model.save(traced_model_path)
print(f"βœ”οΈ Saved JIT model to: {traced_model_path}")
def df_h():
df_result = subprocess.run(["df", "-H"], capture_output=True, text=True)
print("=== Disk Free Space (df -H) ===")
print(df_result.stdout)
du_result = subprocess.run(["du", "-h", "--max-depth=2", "/data/"], capture_output=True, text=True)
print("=== Disk Usage for /data/ (du -h --max-depth=2) ===")
print(du_result.stdout)
def hf_cleanup():
http_folders = glob.glob("/data/http*")
for folder in http_folders:
if os.path.isdir(folder):
print(f"Deleting: {folder}")
shutil.rmtree(folder)
def scan_cache():
# Scan Hugging Face model cache
cache_dir = os.environ.get("TRANSFORMERS_CACHE", os.path.expanduser("~/.cache/huggingface/transformers"))
scan_result = scan_cache_dir(cache_dir)
print("=== πŸ€— Hugging Face Model Cache ===")
print(f"Cache size: {scan_result.size_on_disk / 1e6:.2f} MB")
print(f"Number of repos: {len(scan_result.repos)}")
for repo in scan_result.repos:
print(f"- {repo.repo_id} ({repo.repo_type}) β€” {repo.size_on_disk / 1e6:.2f} MB")
print("\n=== 🧊 TorchScript JIT Cache ===")
if not os.path.exists(JIT_DIR):
print(f"(Directory does not exist: {JIT_DIR})")
return
total_size = 0
for filename in os.listdir(JIT_DIR):
if filename.endswith(".pt"):
path = os.path.join(JIT_DIR, filename)
size = os.path.getsize(path)
total_size += size
print(f"- {filename}: {size / 1e6:.2f} MB")
print(f"Total JIT cache size: {total_size / 1e6:.2f} MB")
def set_hf_cache_dir(path:str):
os.environ['TRANSFORMERS_CACHE'] = path
os.environ['HF_HOME'] = path
os.environ['HF_DATASETS_CACHE'] = path
os.environ['TORCH_HOME'] = path
def is_disk_full(min_free_space_in_GB=10):
total, used, free = shutil.disk_usage("/")
free_gb = free / (1024 ** 3)
if free_gb >= min_free_space_in_GB:
return False
else:
return True