Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer, AutoModel | |
from huggingface_hub import hf_hub_download | |
from model import IndoBERTMultitask | |
import gradio as gr | |
# Load tokenizer dari Hugging Face Hub | |
tokenizer = AutoTokenizer.from_pretrained("plipustel/idxstockBERT") | |
# Load base IndoBERT | |
base_model = AutoModel.from_pretrained("indobenchmark/indobert-base-p1") | |
# Load model multitask custom | |
model = IndoBERTMultitask(base_model=base_model, num_event_labels=43) | |
# Download bobot model dari Hugging Face Hub | |
model_path = hf_hub_download(repo_id="plipustel/idxstockBERT", filename="pytorch_model.bin") | |
model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
model.eval() | |
# Label mapping | |
sentiment_map = {0: "Negative", 1: "Netral", 2: "Positive"} | |
event_map = { | |
0: "IPO", | |
1: "banking", | |
2: "banking sector", | |
3: "commodity", | |
4: "commodity + stock recommendation", | |
5: "company scandal", | |
6: "corporate action", | |
7: "corporate appointment", | |
8: "corporate expansion", | |
9: "corporate finance", | |
10: "corporate outlook", | |
11: "corporate outlook + stock recommendation", | |
12: "corporate performance", | |
13: "corporate rebranding", | |
14: "corporate strategy", | |
15: "crypto", | |
16: "currency", | |
17: "delisting", | |
18: "foreign activity", | |
19: "global market", | |
20: "government policy", | |
21: "insider activity", | |
22: "interest rate", | |
23: "investment product", | |
24: "investment strategy", | |
25: "ipo", | |
26: "logistics expansion", | |
27: "macro economy", | |
28: "macroeconomy", | |
29: "market index", | |
30: "market index + foreign activity", | |
31: "market index + stock recommendation", | |
32: "market outlook", | |
33: "market outlook + stock recommendation", | |
34: "merger & acquisition", | |
35: "politics", | |
36: "private placement", | |
37: "property", | |
38: "regulation", | |
39: "sector outlook", | |
40: "sector rotation", | |
41: "stock recommendation", | |
42: "trading activity" | |
} | |
idx_labels = [ | |
"idx_energy", "idx_basic", "idx_indust", "idx_noncyc", "idx_cyclic", | |
"idx_health", "idx_finance", "idx_propert", "idx_techno", "idx_infra", "idx_trans" | |
] | |
def predict(text): | |
with torch.no_grad(): | |
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64) | |
out_sent, out_event, out_idx = model(encoded["input_ids"], encoded["attention_mask"]) | |
# Sentiment | |
sent_label = torch.argmax(out_sent, dim=1).item() | |
sent_result = sentiment_map[sent_label] | |
# Event | |
event_label = torch.argmax(out_event, dim=1).item() | |
event_result = event_map.get(event_label, "unknown") | |
# IDX (multi-label + bullish/bearish status) | |
idx_probs = torch.sigmoid(out_idx).squeeze(0) | |
idx_result = [] | |
for i, prob in enumerate(idx_probs): | |
status = "Bullish" if prob.item() > 0.5 else "Bearish" | |
idx_result.append(f"{idx_labels[i]}: {status} ({prob.item():.2f})") | |
return sent_result, event_result, "\n".join(idx_result) | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(lines=2, placeholder="Masukkan headline berita saham..."), | |
outputs=[ | |
gr.Textbox(label="Sentiment"), | |
gr.Textbox(label="Event Type"), | |
gr.Textbox(label="IDX Sectors"), | |
], | |
title="IndoBERT Multitask: Berita Saham", | |
description="Masukkan 1 headline berita saham. Model akan memprediksi sentimen, tipe event, dan sektor IDX terkait." | |
) | |
iface.launch() | |