Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -44,16 +44,24 @@ def predict(text):
|
|
44 |
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64)
|
45 |
out_sent, out_event, out_idx = model(encoded["input_ids"], encoded["attention_mask"])
|
46 |
|
|
|
47 |
sent_label = torch.argmax(out_sent, dim=1).item()
|
48 |
sent_result = sentiment_map[sent_label]
|
49 |
|
|
|
50 |
event_label = torch.argmax(out_event, dim=1).item()
|
51 |
event_result = event_map.get(event_label, "unknown")
|
52 |
|
|
|
53 |
idx_probs = torch.sigmoid(out_idx).squeeze(0)
|
54 |
-
idx_result = [
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
-
return sent_result, event_result, ", ".join(idx_result)
|
57 |
|
58 |
iface = gr.Interface(
|
59 |
fn=predict,
|
|
|
44 |
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64)
|
45 |
out_sent, out_event, out_idx = model(encoded["input_ids"], encoded["attention_mask"])
|
46 |
|
47 |
+
# Sentiment
|
48 |
sent_label = torch.argmax(out_sent, dim=1).item()
|
49 |
sent_result = sentiment_map[sent_label]
|
50 |
|
51 |
+
# Event
|
52 |
event_label = torch.argmax(out_event, dim=1).item()
|
53 |
event_result = event_map.get(event_label, "unknown")
|
54 |
|
55 |
+
# IDX (multi-label + bullish/bearish status)
|
56 |
idx_probs = torch.sigmoid(out_idx).squeeze(0)
|
57 |
+
idx_result = []
|
58 |
+
|
59 |
+
for i, prob in enumerate(idx_probs):
|
60 |
+
status = "Bullish" if prob.item() > 0.5 else "Bearish"
|
61 |
+
idx_result.append(f"{idx_labels[i]}: {status} ({prob.item():.2f})")
|
62 |
+
|
63 |
+
return sent_result, event_result, "\n".join(idx_result)
|
64 |
|
|
|
65 |
|
66 |
iface = gr.Interface(
|
67 |
fn=predict,
|