|
|
|
class ChatState: |
|
|
|
def __init__(self, model, system="", chat_template="auto"): |
|
chat_template = ( |
|
type(model).__name__ if chat_template == "auto" else chat_template |
|
) |
|
|
|
if chat_template == "Llama3CausalLM": |
|
self.__START_TURN_SYSTEM__ = ( |
|
"<|start_header_id|>system<|end_header_id|>\n\n" |
|
) |
|
self.__START_TURN_USER__ = ( |
|
"<|start_header_id|>user<|end_header_id|>\n\n" |
|
) |
|
self.__START_TURN_MODEL__ = ( |
|
"<|start_header_id|>assistant<|end_header_id|>\n\n" |
|
) |
|
self.__END_TURN_SYSTEM__ = "<|eot_id|>" |
|
self.__END_TURN_USER__ = "<|eot_id|>" |
|
self.__END_TURN_MODEL__ = "<|eot_id|>" |
|
print("Using chat template for: Llama") |
|
elif chat_template == "GemmaCausalLM": |
|
self.__START_TURN_SYSTEM__ = "" |
|
self.__START_TURN_USER__ = "<start_of_turn>user\n" |
|
self.__START_TURN_MODEL__ = "<start_of_turn>model\n" |
|
self.__END_TURN_SYSTEM__ = "\n" |
|
self.__END_TURN_USER__ = "<end_of_turn>\n" |
|
self.__END_TURN_MODEL__ = "<end_of_turn>\n" |
|
print("Using chat template for: Gemma") |
|
elif chat_template == "MistralCausalLM": |
|
self.__START_TURN_SYSTEM__ = "" |
|
self.__START_TURN_USER__ = "[INST]" |
|
self.__START_TURN_MODEL__ = "" |
|
self.__END_TURN_SYSTEM__ = "<s>" |
|
self.__END_TURN_USER__ = "[/INST]" |
|
self.__END_TURN_MODEL__ = "</s>" |
|
print("Using chat template for: Mistral") |
|
elif chat_template == "Vicuna": |
|
self.__START_TURN_SYSTEM__ = "" |
|
self.__START_TURN_USER__ = "USER: " |
|
self.__START_TURN_MODEL__ = "ASSISTANT: " |
|
self.__END_TURN_SYSTEM__ = "\n\n" |
|
self.__END_TURN_USER__ = "\n" |
|
self.__END_TURN_MODEL__ = "</s>\n" |
|
print("Using chat template for : Vicuna") |
|
else: |
|
assert (0, "Unknown turn tags for this model class") |
|
|
|
self.model = model |
|
self.system = system |
|
self.history = [] |
|
|
|
def add_to_history_as_user(self, message): |
|
self.history.append( |
|
self.__START_TURN_USER__ + message + self.__END_TURN_USER__ |
|
) |
|
|
|
def add_to_history_as_model(self, message): |
|
self.history.append( |
|
self.__START_TURN_MODEL__ + message + self.__END_TURN_MODEL__ |
|
) |
|
|
|
def get_history(self): |
|
return "".join([*self.history]) |
|
|
|
def get_full_prompt(self): |
|
prompt = self.get_history() + self.__START_TURN_MODEL__ |
|
if len(self.system) > 0: |
|
prompt = ( |
|
self.__START_TURN_SYSTEM__ |
|
+ self.system |
|
+ self.__END_TURN_SYSTEM__ |
|
+ prompt |
|
) |
|
return prompt |
|
|
|
def send_message(self, message): |
|
""" |
|
Handles sending a user message and getting a model response. |
|
|
|
Args: |
|
message: The user's message. |
|
|
|
Returns: |
|
The model's response. |
|
""" |
|
self.add_to_history_as_user(message) |
|
prompt = self.get_full_prompt() |
|
response = self.model.generate( |
|
prompt, max_length=2048, strip_prompt=True |
|
) |
|
self.add_to_history_as_model(response) |
|
return (message, response) |
|
|