Spaces:
Runtime error
Runtime error
| import re | |
| import os | |
| from typing import TypeVar | |
| from functools import cache | |
| import logging | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| ) | |
| from peft import PeftModel | |
| from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline | |
| from langchain_community.chat_models import ChatLiteLLM | |
| from langchain.chains import LLMChain | |
| from langchain.output_parsers import PydanticOutputParser | |
| from langchain.prompts import ( | |
| ChatPromptTemplate, | |
| HumanMessagePromptTemplate, | |
| PromptTemplate, | |
| ) | |
| from langchain.schema import BaseOutputParser, OutputParserException | |
| from message_classes import ActionType, AgentAction | |
| from utils import format_docstring | |
| from langchain_callback_handler import LoggingCallbackHandler | |
| HF_TOKEN_KEY_FILE="./hf_token.key" | |
| OutputType = TypeVar("OutputType", bound=object) | |
| log = logging.getLogger("generate") | |
| logging_handler = LoggingCallbackHandler("langchain") | |
| def generate_action( | |
| model_name: str, | |
| history: str, | |
| turn_number: int, | |
| action_types: list[ActionType], | |
| agent: str, | |
| temperature: float = 0.7, | |
| ) -> AgentAction: | |
| """ | |
| Using langchain to generate an example episode | |
| """ | |
| try: | |
| # Normal case, model as agent | |
| template = """ | |
| Imagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal. | |
| You can find {agent}'s goal (or background) in the 'Here is the context of the interaction' field. | |
| Note that {agent}'s goal is only visible to you. | |
| You should try your best to achieve {agent}'s goal in a way that align with their character traits. | |
| Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before). | |
| {history}. | |
| You are at Turn #{turn_number}. Your available action types are | |
| {action_list}. | |
| Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave. | |
| Please only generate a JSON string including the action type and the argument. | |
| Your action should follow the given format: | |
| {format_instructions} | |
| """ | |
| return generate( | |
| model_name=model_name, | |
| template=template, | |
| input_values=dict( | |
| agent=agent, | |
| turn_number=str(turn_number), | |
| history=history, | |
| action_list=" ".join(action_types), | |
| ), | |
| output_parser=PydanticOutputParser(pydantic_object=AgentAction), | |
| temperature=temperature, | |
| ) | |
| except Exception: | |
| return AgentAction(action_type="none", argument="") | |
| def prepare_model(model_name, hf_token_key_file=HF_TOKEN_KEY_FILE): | |
| compute_type = torch.float16 | |
| if os.path.exists(hf_token_key_file): | |
| with open (hf_token_key_file, 'r') as f: | |
| hf_token = f.read().strip() | |
| else: | |
| hf_token = os.environ["HF_TOKEN"] | |
| if model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR': | |
| tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "mistralai/Mistral-7B-Instruct-v0.1", | |
| cache_dir="./.cache", | |
| device_map='cuda', | |
| token=hf_token | |
| ) | |
| model = PeftModel.from_pretrained(model, model_name).to("cuda") | |
| elif model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit': | |
| tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "mistralai/Mistral-7B-Instruct-v0.1", | |
| cache_dir="./.cache", | |
| device_map='cuda', | |
| quantization_config=BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=compute_type, | |
| ), | |
| token=hf_token | |
| ) | |
| model = PeftModel.from_pretrained(model, model_name).to("cuda") | |
| elif model_name == 'mistralai/Mistral-7B-Instruct-v0.1': | |
| tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "mistralai/Mistral-7B-Instruct-v0.1", | |
| cache_dir="./.cache", | |
| device_map='cuda', | |
| token=hf_token | |
| ) | |
| else: | |
| raise RuntimeError(f"Model {model_name} not supported") | |
| return model, tokenizer | |
| def obtain_chain_hf( | |
| model_name: str, | |
| template: str, | |
| input_variables: list[str], | |
| temperature: float = 0.7, | |
| max_retries: int = 6, | |
| max_tokens: int = 2700 | |
| ) -> LLMChain: | |
| human_message_prompt = HumanMessagePromptTemplate( | |
| prompt=PromptTemplate(template=template, input_variables=input_variables) | |
| ) | |
| chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt]) | |
| model, tokenizer = prepare_model(model_name) | |
| pipe = pipeline("text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=100, | |
| temperature=temperature, | |
| return_full_text=False, | |
| do_sample=True, | |
| num_beams=3, | |
| length_penalty=-1.0, | |
| ) | |
| hf = HuggingFacePipeline(pipeline=pipe) | |
| chain = LLMChain(llm=hf, prompt=chat_prompt_template) | |
| return chain | |
| def generate( | |
| model_name: str, | |
| template: str, | |
| input_values: dict[str, str], | |
| output_parser: BaseOutputParser[OutputType], | |
| temperature: float = 0.7, | |
| ) -> OutputType: | |
| input_variables = re.findall(r"{(.*?)}", template) | |
| assert ( | |
| set(input_variables) == set(list(input_values.keys()) + ["format_instructions"]) | |
| or set(input_variables) == set(list(input_values.keys())) | |
| ), f"The variables in the template must match input_values except for format_instructions. Got {sorted(input_values.keys())}, expect {sorted(input_variables)}" | |
| # process template | |
| template = format_docstring(template) | |
| chain = obtain_chain(model_name, template, input_variables, temperature) | |
| if "format_instructions" not in input_values: | |
| input_values["format_instructions"] = output_parser.get_format_instructions() | |
| result = chain.predict([logging_handler], **input_values) | |
| prompt = logging_handler.retrive_prompt() | |
| try: | |
| parsed_result = output_parser.parse(result) | |
| except KeyboardInterrupt: | |
| raise KeyboardInterrupt | |
| except Exception as e: | |
| log.debug( | |
| f"[red] Failed to parse result: {result}\nEncounter Exception {e}\nstart to reparse", | |
| extra={"markup": True}, | |
| ) | |
| reformat_parsed_result = format_bad_output( | |
| result, format_instructions=output_parser.get_format_instructions() | |
| ) | |
| parsed_result = output_parser.parse(reformat_parsed_result) | |
| log.info(f"Generated result: {parsed_result}") | |
| return parsed_result | |
| def format_bad_output( | |
| ill_formed_output: str, | |
| format_instructions: str, | |
| model_name: str = "gpt-3.5-turbo", | |
| ) -> str: | |
| template = """ | |
| Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser. | |
| Original string: {ill_formed_output} | |
| Format instructions: {format_instructions} | |
| Please only generate the JSON: | |
| """ | |
| chain = obtain_chain( | |
| model_name=model_name, | |
| template=template, | |
| input_variables=re.findall(r"{(.*?)}", template), | |
| ) | |
| input_values = { | |
| "ill_formed_output": ill_formed_output, | |
| "format_instructions": format_instructions, | |
| } | |
| reformat = chain.predict([logging_handler], **input_values) | |
| log.info(f"Reformated output: {reformat}") | |
| return reformat | |
| def obtain_chain( | |
| model_name: str, | |
| template: str, | |
| input_variables: list[str], | |
| temperature: float = 0.7, | |
| max_retries: int = 6, | |
| ) -> LLMChain: | |
| """ | |
| Using langchain to sample profiles for participants | |
| """ | |
| if model_name in ["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit"]: | |
| return obtain_chain_hf( | |
| model_name=model_name, | |
| template=template, | |
| input_variables=input_variables, | |
| temperature=temperature, | |
| max_retries=max_retries, | |
| ) | |
| model_name = _return_fixed_model_version(model_name) | |
| chat = ChatLiteLLM( | |
| model=model_name, | |
| temperature=temperature, | |
| max_tokens=2700, # tweak as needed | |
| max_retries=max_retries, | |
| ) | |
| human_message_prompt = HumanMessagePromptTemplate( | |
| prompt=PromptTemplate(template=template, input_variables=input_variables) | |
| ) | |
| chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt]) | |
| chain = LLMChain(llm=chat, prompt=chat_prompt_template) | |
| return chain | |
| def _return_fixed_model_version(model_name: str) -> str: | |
| return { | |
| "gpt-3.5-turbo": "gpt-3.5-turbo-0613", | |
| "gpt-3.5-turbo-finetuned": "ft:gpt-3.5-turbo-0613:academicscmu::8nY2zgdt", | |
| "gpt-3.5-turbo-ft-MF": "ft:gpt-3.5-turbo-0613:academicscmu::8nuER4bO", | |
| "gpt-4": "gpt-4-0613", | |
| "gpt-4-turbo": "gpt-4-1106-preview", | |
| }[model_name] |