hanzla's picture
first
42ce24e
raw
history blame
1.57 kB
import gradio as gr
import spaces
import torch
import transformers
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "ModularityAI/gemma-2b-datascience-it-raft"
pipeline = transformers.pipeline(
"text-generation",
model=model_name,
model_kwargs={"torch_dtype": torch.bfloat16},
device="cuda",
)
def format_test_question(q):
return f"<bos><start_of_turn>user {q} <end_of_turn>model "
@spaces.GPU
def chat_function(message, history,max_new_tokens,temperature):
prompt = format_test_question(message)
print(prompt)
temp = temperature + 0.1
outputs = pipeline(
prompt,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temp,
top_p=0.9,
)
return outputs[0]["generated_text"][len(prompt):]
gr.ChatInterface(
chat_function,
chatbot=gr.Chatbot(height=400),
textbox=gr.Textbox(placeholder="Enter message here", container=False, scale=7),
title="Gemma 2B Data Science QA RAFT Demo",
description="""
This space is dedicated for chatting with Gemma 2B Finetuned for Data Science QA using RAFT. Find this model here: https://huggingface.co/ModularityAI/gemma-2b-datascience-it-raft
Feel free to play with customization in the "Additional Inputs".
Fine tune Notebook: https://www.kaggle.com/code/hanzlajavaid/gemma-finetuning-raft-technique
""",
theme="Monochrome",
additional_inputs=[
gr.Slider(512, 4096, label="Max New Tokens"),
gr.Slider(0, 1, label="Temperature")
]
).launch()