import torch from transformers import AutoTokenizer, AutoModel from huggingface_hub import hf_hub_download from model import IndoBERTMultitask import gradio as gr # Load tokenizer dari Hugging Face Hub tokenizer = AutoTokenizer.from_pretrained("plipustel/idxstockBERT") # Load base IndoBERT base_model = AutoModel.from_pretrained("indobenchmark/indobert-base-p1") # Load model multitask custom model = IndoBERTMultitask(base_model=base_model, num_event_labels=43) # Download bobot model dari Hugging Face Hub model_path = hf_hub_download(repo_id="plipustel/idxstockBERT", filename="pytorch_model.bin") model.load_state_dict(torch.load(model_path, map_location="cpu")) model.eval() # Label mapping sentiment_map = {0: "Negative", 1: "Netral", 2: "Positive"} event_map = { 0: "IPO", 1: "banking", 2: "banking sector", 3: "commodity", 4: "commodity + stock recommendation", 5: "company scandal", 6: "corporate action", 7: "corporate appointment", 8: "corporate expansion", 9: "corporate finance", 10: "corporate outlook", 11: "corporate outlook + stock recommendation", 12: "corporate performance", 13: "corporate rebranding", 14: "corporate strategy", 15: "crypto", 16: "currency", 17: "delisting", 18: "foreign activity", 19: "global market", 20: "government policy", 21: "insider activity", 22: "interest rate", 23: "investment product", 24: "investment strategy", 25: "ipo", 26: "logistics expansion", 27: "macro economy", 28: "macroeconomy", 29: "market index", 30: "market index + foreign activity", 31: "market index + stock recommendation", 32: "market outlook", 33: "market outlook + stock recommendation", 34: "merger & acquisition", 35: "politics", 36: "private placement", 37: "property", 38: "regulation", 39: "sector outlook", 40: "sector rotation", 41: "stock recommendation", 42: "trading activity" } idx_labels = [ "idx_energy", "idx_basic", "idx_indust", "idx_noncyc", "idx_cyclic", "idx_health", "idx_finance", "idx_propert", "idx_techno", "idx_infra", "idx_trans" ] def predict(text): with torch.no_grad(): encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64) out_sent, out_event, out_idx = model(encoded["input_ids"], encoded["attention_mask"]) # Sentiment sent_label = torch.argmax(out_sent, dim=1).item() sent_result = sentiment_map[sent_label] # Event event_label = torch.argmax(out_event, dim=1).item() event_result = event_map.get(event_label, "unknown") # IDX (multi-label + bullish/bearish status) idx_probs = torch.sigmoid(out_idx).squeeze(0) idx_result = [] for i, prob in enumerate(idx_probs): status = "Bullish" if prob.item() > 0.5 else "Bearish" idx_result.append(f"{idx_labels[i]}: {status} ({prob.item():.2f})") return sent_result, event_result, "\n".join(idx_result) iface = gr.Interface( fn=predict, inputs=gr.Textbox(lines=2, placeholder="Masukkan headline berita saham..."), outputs=[ gr.Textbox(label="Sentiment"), gr.Textbox(label="Event Type"), gr.Textbox(label="IDX Sectors"), ], title="IndoBERT Multitask: Berita Saham", description="Masukkan 1 headline berita saham. Model akan memprediksi sentimen, tipe event, dan sektor IDX terkait." ) iface.launch()