Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| from transformers import AutoModelForSequenceClassification | |
| from transformers import AutoTokenizer | |
| from huggingface_hub import HfApi | |
| from label_dicts import ILLFRAMES_MIGRATION_LABEL_NAMES, ILLFRAMES_COVID_LABEL_NAMES, ILLFRAMES_WAR_LABEL_NAMES | |
| from .utils import is_disk_full, release_model | |
| HF_TOKEN = os.environ["hf_read"] | |
| languages = [ | |
| "English" | |
| ] | |
| domains = { | |
| "Covid": "covid", | |
| "Migration": "migration", | |
| "War": "war" | |
| } | |
| # --- DEBUG --- | |
| import shutil | |
| def convert_size(size): | |
| for unit in ['B', 'KB', 'MB', 'GB', 'TB', 'PB']: | |
| if size < 1024: | |
| return f"{size:.2f} {unit}" | |
| size /= 1024 | |
| def get_disk_space(path="/"): | |
| total, used, free = shutil.disk_usage(path) | |
| return { | |
| "Total": convert_size(total), | |
| "Used": convert_size(used), | |
| "Free": convert_size(free) | |
| } | |
| # --- | |
| def check_huggingface_path(checkpoint_path: str): | |
| try: | |
| hf_api = HfApi(token=HF_TOKEN) | |
| hf_api.model_info(checkpoint_path, token=HF_TOKEN) | |
| return True | |
| except: | |
| return False | |
| def build_huggingface_path(domain: str): | |
| return f"poltextlab/xlm-roberta-large-english-ILLFRAMES-{domain}" | |
| def predict(text, model_id, tokenizer_id, label_names): | |
| device = torch.device("cpu") | |
| # Load JIT-traced model | |
| jit_model_path = f"/data/jit_models/{model_id.replace('/', '_')}.pt" | |
| model = torch.jit.load(jit_model_path).to(device) | |
| model.eval() | |
| # Load tokenizer (still regular HF) | |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) | |
| # Tokenize input | |
| inputs = tokenizer( | |
| text, | |
| max_length=64, | |
| truncation=True, | |
| padding=True, | |
| return_tensors="pt" | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model(inputs["input_ids"], inputs["attention_mask"]) | |
| print(output) # debug | |
| logits = output["logits"] | |
| release_model(model, model_id) | |
| probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten() | |
| NUMS_DICT = {i: key for i, key in enumerate(sorted(label_names.keys()))} | |
| output_pred = {f"[{NUMS_DICT[i]}] {label_names[NUMS_DICT[i]]}": probs[i] for i in np.argsort(probs)[::-1]} | |
| output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>' | |
| return output_pred, output_info | |
| def predict_illframes(text, language, domain): | |
| domain = domains[domain] | |
| model_id = build_huggingface_path(domain) | |
| tokenizer_id = "xlm-roberta-large" | |
| if domain == "migration": | |
| label_names = ILLFRAMES_MIGRATION_LABEL_NAMES | |
| elif domain == "covid": | |
| label_names = ILLFRAMES_COVID_LABEL_NAMES | |
| elif domain == "war": | |
| label_names = ILLFRAMES_WAR_LABEL_NAMES | |
| if is_disk_full(): | |
| os.system('rm -rf /data/models*') | |
| os.system('rm -r ~/.cache/huggingface/hub') | |
| return predict(text, model_id, tokenizer_id, label_names) | |
| demo = gr.Interface( | |
| title="ILLFRAMES Babel Demo", | |
| fn=predict_illframes, | |
| inputs=[gr.Textbox(lines=6, label="Input"), | |
| gr.Dropdown(languages, label="Language", value=languages[0]), | |
| gr.Dropdown(domains.keys(), label="Domain", value=list(domains.keys())[0])], | |
| outputs=[gr.Label(num_top_classes=5, label="Output"), gr.Markdown()]) | |