TohidA's picture
Updated Gradio app with new model
4e52062
raw
history blame
3.78 kB
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from transformers import AutoTokenizer
from peft import PeftModel, PeftConfig
config = PeftConfig.from_pretrained("TohidA/LlamaInstructMona")
model = AutoModelForCausalLM.from_pretrained("mlabonne/llama-2-7b-miniguanaco")
model = PeftModel.from_pretrained(model, "TohidA/LlamaInstructMona")
if torch.cuda.is_available():
model = model.cuda()
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
def prompt(instruction, input=''):
if input=='':
return f"Below is an instruction that describes a task. Write a response that appropriately completes the request. \n\n### Instruction:\n{instruction} \n\n### Response:\n"
return f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. \n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
def instruct(instruction, input='', temperature=0.7, top_p=0.95, top_k=4, max_new_tokens=128, do_sample=False, penalty_alpha=0.6, repetition_penalty=1., stop="\n\n"):
input_ids = tokenizer(prompt(instruction, input).strip(), return_tensors='pt').input_ids.cuda()
with torch.cuda.amp.autocast():
outputs = model.generate(
input_ids=input_ids,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
repetition_penalty=repetition_penalty
)
if stop=="":
return tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).split("### Response:")[1].strip(), prompt(instruction, input)
return tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).split("### Response:")[1].strip().split(stop)[0].strip(), prompt(instruction, input)
import locale
locale.getpreferredencoding = lambda: "UTF-8"
import gradio as gr
input_text = gr.Textbox(label="Input")
instruction_text = gr.Textbox(label="Instruction")
temperature = gr.Slider(label="Temperature", minimum=0, maximum=1, value=0.7, step=0.05)
top_p = gr.Slider(label="Top-P", minimum=0, maximum=1, value=0.95, step=0.01)
top_k = gr.Slider(label="Top-K", minimum=0, maximum=128, value=40, step=1)
max_new_tokens = gr.Slider(label="Tokens", minimum=1, maximum=256, value=64)
do_sample = gr.Checkbox(label="Do Sample", value=True)
penalty_alpha = gr.Slider(minimum=0, maximum=1, value=0.5)
repetition_penalty = gr.Slider(minimum=1., maximum=2., value=1., step=0.1)
stop = gr.Textbox(label="Stopping Criteria", value="")
output_prompt = gr.Textbox(label="Prompt")
output_text = gr.Textbox(label="Output")
description = """
The [TohidA/InstructLlamaMONA-withMONAdataset](https://hf.co/TohidA/LlamaInstructMona). A Llama chat 7B model finetuned on an [instruction dataset](https://huggingface.co/mlabonne/llama-2-7b-miniguanaco), then finetuned with the RL/PPO using a [Reward model](https://huggingface.co/TohidA/MONAreward) which is a BERT classifier trained on [Monda dataset](https://huggingface.co/datasets/TohidA/MONA), with [low rank adaptation](https://arxiv.org/abs/2106.09685) for a single epoch.
"""
gr.Interface(fn=instruct,
inputs=[instruction_text, input_text, temperature, top_p, top_k, max_new_tokens, do_sample, penalty_alpha, repetition_penalty, stop],
outputs=[output_text, output_prompt],
title="InstructLlamaMONA 7B Gradio Demo", description=description).launch(
debug=True,
share=True
)