File size: 5,611 Bytes
b2c474d
e970aef
646a0c2
b2c474d
ce9b3a4
 
e970aef
ce9b3a4
b2c474d
e970aef
 
 
 
646a0c2
e970aef
646a0c2
e970aef
646a0c2
e970aef
 
 
 
 
 
646a0c2
e970aef
 
 
 
 
 
 
646a0c2
e970aef
b2c474d
ce9b3a4
b2c474d
ce9b3a4
 
 
b2c474d
 
ce9b3a4
 
 
b2c474d
e970aef
646a0c2
e970aef
 
 
 
 
646a0c2
 
e970aef
646a0c2
e970aef
646a0c2
 
 
 
 
 
 
 
e970aef
 
646a0c2
 
e970aef
646a0c2
e970aef
ce9b3a4
646a0c2
 
 
e970aef
646a0c2
 
 
e970aef
ce9b3a4
e970aef
 
 
ce9b3a4
646a0c2
e970aef
646a0c2
 
 
e970aef
646a0c2
 
ce9b3a4
e970aef
 
646a0c2
ce9b3a4
e970aef
ce9b3a4
 
e970aef
646a0c2
e970aef
 
 
 
 
 
646a0c2
 
ce9b3a4
 
e970aef
 
ce9b3a4
 
e970aef
ce9b3a4
 
e970aef
 
 
 
 
 
646a0c2
e970aef
 
646a0c2
e970aef
 
 
646a0c2
 
 
ce9b3a4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

# Configuration
MODEL_NAME = "RekaAI/reka-flash-3"
DEFAULT_MAX_LENGTH = 4096  # Reduced for CPU efficiency
DEFAULT_TEMPERATURE = 0.7

# System prompt with reasoning instructions
SYSTEM_PROMPT = """You are Reka Flash-3, a helpful AI assistant created by Reka AI.
When responding, think step-by-step within <thinking> tags and conclude your answer after </thinking>.
For example:
User: What is 2+2?
Assistant: <thinking>Let me calculate that. 2 plus 2 equals 4.</thinking> The answer is 4."""

# Load model and tokenizer with 4-bit quantization
try:
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=quantization_config,
        device_map="auto",  # Maps to CPU
        torch_dtype=torch.float16
    )
    tokenizer.pad_token = tokenizer.eos_token  # Ensure padding works
except Exception as e:
    raise Exception(f"Failed to load model: {str(e)}. Ensure access to {MODEL_NAME} and sufficient CPU memory.")

def generate_response(
    message,
    chat_history,
    system_prompt,
    max_length,
    temperature,
    top_p,
    top_k,
    repetition_penalty,
    show_reasoning
):
    """Generate a response from Reka Flash-3 with reasoning tags."""
    try:
        # Format chat history and prompt (multi-round conversation)
        history_str = ""
        for user_msg, assistant_msg in chat_history:
            history_str += f"human: {user_msg} <sep> assistant: {assistant_msg} <sep> "
        prompt = f"{system_prompt} <sep> human: {message} <sep> assistant: <thinking>\n"

        # Tokenize input
        inputs = tokenizer(prompt, return_tensors="pt").to("cpu")

        # Generate response with budget forcing
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_length,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            eos_token_id=tokenizer.convert_tokens_to_ids("<sep>"),  # Stop at <sep>
            pad_token_id=tokenizer.eos_token_id
        )

        # Decode and clean response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        response = response[len(prompt):].split("<sep>")[0].strip()  # Extract assistant response

        # Parse reasoning and final answer
        if "</thinking>" in response:
            reasoning, final_answer = response.split("</thinking>", 1)
            reasoning = reasoning.replace("<thinking>", "").strip()
            final_answer = final_answer.strip()
        else:
            reasoning = ""
            final_answer = response

        # Update chat history (drop reasoning to save tokens)
        chat_history.append({"role": "user", "content": message})
        chat_history.append({"role": "assistant", "content": final_answer})

        # Display reasoning if requested
        reasoning_display = f"**Reasoning:**\n{reasoning}" if show_reasoning and reasoning else ""
        return "", chat_history, reasoning_display

    except Exception as e:
        error_msg = f"Error: {str(e)}"
        gr.Warning(error_msg)
        return "", chat_history, error_msg

# Gradio Interface
with gr.Blocks(title="Reka Flash-3 Chat", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # Reka Flash-3 Chat Interface
    *Powered by [Reka AI](https://www.reka.ai/)* - A 21B parameter reasoning model optimized for CPU.
    """)

    with gr.Accordion("Deployment Instructions", open=True):
        gr.Textbox(
            value="""To deploy on Hugging Face Spaces:
1. Request access to RekaAI/reka-flash-3 from Reka AI.
2. Use a Pro subscription with zero-GPU (CPU-only) hardware.
3. Ensure 32GB+ CPU memory for 4-bit quantization.
4. Install dependencies: gradio, transformers, torch, bitsandbytes.""",
            label="How to Deploy",
            interactive=False
        )

    with gr.Row():
        chatbot = gr.Chatbot(type="messages", height=400, label="Conversation")
        reasoning_display = gr.Textbox(label="Model Reasoning", interactive=False, lines=8)

    with gr.Row():
        message = gr.Textbox(label="Your Message", placeholder="Ask me anything...", lines=2)
        submit_btn = gr.Button("Send", variant="primary")

    with gr.Accordion("Options", open=True):
        max_length = gr.Slider(128, 512, value=DEFAULT_MAX_LENGTH, label="Max Length", step=64)
        temperature = gr.Slider(0.1, 2.0, value=DEFAULT_TEMPERATURE, label="Temperature", step=0.1)
        top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top-p", step=0.05)
        top_k = gr.Slider(1, 100, value=50, label="Top-k", step=1)
        repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, label="Repetition Penalty", step=0.1)

    system_prompt = gr.Textbox(label="System Prompt", value=SYSTEM_PROMPT, lines=4)
    show_reasoning = gr.Checkbox(label="Show Reasoning", value=True)

    # Event handling
    inputs = [message, chatbot, system_prompt, max_length, temperature, top_p, top_k, repetition_penalty, show_reasoning]
    outputs = [message, chatbot, reasoning_display]
    submit_btn.click(generate_response, inputs=inputs, outputs=outputs)
    message.submit(generate_response, inputs=inputs, outputs=outputs)

demo.launch(debug=True)