Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
import torch | |
# Configuration | |
MODEL_NAME = "RekaAI/reka-flash-3" | |
DEFAULT_MAX_LENGTH = 256 | |
DEFAULT_TEMPERATURE = 0.7 | |
SYSTEM_PROMPT = """You are Reka Flash-3, a helpful AI assistant created by Reka AI.""" | |
# Load model and tokenizer | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4" | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
quantization_config=quantization_config, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
def generate_response(message, chat_history, system_prompt, max_length, temperature, top_p, top_k, repetition_penalty): | |
prompt = f"{system_prompt} <sep> human: {message} <sep> assistant: " | |
inputs = tokenizer(prompt, return_tensors="pt").to("cpu") | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True).split("<sep>")[2].strip() | |
chat_history.append({"user": message, "assistant": response}) | |
return "", chat_history | |
# Gradio Interface | |
with gr.Blocks(title="Reka Flash-3 Chat") as demo: | |
gr.Markdown("# Reka Flash-3 Chat Interface") | |
chatbot = gr.Chatbot(type="messages", height=400, label="Conversation") | |
with gr.Row(): | |
message = gr.Textbox(label="Your Message", placeholder="Ask me anything...") | |
submit_btn = gr.Button("Send") | |
with gr.Accordion("Options", open=False): | |
max_length = gr.Slider(128, 512, value=DEFAULT_MAX_LENGTH, label="Max Length") | |
temperature = gr.Slider(0.1, 2.0, value=DEFAULT_TEMPERATURE, label="Temperature") | |
top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top-p") | |
top_k = gr.Slider(1, 100, value=50, label="Top-k") | |
repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, label="Repetition Penalty") | |
system_prompt = gr.Textbox(label="System Prompt", value=SYSTEM_PROMPT, lines=4) | |
inputs = [message, chatbot, system_prompt, max_length, temperature, top_p, top_k, repetition_penalty] | |
outputs = [message, chatbot] | |
submit_btn.click(generate_response, inputs=inputs, outputs=outputs) | |
message.submit(generate_response, inputs=inputs, outputs=outputs) | |
demo.launch() |