Spaces:
Sleeping
Sleeping
import uuid | |
from typing import Any, Dict, List, Optional, Union | |
from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available | |
from ..utils import logging | |
from .base import PIPELINE_INIT_ARGS, Pipeline | |
if is_tf_available(): | |
import tensorflow as tf | |
if is_torch_available(): | |
import torch | |
logger = logging.get_logger(__name__) | |
class Conversation: | |
""" | |
Utility class containing a conversation and its history. This class is meant to be used as an input to the | |
:class:`~transformers.ConversationalPipeline`. The conversation contains a number of utility function to manage the | |
addition of new user input and generated model responses. A conversation needs to contain an unprocessed user input | |
before being passed to the :class:`~transformers.ConversationalPipeline`. This user input is either created when | |
the class is instantiated, or by calling :obj:`conversational_pipeline.append_response("input")` after a | |
conversation turn. | |
Arguments: | |
text (:obj:`str`, `optional`): | |
The initial user input to start the conversation. If not provided, a user input needs to be provided | |
manually using the :meth:`~transformers.Conversation.add_user_input` method before the conversation can | |
begin. | |
conversation_id (:obj:`uuid.UUID`, `optional`): | |
Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the | |
conversation. | |
past_user_inputs (:obj:`List[str]`, `optional`): | |
Eventual past history of the conversation of the user. You don't need to pass it manually if you use the | |
pipeline interactively but if you want to recreate history you need to set both :obj:`past_user_inputs` and | |
:obj:`generated_responses` with equal length lists of strings | |
generated_responses (:obj:`List[str]`, `optional`): | |
Eventual past history of the conversation of the model. You don't need to pass it manually if you use the | |
pipeline interactively but if you want to recreate history you need to set both :obj:`past_user_inputs` and | |
:obj:`generated_responses` with equal length lists of strings | |
Usage:: | |
conversation = Conversation("Going to the movies tonight - any suggestions?") | |
# Steps usually performed by the model when generating a response: | |
# 1. Mark the user input as processed (moved to the history) | |
conversation.mark_processed() | |
# 2. Append a mode response | |
conversation.append_response("The Big lebowski.") | |
conversation.add_user_input("Is it good?") | |
""" | |
def __init__( | |
self, text: str = None, conversation_id: uuid.UUID = None, past_user_inputs=None, generated_responses=None | |
): | |
if not conversation_id: | |
conversation_id = uuid.uuid4() | |
if past_user_inputs is None: | |
past_user_inputs = [] | |
if generated_responses is None: | |
generated_responses = [] | |
self.uuid: uuid.UUID = conversation_id | |
self.past_user_inputs: List[str] = past_user_inputs | |
self.generated_responses: List[str] = generated_responses | |
self.new_user_input: Optional[str] = text | |
def __eq__(self, other): | |
if not isinstance(other, Conversation): | |
return False | |
if self.uuid == other.uuid: | |
return True | |
return ( | |
self.new_user_input == other.new_user_input | |
and self.past_user_inputs == other.past_user_inputs | |
and self.generated_responses == other.generated_responses | |
) | |
def add_user_input(self, text: str, overwrite: bool = False): | |
""" | |
Add a user input to the conversation for the next round. This populates the internal :obj:`new_user_input` | |
field. | |
Args: | |
text (:obj:`str`): The user input for the next conversation round. | |
overwrite (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not existing and unprocessed user input should be overwritten when this function is called. | |
""" | |
if self.new_user_input: | |
if overwrite: | |
logger.warning( | |
f'User input added while unprocessed input was existing: "{self.new_user_input}" was overwritten ' | |
f'with: "{text}".' | |
) | |
self.new_user_input = text | |
else: | |
logger.warning( | |
f'User input added while unprocessed input was existing: "{self.new_user_input}" new input ' | |
f'ignored: "{text}". Set `overwrite` to True to overwrite unprocessed user input' | |
) | |
else: | |
self.new_user_input = text | |
def mark_processed(self): | |
""" | |
Mark the conversation as processed (moves the content of :obj:`new_user_input` to :obj:`past_user_inputs`) and | |
empties the :obj:`new_user_input` field. | |
""" | |
if self.new_user_input: | |
self.past_user_inputs.append(self.new_user_input) | |
self.new_user_input = None | |
def append_response(self, response: str): | |
""" | |
Append a response to the list of generated responses. | |
Args: | |
response (:obj:`str`): The model generated response. | |
""" | |
self.generated_responses.append(response) | |
def iter_texts(self): | |
""" | |
Iterates over all blobs of the conversation. | |
Returns: Iterator of (is_user, text_chunk) in chronological order of the conversation. ``is_user`` is a | |
:obj:`bool`, ``text_chunks`` is a :obj:`str`. | |
""" | |
for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses): | |
yield True, user_input | |
yield False, generated_response | |
if self.new_user_input: | |
yield True, self.new_user_input | |
def __repr__(self): | |
""" | |
Generates a string representation of the conversation. | |
Return: | |
:obj:`str`: | |
Example: Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114 user >> Going to the movies tonight - any | |
suggestions? bot >> The Big Lebowski | |
""" | |
output = f"Conversation id: {self.uuid} \n" | |
for is_user, text in self.iter_texts(): | |
name = "user" if is_user else "bot" | |
output += f"{name} >> {text} \n" | |
return output | |
class ConversationalPipeline(Pipeline): | |
""" | |
Multi-turn conversational pipeline. | |
This conversational pipeline can currently be loaded from :func:`~transformers.pipeline` using the following task | |
identifier: :obj:`"conversational"`. | |
The models that this pipeline can use are models that have been fine-tuned on a multi-turn conversational task, | |
currently: `'microsoft/DialoGPT-small'`, `'microsoft/DialoGPT-medium'`, `'microsoft/DialoGPT-large'`. See the | |
up-to-date list of available models on `huggingface.co/models | |
<https://huggingface.co/models?filter=conversational>`__. | |
Usage:: | |
conversational_pipeline = pipeline("conversational") | |
conversation_1 = Conversation("Going to the movies tonight - any suggestions?") | |
conversation_2 = Conversation("What's the last book you have read?") | |
conversational_pipeline([conversation_1, conversation_2]) | |
conversation_1.add_user_input("Is it an action movie?") | |
conversation_2.add_user_input("What is the genre of this book?") | |
conversational_pipeline([conversation_1, conversation_2]) | |
""" | |
def __init__(self, min_length_for_response=32, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
# We need at least an eos_token | |
assert self.tokenizer.eos_token_id is not None, "ConversationalPipeline tokenizer should have an EOS token set" | |
if self.tokenizer.pad_token_id is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.min_length_for_response = min_length_for_response | |
def __call__( | |
self, | |
conversations: Union[Conversation, List[Conversation]], | |
clean_up_tokenization_spaces=True, | |
**generate_kwargs | |
): | |
r""" | |
Generate responses for the conversation(s) given as inputs. | |
Args: | |
conversations (a :class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`): | |
Conversations to generate responses for. | |
clean_up_tokenization_spaces (:obj:`bool`, `optional`, defaults to :obj:`False`): | |
Whether or not to clean up the potential extra spaces in the text output. | |
generate_kwargs: | |
Additional keyword arguments to pass along to the generate method of the model (see the generate method | |
corresponding to your framework `here <./model.html#generative-models>`__). | |
Returns: | |
:class:`~transformers.Conversation` or a list of :class:`~transformers.Conversation`: Conversation(s) with | |
updated generated responses for those containing a new user input. | |
""" | |
if isinstance(conversations, Conversation): | |
conversations = [conversations] | |
# Input validation | |
if isinstance(conversations, list): | |
for conversation in conversations: | |
assert isinstance( | |
conversation, Conversation | |
), "ConversationalPipeline expects a Conversation or list of Conversations as an input" | |
if conversation.new_user_input is None: | |
raise ValueError( | |
f"Conversation with UUID {type(conversation.uuid)} does not contain new user input to process. " | |
"Add user inputs with the conversation's `add_user_input` method" | |
) | |
assert ( | |
self.tokenizer.pad_token_id is not None or self.tokenizer.eos_token_id is not None | |
), "Please make sure that the tokenizer has a pad_token_id or eos_token_id when using a batch input" | |
else: | |
raise ValueError("ConversationalPipeline expects a Conversation or list of Conversations as an input") | |
with self.device_placement(): | |
inputs = self._parse_and_tokenize(conversations) | |
if self.framework == "pt": | |
inputs = self.ensure_tensor_on_device(**inputs) | |
input_length = inputs["input_ids"].shape[-1] | |
elif self.framework == "tf": | |
input_length = tf.shape(inputs["input_ids"])[-1].numpy() | |
generated_responses = self.model.generate( | |
inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
**generate_kwargs, | |
) | |
if self.model.config.is_encoder_decoder: | |
if self.framework == "pt": | |
history = torch.cat((inputs["input_ids"], generated_responses[:, 1:]), 1) | |
elif self.framework == "tf": | |
history = tf.concat([inputs["input_ids"], generated_responses[:, 1:]], 1) | |
else: | |
history = generated_responses | |
history = self._clean_padding_history(history) | |
if self.model.config.is_encoder_decoder: | |
start_position = 1 | |
else: | |
start_position = input_length | |
output = [] | |
for conversation_index, conversation in enumerate(conversations): | |
conversation.mark_processed() | |
conversation.generated_responses.append( | |
self.tokenizer.decode( | |
generated_responses[conversation_index][start_position:], | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=clean_up_tokenization_spaces, | |
) | |
) | |
output.append(conversation) | |
if len(output) == 1: | |
return output[0] | |
else: | |
return output | |
def _clean_padding_history(self, generated_tensor) -> List[List[int]]: | |
""" | |
Cleans the padding history. Padding may be generated in two places when multiple conversations are provided as | |
an input: | |
- at the end of the concatenated history and new user input, so that all input to the model have the same | |
length | |
- at the end of the generated response, as some responses will be longer than others | |
This method cleans up these padding token so that the history for each conversation is not impacted by the | |
batching process. | |
""" | |
outputs = [] | |
for sequence in generated_tensor: | |
sequence_tokens = [] | |
is_previous_pad = False | |
for token in sequence: | |
if token == self.tokenizer.pad_token_id: | |
if self.tokenizer.pad_token_id != self.tokenizer.eos_token_id: | |
continue | |
if is_previous_pad: | |
continue | |
else: | |
is_previous_pad = True | |
else: | |
is_previous_pad = False | |
if self.framework == "pt": | |
sequence_tokens.append(token.item()) | |
else: | |
sequence_tokens.append(int(token.numpy())) | |
outputs.append(sequence_tokens) | |
return outputs | |
def _legacy_parse_and_tokenize(self, conversation: List[Conversation]) -> List[int]: | |
eos_token_id = self.tokenizer.eos_token_id | |
input_ids = [] | |
for is_user, text in conversation.iter_texts(): | |
input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False) + [eos_token_id]) | |
if len(input_ids) > self.tokenizer.model_max_length: | |
input_ids = input_ids[-self.model_max_length :] | |
return input_ids | |
def _parse_and_tokenize(self, conversations: List[Conversation]) -> Dict[str, Any]: | |
if hasattr(self.tokenizer, "_build_conversation_input_ids"): | |
input_ids = [self.tokenizer._build_conversation_input_ids(conversation) for conversation in conversations] | |
else: | |
# If the tokenizer cannot handle conversations, we default to only the old version | |
input_ids = [self._legacy_parse_and_tokenize(conversation) for conversation in conversations] | |
inputs = self.tokenizer.pad( | |
{"input_ids": input_ids}, padding="longest", return_attention_mask=True, return_tensors=self.framework | |
) | |
return inputs | |