Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| from typing import List | |
| import torch | |
| from transformers import ( | |
| LogitsProcessor, | |
| ) | |
| class StopAfterTokenIsGenerated(LogitsProcessor): | |
| def __init__(self, stops: List[torch.tensor], eos_token_id: int): | |
| super().__init__() | |
| self.stops = stops | |
| self.eos_token_id = eos_token_id | |
| logging.info(f"Stopping criteria words ids: {self.stops}") | |
| self.first_batch = True | |
| def __call__( | |
| self, input_ids: torch.LongTensor, scores: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| """ | |
| Args: | |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) | |
| scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): | |
| Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam | |
| search or log softmax for each vocabulary token when using beam search | |
| Return: | |
| `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores. | |
| """ | |
| if self.first_batch: | |
| self.first_batch = False | |
| return scores | |
| for seq_no, seq in enumerate(input_ids): | |
| # logging.info(seq_no) | |
| for stop in self.stops: | |
| stop = stop.to(device=seq.device, dtype=seq.dtype) | |
| if ( | |
| len(seq) >= len(stop) | |
| and torch.all((stop == seq[-len(stop) :])).item() | |
| ): | |
| scores[seq_no, :] = -float("inf") | |
| scores[seq_no, self.eos_token_id] = 0 | |
| logging.info(f"Stopping criteria found: {stop}") | |
| break | |
| return scores | |
| def reset(self): | |
| self.first_batch = True | |