from transformers import AutoTokenizer, Pipeline import torch class PairTextClassificationPipeline(Pipeline): def __init__(self, model, tokenizer=None, **kwargs): # Initialize tokenizer first if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) # Make sure we store the tokenizer before calling super().__init__ self.tokenizer = tokenizer super().__init__(model=model, tokenizer=tokenizer, **kwargs) self.prompt = " Determine if the hypothesis is true given the premise?\n\nPremise: {text1}\n\nHypothesis: {text2}" def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} return preprocess_kwargs, {}, {} def preprocess(self, inputs): # Expect inputs to be list of (Premise, Hypothesis) tuples pair_dict = {'text1': inputs[0], 'text2': inputs[1]} formatted_prompt = self.prompt.format(**pair_dict) model_inputs = self.tokenizer( formatted_prompt, return_tensors='pt', padding=True ) return model_inputs def _forward(self, model_inputs): model_outputs = self.model(**model_inputs) return model_outputs def postprocess(self, model_outputs): logits = model_outputs.logits logits = logits[:, 0, :] # tok_cls transformed_probs = torch.softmax(logits, dim=-1) raw_scores = transformed_probs[:, 1] # probability of class 1 return raw_scores.item()