idxstockBERT / app.py
plipustel's picture
Create app.py
c1076c4 verified
raw
history blame
2.82 kB
import torch
from transformers import AutoTokenizer, AutoModel
from model import IndoBERTMultitask
import gradio as gr
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("plipustel/idxstockBERT")
# Load model
base_model = AutoModel.from_pretrained("indobenchmark/indobert-base-p1")
model = IndoBERTMultitask(base_model=base_model, num_event_labels=43)
model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu"))
model.eval()
# Label mapping
sentiment_map = {0: "Negatif", 1: "Positif", 2: "Netral"}
# Event mapping (ganti sesuai mapping di dataset Tuan)
event_map = {
0: "corporate action", 1: "divestment", 2: "ipo", 3: "regulation", 4: "insider",
5: "market outlook", 6: "forex", 7: "commodity", 8: "finance report", 9: "debt",
10: "macroeconomic", 11: "acquisition", 12: "merger", 13: "product launch",
14: "litigation", 15: "rating", 16: "investment", 17: "employment", 18: "governance",
19: "funding", 20: "competition", 21: "environment", 22: "infrastructure",
23: "trade", 24: "tax", 25: "technology", 26: "sector update", 27: "bankruptcy",
28: "security", 29: "restructuring", 30: "sustainability", 31: "ownership",
32: "inflation", 33: "interest rate", 34: "monetary policy", 35: "digital economy",
36: "consumer behavior", 37: "political", 38: "climate", 39: "energy",
40: "supply chain", 41: "geopolitical", 42: "misc"
}
# IDX sector labels
idx_labels = [
"idx_energy", "idx_basic", "idx_indust", "idx_noncyc", "idx_cyclic",
"idx_health", "idx_finance", "idx_propert", "idx_techno", "idx_infra", "idx_trans"
]
# Inference function
def predict(text):
with torch.no_grad():
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64)
out_sent, out_event, out_idx = model(encoded["input_ids"], encoded["attention_mask"])
# Sentiment
sent_label = torch.argmax(out_sent, dim=1).item()
sent_result = sentiment_map[sent_label]
# Event
event_label = torch.argmax(out_event, dim=1).item()
event_result = event_map.get(event_label, "unknown")
# IDX (multi-label)
idx_probs = torch.sigmoid(out_idx).squeeze(0)
idx_result = [label for i, label in enumerate(idx_labels) if idx_probs[i] > 0.5]
return sent_result, event_result, ", ".join(idx_result)
# Gradio UI
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=2, placeholder="Masukkan headline berita saham..."),
outputs=[
gr.Textbox(label="Sentiment"),
gr.Textbox(label="Event Type"),
gr.Textbox(label="IDX Sectors"),
],
title="IndoBERT Multitask: Berita Saham",
description="Masukkan 1 headline berita saham. Model akan memprediksi sentimen, tipe event, dan sektor IDX terkait."
)
iface.launch()