File size: 3,457 Bytes
eb173d2
 
 
 
 
 
 
 
 
 
ae818da
eb173d2
9925a18
eb173d2
 
 
 
 
 
 
ae818da
 
eb173d2
 
261d20f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb173d2
 
 
 
 
 
 
 
 
 
 
 
 
261d20f
61e3c30
261d20f
fa1a4e5
261d20f
 
 
95e10b6
c97614f
64f51f4
261d20f
eb173d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c6d143
eb173d2
 
 
 
 
 
ae818da
eb173d2
ae818da
 
eb173d2
 
 
 
f7e1e22
40ba46f
eb173d2
 
 
f7e1e22
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr

import os
import torch
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from huggingface_hub import HfApi

from label_dicts import ILLFRAMES_MIGRATION_LABEL_NAMES, ILLFRAMES_COVID_LABEL_NAMES, ILLFRAMES_WAR_LABEL_NAMES

HF_TOKEN = os.environ["hf_read"]

languages = [
    "English"
]

domains = {
    "Covid": "covid",
    "Migration": "migration",
    "War": "war"
}


# --- DEBUG ---
import shutil

def convert_size(size):
    for unit in ['B', 'KB', 'MB', 'GB', 'TB', 'PB']:
        if size < 1024:
            return f"{size:.2f} {unit}"
        size /= 1024

def get_disk_space(path="/"):
    total, used, free = shutil.disk_usage(path)
    
    return {
        "Total": convert_size(total),
        "Used": convert_size(used),
        "Free": convert_size(free)
    }

# ---

def check_huggingface_path(checkpoint_path: str):
    try:
        hf_api = HfApi(token=HF_TOKEN)
        hf_api.model_info(checkpoint_path, token=HF_TOKEN)
        return True
    except:
        return False

def build_huggingface_path(domain: str):
    return f"poltextlab/xlm-roberta-large-english-ILLFRAMES-{domain}"

def predict(text, model_id, tokenizer_id, label_names):
    device = torch.device("cpu")
    try:
        model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, offload_folder="offload", device_map="auto", token=HF_TOKEN)
    except:
        disk_space = get_disk_space('/data/')
        print("Disk Space Error:")
        for key, value in disk_space.items():
            print(f"{key}: {value}")

        shutil.rmtree("/data")
        model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", token=HF_TOKEN, force_download=True)
        
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

    inputs = tokenizer(text,
                       max_length=256,
                       truncation=True,
                       padding="do_not_pad",
                       return_tensors="pt").to(device)
    model.eval()

    with torch.no_grad():
        logits = model(**inputs).logits

    probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()

    NUMS_DICT = {i: key for i, key in enumerate(sorted(label_names.keys()))}

    output_pred = {f"[{NUMS_DICT[i]}] {label_names[NUMS_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]}
    output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
    return output_pred, output_info

def predict_illframes(text, language, domain):   
    domain = domains[domain]
    model_id = build_huggingface_path(domain)
    tokenizer_id = "xlm-roberta-large"

    if domain == "migration":
        label_names = ILLFRAMES_MIGRATION_LABEL_NAMES
    elif domain == "covid":
        label_names = ILLFRAMES_COVID_LABEL_NAMES
    elif domain == "war":
        label_names = ILLFRAMES_WAR_LABEL_NAMES

    return predict(text, model_id, tokenizer_id, label_names)

demo = gr.Interface(
    title="ILLFRAMES Babel Demo",
    fn=predict_illframes,
    inputs=[gr.Textbox(lines=6, label="Input"),
            gr.Dropdown(languages, label="Language"),
            gr.Dropdown(domains.keys(), label="Domain")],
    outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()])