plipustel commited on
Commit
1df3627
·
verified ·
1 Parent(s): c2a82aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -1,21 +1,26 @@
1
  import torch
2
  from transformers import AutoTokenizer, AutoModel
 
3
  from model import IndoBERTMultitask
4
  import gradio as gr
5
 
6
- # Load tokenizer
7
  tokenizer = AutoTokenizer.from_pretrained("plipustel/idxstockBERT")
8
 
9
- # Load model
10
  base_model = AutoModel.from_pretrained("indobenchmark/indobert-base-p1")
 
 
11
  model = IndoBERTMultitask(base_model=base_model, num_event_labels=43)
12
- model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu"))
 
 
 
13
  model.eval()
14
 
15
  # Label mapping
16
  sentiment_map = {0: "Negatif", 1: "Positif", 2: "Netral"}
17
 
18
- # Event mapping (ganti sesuai mapping di dataset Tuan)
19
  event_map = {
20
  0: "corporate action", 1: "divestment", 2: "ipo", 3: "regulation", 4: "insider",
21
  5: "market outlook", 6: "forex", 7: "commodity", 8: "finance report", 9: "debt",
@@ -29,33 +34,27 @@ event_map = {
29
  40: "supply chain", 41: "geopolitical", 42: "misc"
30
  }
31
 
32
- # IDX sector labels
33
  idx_labels = [
34
  "idx_energy", "idx_basic", "idx_indust", "idx_noncyc", "idx_cyclic",
35
  "idx_health", "idx_finance", "idx_propert", "idx_techno", "idx_infra", "idx_trans"
36
  ]
37
 
38
- # Inference function
39
  def predict(text):
40
  with torch.no_grad():
41
  encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64)
42
  out_sent, out_event, out_idx = model(encoded["input_ids"], encoded["attention_mask"])
43
 
44
- # Sentiment
45
  sent_label = torch.argmax(out_sent, dim=1).item()
46
  sent_result = sentiment_map[sent_label]
47
 
48
- # Event
49
  event_label = torch.argmax(out_event, dim=1).item()
50
  event_result = event_map.get(event_label, "unknown")
51
 
52
- # IDX (multi-label)
53
  idx_probs = torch.sigmoid(out_idx).squeeze(0)
54
  idx_result = [label for i, label in enumerate(idx_labels) if idx_probs[i] > 0.5]
55
 
56
  return sent_result, event_result, ", ".join(idx_result)
57
 
58
- # Gradio UI
59
  iface = gr.Interface(
60
  fn=predict,
61
  inputs=gr.Textbox(lines=2, placeholder="Masukkan headline berita saham..."),
 
1
  import torch
2
  from transformers import AutoTokenizer, AutoModel
3
+ from huggingface_hub import hf_hub_download
4
  from model import IndoBERTMultitask
5
  import gradio as gr
6
 
7
+ # Load tokenizer dari Hugging Face Hub
8
  tokenizer = AutoTokenizer.from_pretrained("plipustel/idxstockBERT")
9
 
10
+ # Load base IndoBERT
11
  base_model = AutoModel.from_pretrained("indobenchmark/indobert-base-p1")
12
+
13
+ # Load model multitask custom
14
  model = IndoBERTMultitask(base_model=base_model, num_event_labels=43)
15
+
16
+ # Download bobot model dari Hugging Face Hub
17
+ model_path = hf_hub_download(repo_id="plipustel/idxstockBERT", filename="pytorch_model.bin")
18
+ model.load_state_dict(torch.load(model_path, map_location="cpu"))
19
  model.eval()
20
 
21
  # Label mapping
22
  sentiment_map = {0: "Negatif", 1: "Positif", 2: "Netral"}
23
 
 
24
  event_map = {
25
  0: "corporate action", 1: "divestment", 2: "ipo", 3: "regulation", 4: "insider",
26
  5: "market outlook", 6: "forex", 7: "commodity", 8: "finance report", 9: "debt",
 
34
  40: "supply chain", 41: "geopolitical", 42: "misc"
35
  }
36
 
 
37
  idx_labels = [
38
  "idx_energy", "idx_basic", "idx_indust", "idx_noncyc", "idx_cyclic",
39
  "idx_health", "idx_finance", "idx_propert", "idx_techno", "idx_infra", "idx_trans"
40
  ]
41
 
 
42
  def predict(text):
43
  with torch.no_grad():
44
  encoded = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64)
45
  out_sent, out_event, out_idx = model(encoded["input_ids"], encoded["attention_mask"])
46
 
 
47
  sent_label = torch.argmax(out_sent, dim=1).item()
48
  sent_result = sentiment_map[sent_label]
49
 
 
50
  event_label = torch.argmax(out_event, dim=1).item()
51
  event_result = event_map.get(event_label, "unknown")
52
 
 
53
  idx_probs = torch.sigmoid(out_idx).squeeze(0)
54
  idx_result = [label for i, label in enumerate(idx_labels) if idx_probs[i] > 0.5]
55
 
56
  return sent_result, event_result, ", ".join(idx_result)
57
 
 
58
  iface = gr.Interface(
59
  fn=predict,
60
  inputs=gr.Textbox(lines=2, placeholder="Masukkan headline berita saham..."),