plipustel commited on
Commit
fb33a38
·
verified ·
1 Parent(s): c1076c4

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +28 -0
model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class IndoBERTMultitask(nn.Module):
5
+ def __init__(self, base_model, hidden_size=768, num_event_labels=43):
6
+ super(IndoBERTMultitask, self).__init__()
7
+ self.bert = base_model
8
+ self.dropout = nn.Dropout(0.3)
9
+
10
+ # Task heads
11
+ self.sentiment_head = nn.Linear(hidden_size, 3) # 3 kelas
12
+ self.event_head = nn.Linear(hidden_size, num_event_labels) # Event (43 kelas)
13
+
14
+ # IDX sektor (multi-label 11 sektor)
15
+ self.idx_heads = nn.ModuleList([
16
+ nn.Linear(hidden_size, 1) for _ in range(11)
17
+ ])
18
+
19
+ def forward(self, input_ids, attention_mask):
20
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
21
+ pooled = outputs.last_hidden_state[:, 0] # Ambil token CLS
22
+ pooled = self.dropout(pooled)
23
+
24
+ sentiment_logits = self.sentiment_head(pooled)
25
+ event_logits = self.event_head(pooled)
26
+ idx_logits = torch.cat([head(pooled) for head in self.idx_heads], dim=1)
27
+
28
+ return sentiment_logits, event_logits, idx_logits