LamzaingGraceAI / app.py
gracexu's picture
add app.py
579a9aa
raw
history blame
3.49 kB
import argparse
import gradio as gr
from peft import AutoPeftModelForCausalLM
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
parser = argparse.ArgumentParser()
parser.add_argument("--model_path_or_id",
type=str,
default = "NousResearch/Llama-2-7b-hf",
required = False,
help = "Model ID or path to saved model")
parser.add_argument("--lora_path",
type=str,
default = None,
required = False,
help = "Path to the saved lora adapter")
args = parser.parse_args()
if args.lora_path:
# load base LLM model with PEFT Adapter
model = AutoPeftModelForCausalLM.from_pretrained(
args.lora_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
)
tokenizer = AutoTokenizer.from_pretrained(args.lora_path)
else:
model = AutoModelForCausalLM.from_pretrained(
args.model_path_or_id,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True
)
tokenizer = AutoTokenizer.from_pretrained(args.model_path_or_id)
with gr.Blocks() as demo:
gr.HTML(f"""
<h2> Instruction Chat Bot Demo </h2>
<h3> Model ID : {args.model_path_or_id} </h3>
<h3> Peft Adapter : {args.lora_path} </h3>
""")
chat_history = gr.Chatbot(label = "Instruction Bot")
msg = gr.Textbox(label = "Instruction")
with gr.Accordion(label = "Generation Parameters", open = False):
prompt_format = gr.Textbox(
label = "Formatting prompt",
value = "{instruction}",
lines = 8)
with gr.Row():
max_new_tokens = gr.Number(minimum = 25, maximum = 500, value = 100, label = "Max New Tokens")
temperature = gr.Slider(minimum = 0, maximum = 1.0, value = 0.7, label = "Temperature")
clear = gr.ClearButton([msg, chat_history])
def user(user_message, history):
return "", [[user_message, None]]
def bot(chat_history, prompt_format, max_new_tokens, temperature):
# Format the instruction using the format string with key
# {instruction}
formatted_inst = prompt_format.format(
instruction = chat_history[-1][0]
)
# Tokenize the input
input_ids = tokenizer(
formatted_inst,
return_tensors="pt",
truncation=True).input_ids.cuda()
# Support for streaming of tokens within generate requires
# generation to run in a separate thread
streamer = TextIteratorStreamer(tokenizer, skip_prompt = True)
generation_kwargs = dict(
input_ids = input_ids,
streamer = streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=0.9,
temperature=temperature,
use_cache=True
)
thread = Thread(target = model.generate, kwargs = generation_kwargs)
thread.start()
chat_history[-1][1] = ""
for new_text in streamer:
chat_history[-1][1] += new_text
yield chat_history
msg.submit(user,[msg, chat_history], [msg, chat_history], queue = False).then(
bot, [chat_history, prompt_format, max_new_tokens, temperature], chat_history
)
demo.queue()
demo.launch()