Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModel
|
3 |
+
from model import IndoBERTMultitask
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
# Load tokenizer
|
7 |
+
tokenizer = AutoTokenizer.from_pretrained("plipustel/idxstockBERT")
|
8 |
+
|
9 |
+
# Load model
|
10 |
+
base_model = AutoModel.from_pretrained("indobenchmark/indobert-base-p1")
|
11 |
+
model = IndoBERTMultitask(base_model=base_model, num_event_labels=43)
|
12 |
+
model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu"))
|
13 |
+
model.eval()
|
14 |
+
|
15 |
+
# Label mapping
|
16 |
+
sentiment_map = {0: "Negatif", 1: "Positif", 2: "Netral"}
|
17 |
+
|
18 |
+
# Event mapping (ganti sesuai mapping di dataset Tuan)
|
19 |
+
event_map = {
|
20 |
+
0: "corporate action", 1: "divestment", 2: "ipo", 3: "regulation", 4: "insider",
|
21 |
+
5: "market outlook", 6: "forex", 7: "commodity", 8: "finance report", 9: "debt",
|
22 |
+
10: "macroeconomic", 11: "acquisition", 12: "merger", 13: "product launch",
|
23 |
+
14: "litigation", 15: "rating", 16: "investment", 17: "employment", 18: "governance",
|
24 |
+
19: "funding", 20: "competition", 21: "environment", 22: "infrastructure",
|
25 |
+
23: "trade", 24: "tax", 25: "technology", 26: "sector update", 27: "bankruptcy",
|
26 |
+
28: "security", 29: "restructuring", 30: "sustainability", 31: "ownership",
|
27 |
+
32: "inflation", 33: "interest rate", 34: "monetary policy", 35: "digital economy",
|
28 |
+
36: "consumer behavior", 37: "political", 38: "climate", 39: "energy",
|
29 |
+
40: "supply chain", 41: "geopolitical", 42: "misc"
|
30 |
+
}
|
31 |
+
|
32 |
+
# IDX sector labels
|
33 |
+
idx_labels = [
|
34 |
+
"idx_energy", "idx_basic", "idx_indust", "idx_noncyc", "idx_cyclic",
|
35 |
+
"idx_health", "idx_finance", "idx_propert", "idx_techno", "idx_infra", "idx_trans"
|
36 |
+
]
|
37 |
+
|
38 |
+
# Inference function
|
39 |
+
def predict(text):
|
40 |
+
with torch.no_grad():
|
41 |
+
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64)
|
42 |
+
out_sent, out_event, out_idx = model(encoded["input_ids"], encoded["attention_mask"])
|
43 |
+
|
44 |
+
# Sentiment
|
45 |
+
sent_label = torch.argmax(out_sent, dim=1).item()
|
46 |
+
sent_result = sentiment_map[sent_label]
|
47 |
+
|
48 |
+
# Event
|
49 |
+
event_label = torch.argmax(out_event, dim=1).item()
|
50 |
+
event_result = event_map.get(event_label, "unknown")
|
51 |
+
|
52 |
+
# IDX (multi-label)
|
53 |
+
idx_probs = torch.sigmoid(out_idx).squeeze(0)
|
54 |
+
idx_result = [label for i, label in enumerate(idx_labels) if idx_probs[i] > 0.5]
|
55 |
+
|
56 |
+
return sent_result, event_result, ", ".join(idx_result)
|
57 |
+
|
58 |
+
# Gradio UI
|
59 |
+
iface = gr.Interface(
|
60 |
+
fn=predict,
|
61 |
+
inputs=gr.Textbox(lines=2, placeholder="Masukkan headline berita saham..."),
|
62 |
+
outputs=[
|
63 |
+
gr.Textbox(label="Sentiment"),
|
64 |
+
gr.Textbox(label="Event Type"),
|
65 |
+
gr.Textbox(label="IDX Sectors"),
|
66 |
+
],
|
67 |
+
title="IndoBERT Multitask: Berita Saham",
|
68 |
+
description="Masukkan 1 headline berita saham. Model akan memprediksi sentimen, tipe event, dan sektor IDX terkait."
|
69 |
+
)
|
70 |
+
|
71 |
+
iface.launch()
|