File size: 3,473 Bytes
4bba8df
 
 
 
 
 
 
 
 
 
89d4ec8
4bba8df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a115642
89d4ec8
61e094a
89d4ec8
b6f6b98
5c48f48
 
 
 
a552e92
89d4ec8
 
20681cd
 
 
 
89d4ec8
 
 
 
4bba8df
 
 
 
 
 
 
 
 
1849f87
4bba8df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89d4ec8
a115642
4bba8df
555f614
4bba8df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bee266b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr

import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from huggingface_hub import HfApi

from label_dicts import CAP_MIN_NUM_DICT, CAP_MIN_LABEL_NAMES, CAP_LABEL_NAMES 

from .utils import is_disk_full

HF_TOKEN = os.environ["hf_read"]

languages = [
    "Multilingual",
]

domains = {
    "media": "media",
    "social media": "social",
    "parliamentary speech": "parlspeech",
    "legislative documents": "legislative",
    "executive speech": "execspeech",
    "executive order": "execorder",
    "party programs": "party",
    "judiciary": "judiciary",
    "budget": "budget",
    "public opinion": "publicopinion",
    "local government agenda": "localgovernment"
}

def convert_minor_to_major(results, probs):
    results_as_text = dict()
    for i in results:
        prob = probs[i]
        major_code = str(CAP_MIN_NUM_DICT[i])[:-2]

        if major_code == "99":
            major_code = "999"
        
        label = CAP_LABEL_NAMES[int(major_code)]

        key = f"[{major_code}] {label}"
        if key in results_as_text:
            results_as_text[key] += probs[i]
        else:
            results_as_text[key] = probs[i]

    return results_as_text
        

def check_huggingface_path(checkpoint_path: str):
    try:
        hf_api = HfApi(token=HF_TOKEN)
        hf_api.model_info(checkpoint_path, token=HF_TOKEN)
        return True
    except:
        return False

def build_huggingface_path(language: str, domain: str):
    return "poltextlab/xlm-roberta-large-pooled-cap-minor-v3"

def predict(text, model_id, tokenizer_id):
    device = torch.device("cpu")
    model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

    inputs = tokenizer(text,
                       max_length=256,
                       truncation=True,
                       padding="do_not_pad",
                       return_tensors="pt").to(device)
    model.eval()

    with torch.no_grad():
        logits = model(**inputs).logits

    probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
    output_pred_minor = {f"[{CAP_MIN_NUM_DICT[i]}] {CAP_MIN_LABEL_NAMES[CAP_MIN_NUM_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
    output_pred_major = convert_minor_to_major(np.argsort(probs)[::-1], probs)
    output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
    return output_pred_minor, output_pred_major, output_info

def predict_cap(text, language, domain):
    domain = domains[domain]
    model_id = build_huggingface_path(language, domain)
    tokenizer_id = "xlm-roberta-large"
    
    if is_disk_full():
        os.system('rm -rf /data/models*')
        os.system('rm -r ~/.cache/huggingface/hub')
        
    return predict(text, model_id, tokenizer_id)

demo = gr.Interface(
    title="CAP Minor Topics Babel Demo",
    fn=predict_cap,
    inputs=[gr.Textbox(lines=6, label="Input"),
            gr.Dropdown(languages, label="Language"),
            gr.Dropdown(domains.keys(), label="Domain")],
    outputs=[gr.Label(num_top_classes=5, label="Output minor"), gr.Label(num_top_classes=5, label="Output major"), gr.Markdown()])