File size: 2,613 Bytes
b1c2932
 
 
 
 
 
 
 
 
3a6eb20
b1c2932
853f29a
e1df718
b1c2932
 
 
7a079bf
 
 
 
 
b1c2932
 
 
 
 
 
 
4bba8df
fb1a253
 
 
4bba8df
 
fb1a253
 
 
 
 
 
99646de
fb1a253
99646de
fb1a253
 
 
 
4bba8df
fb1a253
 
 
 
853f29a
4bba8df
 
a55b33f
4bba8df
b1c2932
 
 
4bba8df
 
 
 
 
 
 
 
b1c2932
 
4bba8df
b1c2932
 
2b41a25
4bba8df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr

import os
import torch
import numpy as np
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from huggingface_hub import HfApi

from label_dicts import MANIFESTO_LABEL_NAMES, MANIFESTO_NUM_DICT

from .utils import is_disk_full, release_model

HF_TOKEN = os.environ["hf_read"]

languages = [
    "Armenian", "Bulgarian", "Croatian", "Czech", "Danish", "Dutch", "English",
    "Estonian", "Finnish", "French", "Georgian", "German", "Greek", "Hebrew",
    "Hungarian", "Icelandic", "Italian", "Japanese", "Korean", "Latvian",
    "Lithuanian", "Norwegian", "Polish", "Portuguese", "Romanian", "Russian",
    "Serbian", "Slovak", "Slovenian", "Spanish", "Swedish", "Turkish"
]

def build_huggingface_path(language: str):
    return "poltextlab/xlm-roberta-large-manifesto"

def predict(text, model_id, tokenizer_id):
    device = torch.device("cpu")

    # Load JIT-traced model
    jit_model_path = f"/data/jit_models/{model_id.replace('/', '_')}.pt"
    model = torch.jit.load(jit_model_path).to(device)
    model.eval()

    # Load tokenizer (still regular HF)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

    # Tokenize input
    inputs = tokenizer(
        text,
        max_length=64,
        truncation=True,
        padding=True,
        return_tensors="pt"
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        output = model(inputs["input_ids"], inputs["attention_mask"])
        print(output) # debug
        logits = output["logits"]
        
    release_model(model, model_id)

    probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
    output_pred = {f"[{MANIFESTO_NUM_DICT[i]}] {MANIFESTO_LABEL_NAMES[int(MANIFESTO_NUM_DICT[i])]}": probs[i] for i in np.argsort(probs)[::-1]}
    output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
    return output_pred, output_info

def predict_cap(text, language):
    model_id = build_huggingface_path(language)
    tokenizer_id = "xlm-roberta-large"

    if is_disk_full():
        os.system('rm -rf /data/models*')
        os.system('rm -r ~/.cache/huggingface/hub')
        
    return predict(text, model_id, tokenizer_id)

demo = gr.Interface(
    title="Manifesto Babel Demo",
    fn=predict_cap,
    inputs=[gr.Textbox(lines=6, label="Input"),
            gr.Dropdown(languages, label="Language", value=languages[6])],
    outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])