plipustel commited on
Commit
b314282
·
verified ·
1 Parent(s): 1d21cf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -44,16 +44,24 @@ def predict(text):
44
  encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64)
45
  out_sent, out_event, out_idx = model(encoded["input_ids"], encoded["attention_mask"])
46
 
 
47
  sent_label = torch.argmax(out_sent, dim=1).item()
48
  sent_result = sentiment_map[sent_label]
49
 
 
50
  event_label = torch.argmax(out_event, dim=1).item()
51
  event_result = event_map.get(event_label, "unknown")
52
 
 
53
  idx_probs = torch.sigmoid(out_idx).squeeze(0)
54
- idx_result = [label for i, label in enumerate(idx_labels) if idx_probs[i] > 0.5]
 
 
 
 
 
 
55
 
56
- return sent_result, event_result, ", ".join(idx_result)
57
 
58
  iface = gr.Interface(
59
  fn=predict,
 
44
  encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64)
45
  out_sent, out_event, out_idx = model(encoded["input_ids"], encoded["attention_mask"])
46
 
47
+ # Sentiment
48
  sent_label = torch.argmax(out_sent, dim=1).item()
49
  sent_result = sentiment_map[sent_label]
50
 
51
+ # Event
52
  event_label = torch.argmax(out_event, dim=1).item()
53
  event_result = event_map.get(event_label, "unknown")
54
 
55
+ # IDX (multi-label + bullish/bearish status)
56
  idx_probs = torch.sigmoid(out_idx).squeeze(0)
57
+ idx_result = []
58
+
59
+ for i, prob in enumerate(idx_probs):
60
+ status = "Bullish" if prob.item() > 0.5 else "Bearish"
61
+ idx_result.append(f"{idx_labels[i]}: {status} ({prob.item():.2f})")
62
+
63
+ return sent_result, event_result, "\n".join(idx_result)
64
 
 
65
 
66
  iface = gr.Interface(
67
  fn=predict,