Spaces:
Sleeping
Sleeping
File size: 3,510 Bytes
c1076c4 1df3627 c1076c4 1df3627 c1076c4 1df3627 c1076c4 1df3627 c1076c4 1df3627 c1076c4 b74cb14 c1076c4 b74cb14 c1076c4 b74cb14 c1076c4 b314282 c1076c4 b314282 c1076c4 b314282 c1076c4 b314282 c1076c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import torch
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import hf_hub_download
from model import IndoBERTMultitask
import gradio as gr
# Load tokenizer dari Hugging Face Hub
tokenizer = AutoTokenizer.from_pretrained("plipustel/idxstockBERT")
# Load base IndoBERT
base_model = AutoModel.from_pretrained("indobenchmark/indobert-base-p1")
# Load model multitask custom
model = IndoBERTMultitask(base_model=base_model, num_event_labels=43)
# Download bobot model dari Hugging Face Hub
model_path = hf_hub_download(repo_id="plipustel/idxstockBERT", filename="pytorch_model.bin")
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# Label mapping
sentiment_map = {0: "Negative", 1: "Netral", 2: "Positive"}
event_map = {
0: "IPO",
1: "banking",
2: "banking sector",
3: "commodity",
4: "commodity + stock recommendation",
5: "company scandal",
6: "corporate action",
7: "corporate appointment",
8: "corporate expansion",
9: "corporate finance",
10: "corporate outlook",
11: "corporate outlook + stock recommendation",
12: "corporate performance",
13: "corporate rebranding",
14: "corporate strategy",
15: "crypto",
16: "currency",
17: "delisting",
18: "foreign activity",
19: "global market",
20: "government policy",
21: "insider activity",
22: "interest rate",
23: "investment product",
24: "investment strategy",
25: "ipo",
26: "logistics expansion",
27: "macro economy",
28: "macroeconomy",
29: "market index",
30: "market index + foreign activity",
31: "market index + stock recommendation",
32: "market outlook",
33: "market outlook + stock recommendation",
34: "merger & acquisition",
35: "politics",
36: "private placement",
37: "property",
38: "regulation",
39: "sector outlook",
40: "sector rotation",
41: "stock recommendation",
42: "trading activity"
}
idx_labels = [
"idx_energy", "idx_basic", "idx_indust", "idx_noncyc", "idx_cyclic",
"idx_health", "idx_finance", "idx_propert", "idx_techno", "idx_infra", "idx_trans"
]
def predict(text):
with torch.no_grad():
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64)
out_sent, out_event, out_idx = model(encoded["input_ids"], encoded["attention_mask"])
# Sentiment
sent_label = torch.argmax(out_sent, dim=1).item()
sent_result = sentiment_map[sent_label]
# Event
event_label = torch.argmax(out_event, dim=1).item()
event_result = event_map.get(event_label, "unknown")
# IDX (multi-label + bullish/bearish status)
idx_probs = torch.sigmoid(out_idx).squeeze(0)
idx_result = []
for i, prob in enumerate(idx_probs):
status = "Bullish" if prob.item() > 0.5 else "Bearish"
idx_result.append(f"{idx_labels[i]}: {status} ({prob.item():.2f})")
return sent_result, event_result, "\n".join(idx_result)
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=2, placeholder="Masukkan headline berita saham..."),
outputs=[
gr.Textbox(label="Sentiment"),
gr.Textbox(label="Event Type"),
gr.Textbox(label="IDX Sectors"),
],
title="IndoBERT Multitask: Berita Saham",
description="Masukkan 1 headline berita saham. Model akan memprediksi sentimen, tipe event, dan sektor IDX terkait."
)
iface.launch()
|