Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer, AutoModel | |
from model import IndoBERTMultitask | |
import gradio as gr | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("plipustel/idxstockBERT") | |
# Load model | |
base_model = AutoModel.from_pretrained("indobenchmark/indobert-base-p1") | |
model = IndoBERTMultitask(base_model=base_model, num_event_labels=43) | |
model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu")) | |
model.eval() | |
# Label mapping | |
sentiment_map = {0: "Negatif", 1: "Positif", 2: "Netral"} | |
# Event mapping (ganti sesuai mapping di dataset Tuan) | |
event_map = { | |
0: "corporate action", 1: "divestment", 2: "ipo", 3: "regulation", 4: "insider", | |
5: "market outlook", 6: "forex", 7: "commodity", 8: "finance report", 9: "debt", | |
10: "macroeconomic", 11: "acquisition", 12: "merger", 13: "product launch", | |
14: "litigation", 15: "rating", 16: "investment", 17: "employment", 18: "governance", | |
19: "funding", 20: "competition", 21: "environment", 22: "infrastructure", | |
23: "trade", 24: "tax", 25: "technology", 26: "sector update", 27: "bankruptcy", | |
28: "security", 29: "restructuring", 30: "sustainability", 31: "ownership", | |
32: "inflation", 33: "interest rate", 34: "monetary policy", 35: "digital economy", | |
36: "consumer behavior", 37: "political", 38: "climate", 39: "energy", | |
40: "supply chain", 41: "geopolitical", 42: "misc" | |
} | |
# IDX sector labels | |
idx_labels = [ | |
"idx_energy", "idx_basic", "idx_indust", "idx_noncyc", "idx_cyclic", | |
"idx_health", "idx_finance", "idx_propert", "idx_techno", "idx_infra", "idx_trans" | |
] | |
# Inference function | |
def predict(text): | |
with torch.no_grad(): | |
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64) | |
out_sent, out_event, out_idx = model(encoded["input_ids"], encoded["attention_mask"]) | |
# Sentiment | |
sent_label = torch.argmax(out_sent, dim=1).item() | |
sent_result = sentiment_map[sent_label] | |
# Event | |
event_label = torch.argmax(out_event, dim=1).item() | |
event_result = event_map.get(event_label, "unknown") | |
# IDX (multi-label) | |
idx_probs = torch.sigmoid(out_idx).squeeze(0) | |
idx_result = [label for i, label in enumerate(idx_labels) if idx_probs[i] > 0.5] | |
return sent_result, event_result, ", ".join(idx_result) | |
# Gradio UI | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(lines=2, placeholder="Masukkan headline berita saham..."), | |
outputs=[ | |
gr.Textbox(label="Sentiment"), | |
gr.Textbox(label="Event Type"), | |
gr.Textbox(label="IDX Sectors"), | |
], | |
title="IndoBERT Multitask: Berita Saham", | |
description="Masukkan 1 headline berita saham. Model akan memprediksi sentimen, tipe event, dan sektor IDX terkait." | |
) | |
iface.launch() | |