Luna OpenLabs commited on
Commit
5a1602b
·
verified ·
1 Parent(s): 480ec49

Create model/luna_model.py

Browse files
Files changed (1) hide show
  1. model/luna_model.py +15 -0
model/luna_model.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model/luna_model.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import BertTokenizer, BertModel
5
+
6
+ class LunaAI(nn.Module):
7
+ def __init__(self):
8
+ super(LunaAI, self).__init__()
9
+ self.bert = BertModel.from_pretrained('bert-base-uncased')
10
+ self.classifier = nn.Linear(768, 2) # Adjust for number of classes
11
+
12
+ def forward(self, input_ids, attention_mask):
13
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
14
+ logits = self.classifier(outputs.pooler_output)
15
+ return logits