File size: 672 Bytes
8ccfa53
9fbf14c
38d4932
 
 
 
8ccfa53
38d4932
 
 
9fbf14c
 
 
38d4932
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from typing import Any, Dict

from transformers import pipeline, LongformerForSequenceClassification, LongformerTokenizer, Trainer
import gradio as gr


def predict_fn(text: str) -> Dict[str, Any]:
    model = LongformerForSequenceClassification.from_pretrained("model")
    tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
    p = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
    results = p(text)
    factor = 100 if results[0]['label'] == 'Hawkish' else -100
    return {"label": results[0]['label'], "hawkishness_score": round(results[0]['score'] * factor, 0)}


gr.Interface(predict_fn, "textbox", "label").launch()