kovacsvi commited on
Commit
7cbaea3
·
1 Parent(s): 3f77878

time profiling for prediction

Browse files
Files changed (1) hide show
  1. interfaces/cap.py +18 -5
interfaces/cap.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import spaces
3
 
4
  import os
 
5
  import torch
6
  import numpy as np
7
  import pandas as pd
@@ -34,6 +35,7 @@ domains = {
34
  "local government agenda": "localgovernment"
35
  }
36
 
 
37
  def check_huggingface_path(checkpoint_path: str):
38
  try:
39
  hf_api = HfApi(token=HF_TOKEN)
@@ -41,6 +43,7 @@ def check_huggingface_path(checkpoint_path: str):
41
  return True
42
  except:
43
  return False
 
44
 
45
  def build_huggingface_path(language: str, domain: str):
46
  language = language.lower()
@@ -82,19 +85,22 @@ def build_huggingface_path(language: str, domain: str):
82
  return model_path
83
  else:
84
  return "poltextlab/xlm-roberta-large-pooled-cap"
 
85
 
86
  def predict(text, model_id, tokenizer_id):
87
  device = torch.device("cpu")
88
 
89
- # Load JIT-traced model
90
  jit_model_path = f"/data/jit_models/{model_id.replace('/', '_')}.pt"
91
  model = torch.jit.load(jit_model_path).to(device)
92
  model.eval()
 
93
 
94
- # Load tokenizer (still regular HF)
95
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
 
96
 
97
- # Tokenize input
98
  inputs = tokenizer(
99
  text,
100
  max_length=256,
@@ -103,18 +109,24 @@ def predict(text, model_id, tokenizer_id):
103
  return_tensors="pt"
104
  )
105
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
106
 
 
107
  with torch.no_grad():
108
  output = model(inputs["input_ids"], inputs["attention_mask"])
109
- print(output) # debug
110
  logits = output["logits"]
111
-
 
112
  release_model(model, model_id)
113
 
 
114
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
115
  output_pred = {f"[{CAP_NUM_DICT[i]}] {CAP_LABEL_NAMES[CAP_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
116
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
 
 
117
  return output_pred, output_info
 
118
 
119
  def predict_cap(text, language, domain):
120
  print(domain) # debug statement
@@ -127,6 +139,7 @@ def predict_cap(text, language, domain):
127
  os.system('rm -r ~/.cache/huggingface/hub')
128
 
129
  return predict(text, model_id, tokenizer_id)
 
130
 
131
  demo = gr.Interface(
132
  title="CAP Babel Demo",
 
2
  import spaces
3
 
4
  import os
5
+ import time
6
  import torch
7
  import numpy as np
8
  import pandas as pd
 
35
  "local government agenda": "localgovernment"
36
  }
37
 
38
+
39
  def check_huggingface_path(checkpoint_path: str):
40
  try:
41
  hf_api = HfApi(token=HF_TOKEN)
 
43
  return True
44
  except:
45
  return False
46
+
47
 
48
  def build_huggingface_path(language: str, domain: str):
49
  language = language.lower()
 
85
  return model_path
86
  else:
87
  return "poltextlab/xlm-roberta-large-pooled-cap"
88
+
89
 
90
  def predict(text, model_id, tokenizer_id):
91
  device = torch.device("cpu")
92
 
93
+ t0 = time.perf_counter()
94
  jit_model_path = f"/data/jit_models/{model_id.replace('/', '_')}.pt"
95
  model = torch.jit.load(jit_model_path).to(device)
96
  model.eval()
97
+ print(f"Model loading: {time.perf_counter() - t0:.3f}s")
98
 
99
+ t1 = time.perf_counter()
100
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
101
+ print(f"Tokenizer loading: {time.perf_counter() - t1:.3f}s")
102
 
103
+ t2 = time.perf_counter()
104
  inputs = tokenizer(
105
  text,
106
  max_length=256,
 
109
  return_tensors="pt"
110
  )
111
  inputs = {k: v.to(device) for k, v in inputs.items()}
112
+ print(f"Tokenization: {time.perf_counter() - t2:.3f}s")
113
 
114
+ t3 = time.perf_counter()
115
  with torch.no_grad():
116
  output = model(inputs["input_ids"], inputs["attention_mask"])
 
117
  logits = output["logits"]
118
+ print(f"Inference: {time.perf_counter() - t3:.3f}s")
119
+
120
  release_model(model, model_id)
121
 
122
+ t4 = time.perf_counter()
123
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
124
  output_pred = {f"[{CAP_NUM_DICT[i]}] {CAP_LABEL_NAMES[CAP_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
125
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
126
+ print(f"Post-processing: {time.perf_counter() - t4:.3f}s")
127
+
128
  return output_pred, output_info
129
+
130
 
131
  def predict_cap(text, language, domain):
132
  print(domain) # debug statement
 
139
  os.system('rm -r ~/.cache/huggingface/hub')
140
 
141
  return predict(text, model_id, tokenizer_id)
142
+
143
 
144
  demo = gr.Interface(
145
  title="CAP Babel Demo",