kovacsvi commited on
Commit
caa0374
·
1 Parent(s): afcc557
Files changed (1) hide show
  1. utils.py +46 -6
utils.py CHANGED
@@ -25,8 +25,7 @@ from interfaces.ontolisst import build_huggingface_path as hf_ontolisst_path
25
 
26
  from huggingface_hub import scan_cache_dir
27
 
28
- MODELS_PRELOADED = []
29
- TOKENIZERS_PRELOADED = []
30
 
31
  HF_TOKEN = os.environ["hf_read"]
32
 
@@ -55,12 +54,53 @@ for domain in domains_illframes.values():
55
 
56
  tokenizers = ["xlm-roberta-large"]
57
 
58
- def download_hf_models():
 
 
 
59
  for model_id in models:
60
- MODELS_PRELOADED[model_id] = AutoModelForSequenceClassification.from_pretrained(model_id, device_map="auto", token=HF_TOKEN)
61
- for tokenizer_id in tokenizers:
62
- TOKENIZERS_PRELOADED[tokenizer_id] = AutoTokenizer.from_pretrained(tokenizer_id)
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def df_h():
66
  result = subprocess.run(["df", "-H"], capture_output=True, text=True)
 
25
 
26
  from huggingface_hub import scan_cache_dir
27
 
28
+ JIT_DIR = "/data/jit_models"
 
29
 
30
  HF_TOKEN = os.environ["hf_read"]
31
 
 
54
 
55
  tokenizers = ["xlm-roberta-large"]
56
 
57
+ def download_hf_models(models=[], tokenizers=[], hf_token=None):
58
+ # Ensure the JIT model directory exists
59
+ os.makedirs(JIT_DIR, exist_ok=True)
60
+
61
  for model_id in models:
62
+ print(f"Downloading + JIT tracing model: {model_id}")
 
 
63
 
64
+ # Load model and tokenizer
65
+ model = AutoModelForSequenceClassification.from_pretrained(
66
+ model_id,
67
+ token=hf_token,
68
+ device_map="auto"
69
+ )
70
+ tokenizer = AutoTokenizer.from_pretrained(
71
+ model_id,
72
+ token=hf_token
73
+ )
74
+
75
+ model.eval()
76
+
77
+ # Dummy input for tracing
78
+ dummy_input = tokenizer(
79
+ "Hello, world!",
80
+ return_tensors="pt",
81
+ padding=True,
82
+ truncation=True,
83
+ max_length=256
84
+ )
85
+
86
+ # JIT trace
87
+ traced_model = torch.jit.trace(
88
+ model,
89
+ (dummy_input["input_ids"], dummy_input["attention_mask"])
90
+ )
91
+
92
+ # Save traced model
93
+ safe_model_name = model_id.replace("/", "_")
94
+ traced_model_path = os.path.join(JIT_DIR, f"{safe_model_name}.pt")
95
+ traced_model.save(traced_model_path)
96
+ print(f"✔️ Saved JIT model to: {traced_model_path}")
97
+
98
+ for tokenizer_id in tokenizers:
99
+ print(f"Downloading tokenizer: {tokenizer_id}")
100
+ AutoTokenizer.from_pretrained(
101
+ tokenizer_id,
102
+ token=hf_token
103
+ )
104
 
105
  def df_h():
106
  result = subprocess.run(["df", "-H"], capture_output=True, text=True)