File size: 394 Bytes
73384c6
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

class BathSalt1DaedalusPhi3Model(AutoModelForSeq2SeqLM):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

    def forward(self, input_ids, attention_mask, labels):
        outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
        return outputs