import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import torch # Configuration MODEL_NAME = "RekaAI/reka-flash-3" DEFAULT_MAX_LENGTH = 4096 # Reduced for CPU efficiency DEFAULT_TEMPERATURE = 0.7 # System prompt with reasoning instructions SYSTEM_PROMPT = """You are Reka Flash-3, a helpful AI assistant created by Reka AI. When responding, think step-by-step within tags and conclude your answer after . For example: User: What is 2+2? Assistant: Let me calculate that. 2 plus 2 equals 4. The answer is 4.""" # Load model and tokenizer with 4-bit quantization try: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=quantization_config, device_map="auto", # Maps to CPU torch_dtype=torch.float16 ) tokenizer.pad_token = tokenizer.eos_token # Ensure padding works except Exception as e: raise Exception(f"Failed to load model: {str(e)}. Ensure access to {MODEL_NAME} and sufficient CPU memory.") def generate_response( message, chat_history, system_prompt, max_length, temperature, top_p, top_k, repetition_penalty, show_reasoning ): """Generate a response from Reka Flash-3 with reasoning tags.""" try: # Format chat history and prompt (multi-round conversation) history_str = "" for user_msg, assistant_msg in chat_history: history_str += f"human: {user_msg} assistant: {assistant_msg} " prompt = f"{system_prompt} human: {message} assistant: \n" # Tokenize input inputs = tokenizer(prompt, return_tensors="pt").to("cpu") # Generate response with budget forcing outputs = model.generate( **inputs, max_new_tokens=max_length, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, do_sample=True, eos_token_id=tokenizer.convert_tokens_to_ids(""), # Stop at pad_token_id=tokenizer.eos_token_id ) # Decode and clean response response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = response[len(prompt):].split("")[0].strip() # Extract assistant response # Parse reasoning and final answer if "" in response: reasoning, final_answer = response.split("", 1) reasoning = reasoning.replace("", "").strip() final_answer = final_answer.strip() else: reasoning = "" final_answer = response # Update chat history (drop reasoning to save tokens) chat_history.append({"role": "user", "content": message}) chat_history.append({"role": "assistant", "content": final_answer}) # Display reasoning if requested reasoning_display = f"**Reasoning:**\n{reasoning}" if show_reasoning and reasoning else "" return "", chat_history, reasoning_display except Exception as e: error_msg = f"Error: {str(e)}" gr.Warning(error_msg) return "", chat_history, error_msg # Gradio Interface with gr.Blocks(title="Reka Flash-3 Chat", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # Reka Flash-3 Chat Interface *Powered by [Reka AI](https://www.reka.ai/)* - A 21B parameter reasoning model optimized for CPU. """) with gr.Accordion("Deployment Instructions", open=True): gr.Textbox( value="""To deploy on Hugging Face Spaces: 1. Request access to RekaAI/reka-flash-3 from Reka AI. 2. Use a Pro subscription with zero-GPU (CPU-only) hardware. 3. Ensure 32GB+ CPU memory for 4-bit quantization. 4. Install dependencies: gradio, transformers, torch, bitsandbytes.""", label="How to Deploy", interactive=False ) with gr.Row(): chatbot = gr.Chatbot(type="messages", height=400, label="Conversation") reasoning_display = gr.Textbox(label="Model Reasoning", interactive=False, lines=8) with gr.Row(): message = gr.Textbox(label="Your Message", placeholder="Ask me anything...", lines=2) submit_btn = gr.Button("Send", variant="primary") with gr.Accordion("Options", open=True): max_length = gr.Slider(128, 512, value=DEFAULT_MAX_LENGTH, label="Max Length", step=64) temperature = gr.Slider(0.1, 2.0, value=DEFAULT_TEMPERATURE, label="Temperature", step=0.1) top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top-p", step=0.05) top_k = gr.Slider(1, 100, value=50, label="Top-k", step=1) repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, label="Repetition Penalty", step=0.1) system_prompt = gr.Textbox(label="System Prompt", value=SYSTEM_PROMPT, lines=4) show_reasoning = gr.Checkbox(label="Show Reasoning", value=True) # Event handling inputs = [message, chatbot, system_prompt, max_length, temperature, top_p, top_k, repetition_penalty, show_reasoning] outputs = [message, chatbot, reasoning_display] submit_btn.click(generate_response, inputs=inputs, outputs=outputs) message.submit(generate_response, inputs=inputs, outputs=outputs) demo.launch(debug=True)