idxstockBERT / model.py
plipustel's picture
Create model.py
fb33a38 verified
import torch
import torch.nn as nn
class IndoBERTMultitask(nn.Module):
def __init__(self, base_model, hidden_size=768, num_event_labels=43):
super(IndoBERTMultitask, self).__init__()
self.bert = base_model
self.dropout = nn.Dropout(0.3)
# Task heads
self.sentiment_head = nn.Linear(hidden_size, 3) # 3 kelas
self.event_head = nn.Linear(hidden_size, num_event_labels) # Event (43 kelas)
# IDX sektor (multi-label 11 sektor)
self.idx_heads = nn.ModuleList([
nn.Linear(hidden_size, 1) for _ in range(11)
])
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled = outputs.last_hidden_state[:, 0] # Ambil token CLS
pooled = self.dropout(pooled)
sentiment_logits = self.sentiment_head(pooled)
event_logits = self.event_head(pooled)
idx_logits = torch.cat([head(pooled) for head in self.idx_heads], dim=1)
return sentiment_logits, event_logits, idx_logits