File size: 394 Bytes
73384c6 |
1 2 3 4 5 6 7 8 9 10 11 |
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
class BathSalt1DaedalusPhi3Model(AutoModelForSeq2SeqLM):
def __init__(self, config):
super().__init__(config)
self.config = config
def forward(self, input_ids, attention_mask, labels):
outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
return outputs |