plipustel commited on
Commit
c1076c4
·
verified ·
1 Parent(s): 29cd4c5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel
3
+ from model import IndoBERTMultitask
4
+ import gradio as gr
5
+
6
+ # Load tokenizer
7
+ tokenizer = AutoTokenizer.from_pretrained("plipustel/idxstockBERT")
8
+
9
+ # Load model
10
+ base_model = AutoModel.from_pretrained("indobenchmark/indobert-base-p1")
11
+ model = IndoBERTMultitask(base_model=base_model, num_event_labels=43)
12
+ model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu"))
13
+ model.eval()
14
+
15
+ # Label mapping
16
+ sentiment_map = {0: "Negatif", 1: "Positif", 2: "Netral"}
17
+
18
+ # Event mapping (ganti sesuai mapping di dataset Tuan)
19
+ event_map = {
20
+ 0: "corporate action", 1: "divestment", 2: "ipo", 3: "regulation", 4: "insider",
21
+ 5: "market outlook", 6: "forex", 7: "commodity", 8: "finance report", 9: "debt",
22
+ 10: "macroeconomic", 11: "acquisition", 12: "merger", 13: "product launch",
23
+ 14: "litigation", 15: "rating", 16: "investment", 17: "employment", 18: "governance",
24
+ 19: "funding", 20: "competition", 21: "environment", 22: "infrastructure",
25
+ 23: "trade", 24: "tax", 25: "technology", 26: "sector update", 27: "bankruptcy",
26
+ 28: "security", 29: "restructuring", 30: "sustainability", 31: "ownership",
27
+ 32: "inflation", 33: "interest rate", 34: "monetary policy", 35: "digital economy",
28
+ 36: "consumer behavior", 37: "political", 38: "climate", 39: "energy",
29
+ 40: "supply chain", 41: "geopolitical", 42: "misc"
30
+ }
31
+
32
+ # IDX sector labels
33
+ idx_labels = [
34
+ "idx_energy", "idx_basic", "idx_indust", "idx_noncyc", "idx_cyclic",
35
+ "idx_health", "idx_finance", "idx_propert", "idx_techno", "idx_infra", "idx_trans"
36
+ ]
37
+
38
+ # Inference function
39
+ def predict(text):
40
+ with torch.no_grad():
41
+ encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64)
42
+ out_sent, out_event, out_idx = model(encoded["input_ids"], encoded["attention_mask"])
43
+
44
+ # Sentiment
45
+ sent_label = torch.argmax(out_sent, dim=1).item()
46
+ sent_result = sentiment_map[sent_label]
47
+
48
+ # Event
49
+ event_label = torch.argmax(out_event, dim=1).item()
50
+ event_result = event_map.get(event_label, "unknown")
51
+
52
+ # IDX (multi-label)
53
+ idx_probs = torch.sigmoid(out_idx).squeeze(0)
54
+ idx_result = [label for i, label in enumerate(idx_labels) if idx_probs[i] > 0.5]
55
+
56
+ return sent_result, event_result, ", ".join(idx_result)
57
+
58
+ # Gradio UI
59
+ iface = gr.Interface(
60
+ fn=predict,
61
+ inputs=gr.Textbox(lines=2, placeholder="Masukkan headline berita saham..."),
62
+ outputs=[
63
+ gr.Textbox(label="Sentiment"),
64
+ gr.Textbox(label="Event Type"),
65
+ gr.Textbox(label="IDX Sectors"),
66
+ ],
67
+ title="IndoBERT Multitask: Berita Saham",
68
+ description="Masukkan 1 headline berita saham. Model akan memprediksi sentimen, tipe event, dan sektor IDX terkait."
69
+ )
70
+
71
+ iface.launch()