cta2106 commited on
Commit
094951f
·
1 Parent(s): 8ccfa53

added hawkishness score output

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -1,16 +1,16 @@
1
- from typing import Any, Dict
2
 
3
  from transformers import pipeline, LongformerForSequenceClassification, LongformerTokenizer, Trainer
4
  import gradio as gr
5
 
6
 
7
- def predict_fn(text: str) -> Dict[str, Any]:
8
  model = LongformerForSequenceClassification.from_pretrained("model")
9
  tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
10
  p = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
11
  results = p(text)
12
  factor = 100 if results[0]['label'] == 'Hawkish' else -100
13
- return {"label": results[0]['label'], "hawkishness_score": round(results[0]['score'] * factor, 0)}
14
 
15
 
16
- gr.Interface(predict_fn, "textbox", "label").launch()
 
1
+ from typing import Any
2
 
3
  from transformers import pipeline, LongformerForSequenceClassification, LongformerTokenizer, Trainer
4
  import gradio as gr
5
 
6
 
7
+ def predict_fn(text: str) -> tuple[Any, Any]:
8
  model = LongformerForSequenceClassification.from_pretrained("model")
9
  tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
10
  p = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
11
  results = p(text)
12
  factor = 100 if results[0]['label'] == 'Hawkish' else -100
13
+ return results[0]['label'], round(results[0]['score'] * factor, 0)
14
 
15
 
16
+ gr.Interface(predict_fn, "textbox", ["label", "label"]).launch()