Spaces:
Running
Running
kovacsvi
commited on
Commit
·
caa0374
1
Parent(s):
afcc557
JIT...
Browse files
utils.py
CHANGED
@@ -25,8 +25,7 @@ from interfaces.ontolisst import build_huggingface_path as hf_ontolisst_path
|
|
25 |
|
26 |
from huggingface_hub import scan_cache_dir
|
27 |
|
28 |
-
|
29 |
-
TOKENIZERS_PRELOADED = []
|
30 |
|
31 |
HF_TOKEN = os.environ["hf_read"]
|
32 |
|
@@ -55,12 +54,53 @@ for domain in domains_illframes.values():
|
|
55 |
|
56 |
tokenizers = ["xlm-roberta-large"]
|
57 |
|
58 |
-
def download_hf_models():
|
|
|
|
|
|
|
59 |
for model_id in models:
|
60 |
-
|
61 |
-
for tokenizer_id in tokenizers:
|
62 |
-
TOKENIZERS_PRELOADED[tokenizer_id] = AutoTokenizer.from_pretrained(tokenizer_id)
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
def df_h():
|
66 |
result = subprocess.run(["df", "-H"], capture_output=True, text=True)
|
|
|
25 |
|
26 |
from huggingface_hub import scan_cache_dir
|
27 |
|
28 |
+
JIT_DIR = "/data/jit_models"
|
|
|
29 |
|
30 |
HF_TOKEN = os.environ["hf_read"]
|
31 |
|
|
|
54 |
|
55 |
tokenizers = ["xlm-roberta-large"]
|
56 |
|
57 |
+
def download_hf_models(models=[], tokenizers=[], hf_token=None):
|
58 |
+
# Ensure the JIT model directory exists
|
59 |
+
os.makedirs(JIT_DIR, exist_ok=True)
|
60 |
+
|
61 |
for model_id in models:
|
62 |
+
print(f"Downloading + JIT tracing model: {model_id}")
|
|
|
|
|
63 |
|
64 |
+
# Load model and tokenizer
|
65 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
66 |
+
model_id,
|
67 |
+
token=hf_token,
|
68 |
+
device_map="auto"
|
69 |
+
)
|
70 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
71 |
+
model_id,
|
72 |
+
token=hf_token
|
73 |
+
)
|
74 |
+
|
75 |
+
model.eval()
|
76 |
+
|
77 |
+
# Dummy input for tracing
|
78 |
+
dummy_input = tokenizer(
|
79 |
+
"Hello, world!",
|
80 |
+
return_tensors="pt",
|
81 |
+
padding=True,
|
82 |
+
truncation=True,
|
83 |
+
max_length=256
|
84 |
+
)
|
85 |
+
|
86 |
+
# JIT trace
|
87 |
+
traced_model = torch.jit.trace(
|
88 |
+
model,
|
89 |
+
(dummy_input["input_ids"], dummy_input["attention_mask"])
|
90 |
+
)
|
91 |
+
|
92 |
+
# Save traced model
|
93 |
+
safe_model_name = model_id.replace("/", "_")
|
94 |
+
traced_model_path = os.path.join(JIT_DIR, f"{safe_model_name}.pt")
|
95 |
+
traced_model.save(traced_model_path)
|
96 |
+
print(f"✔️ Saved JIT model to: {traced_model_path}")
|
97 |
+
|
98 |
+
for tokenizer_id in tokenizers:
|
99 |
+
print(f"Downloading tokenizer: {tokenizer_id}")
|
100 |
+
AutoTokenizer.from_pretrained(
|
101 |
+
tokenizer_id,
|
102 |
+
token=hf_token
|
103 |
+
)
|
104 |
|
105 |
def df_h():
|
106 |
result = subprocess.run(["df", "-H"], capture_output=True, text=True)
|