poltextlab commited on
Commit
54f853e
·
verified ·
1 Parent(s): a55b33f

Create cap_media2.py

Browse files
Files changed (1) hide show
  1. interfaces/cap_media2.py +86 -0
interfaces/cap_media2.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ import pandas as pd
7
+ from transformers import AutoModelForSequenceClassification
8
+ from transformers import AutoTokenizer
9
+ from huggingface_hub import HfApi
10
+
11
+ from label_dicts import CAP_MEDIA2_NUM_DICT, CAP_MEDIA2_LABEL_NAMES
12
+
13
+ from .utils import is_disk_full, release_model
14
+
15
+ HF_TOKEN = os.environ["hf_read"]
16
+
17
+ languages = [
18
+ "Multilingual",
19
+ ]
20
+
21
+ domains = {
22
+ "media": "media"
23
+ }
24
+
25
+ def check_huggingface_path(checkpoint_path: str):
26
+ try:
27
+ hf_api = HfApi(token=HF_TOKEN)
28
+ hf_api.model_info(checkpoint_path, token=HF_TOKEN)
29
+ return True
30
+ except:
31
+ return False
32
+
33
+ def build_huggingface_path(language: str, domain: str):
34
+ return "poltextlab/xlm-roberta-large-pooled-cap-media2"
35
+
36
+ def predict(text, model_id, tokenizer_id):
37
+ device = torch.device("cpu")
38
+
39
+ # Load JIT-traced model
40
+ jit_model_path = f"/data/jit_models/{model_id.replace('/', '_')}.pt"
41
+ model = torch.jit.load(jit_model_path).to(device)
42
+ model.eval()
43
+
44
+ # Load tokenizer (still regular HF)
45
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
46
+
47
+ # Tokenize input
48
+ inputs = tokenizer(
49
+ text,
50
+ max_length=256,
51
+ truncation=True,
52
+ padding="do_not_pad",
53
+ return_tensors="pt"
54
+ )
55
+ inputs = {k: v.to(device) for k, v in inputs.items()}
56
+
57
+ with torch.no_grad():
58
+ output = model(inputs["input_ids"], inputs["attention_mask"])
59
+ print(output) # debug
60
+ logits = output["logits"]
61
+
62
+ release_model(model, model_id)
63
+
64
+ probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
65
+ output_pred = {f"[{CAP_MEDIA2_NUM_DICT[i]}] {CAP_MEDIA2_LABEL_NAMES[CAP_MEDIA2_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
66
+ output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
67
+ return output_pred, output_info
68
+
69
+ def predict_cap(text, language, domain):
70
+ domain = domains[domain]
71
+ model_id = build_huggingface_path(language, domain)
72
+ tokenizer_id = "xlm-roberta-large"
73
+
74
+ if is_disk_full():
75
+ os.system('rm -rf /data/models*')
76
+ os.system('rm -r ~/.cache/huggingface/hub')
77
+
78
+ return predict(text, model_id, tokenizer_id)
79
+
80
+ demo = gr.Interface(
81
+ title="CAP Media2 Topics Babel Demo",
82
+ fn=predict_cap,
83
+ inputs=[gr.Textbox(lines=6, label="Input"),
84
+ gr.Dropdown(languages, label="Language", value=languages[0]),
85
+ gr.Dropdown(domains.keys(), label="Domain", value=list(domains.keys())[0])],
86
+ outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])