ZoroaStrella commited on
Commit
e970aef
·
1 Parent(s): 646a0c2

Add transformers dependency and correct errors

Browse files
Files changed (2) hide show
  1. app.py +65 -109
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,30 +1,37 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
  # Configuration
6
  MODEL_NAME = "RekaAI/reka-flash-3"
7
- DEFAULT_MAX_LENGTH = 1024
8
  DEFAULT_TEMPERATURE = 0.7
9
 
10
- # System prompt with instructions for reasoning
11
- SYSTEM_PROMPT = """You are Reka Flash-3, a helpful AI assistant created by Reka AI.
12
- Provide detailed, helpful answers while maintaining safety.
13
- Format responses clearly using markdown when appropriate.
14
- When asked a question, think step by step inside <thinking> tags, then provide your final answer after </thinking> tags. For example:
15
-
16
  User: What is 2+2?
17
- Assistant: <thinking>
18
- Let me calculate that. 2 plus 2 equals 4.
19
- </thinking>
20
- The answer is 4."""
21
 
22
- # Load model and tokenizer (assuming CPU-only for zero GPU)
23
  try:
 
 
 
 
 
 
24
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
25
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="cpu", torch_dtype=torch.float32)
 
 
 
 
 
 
26
  except Exception as e:
27
- raise Exception(f"Failed to load model: {str(e)}. Ensure you have access to {MODEL_NAME} and sufficient CPU memory.")
28
 
29
  def generate_response(
30
  message,
@@ -35,21 +42,20 @@ def generate_response(
35
  top_p,
36
  top_k,
37
  repetition_penalty,
38
- presence_penalty,
39
- frequency_penalty,
40
  show_reasoning
41
  ):
42
- """
43
- Generate a response from Reka Flash-3, parsing reasoning and final answer.
44
- """
45
  try:
46
- # Format the prompt with thinking tags
47
- formatted_prompt = f"{system_prompt}\n\nUser: {message}\n\nAssistant: <thinking>\n"
 
 
 
48
 
49
  # Tokenize input
50
- inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cpu")
51
 
52
- # Generate response
53
  outputs = model.generate(
54
  **inputs,
55
  max_new_tokens=max_length,
@@ -57,127 +63,77 @@ def generate_response(
57
  top_p=top_p,
58
  top_k=top_k,
59
  repetition_penalty=repetition_penalty,
60
- presence_penalty=presence_penalty,
61
- frequency_penalty=frequency_penalty,
62
  do_sample=True,
63
- pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else 0
 
64
  )
65
 
66
- # Decode the generated text
67
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
- response = response[len(formatted_prompt):] # Remove the prompt from the output
69
 
70
  # Parse reasoning and final answer
71
  if "</thinking>" in response:
72
  reasoning, final_answer = response.split("</thinking>", 1)
73
- reasoning = reasoning.strip()
74
  final_answer = final_answer.strip()
75
  else:
76
  reasoning = ""
77
- final_answer = response.strip()
78
 
79
- # Update chat history with final answer
80
- chat_history.append((message, final_answer))
 
81
 
82
  # Display reasoning if requested
83
- reasoning_display = reasoning if show_reasoning and reasoning else ""
84
- if reasoning_display:
85
- reasoning_display = f"**Reasoning:**\n{reasoning_display}"
86
-
87
  return "", chat_history, reasoning_display
88
 
89
  except Exception as e:
90
- error_msg = f"Error generating response: {str(e)}"
91
  gr.Warning(error_msg)
92
  return "", chat_history, error_msg
93
 
94
- # UI Components
95
- with gr.Blocks(title="Reka Flash-3 Chat Demo", theme=gr.themes.Soft()) as demo:
96
- # Header Section
97
  gr.Markdown("""
98
  # Reka Flash-3 Chat Interface
99
- *Powered by [Reka Core AI](https://www.reka.ai/)*
100
  """)
101
 
102
- # Deployment Notice
103
- with gr.Accordion("Important Deployment Notice", open=True):
104
  gr.Textbox(
105
- value="""To deploy this model on Hugging Face Spaces:
106
- 1. Request the Reka Flash-3 OSS model from Reka AI (https://www.reka.ai/).
107
- 2. Use a Hugging Face Pro subscription for deployment.
108
- 3. Configure your Space with zero GPU (CPU-only) hardware.
109
- 4. Ensure sufficient CPU memory for the 3B parameter model.""",
110
- label="Deployment Instructions",
111
- lines=5,
112
  interactive=False
113
  )
114
 
115
- # Chat Interface
116
  with gr.Row():
117
- chatbot = gr.Chatbot(height=500, label="Conversation")
118
- reasoning_display = gr.Textbox(
119
- label="Model Reasoning",
120
- interactive=False,
121
- visible=True,
122
- lines=10,
123
- max_lines=20
124
- )
125
 
126
- # Input Section
127
  with gr.Row():
128
- message = gr.Textbox(
129
- label="Your Message",
130
- placeholder="Type your message here...",
131
- lines=3,
132
- max_lines=6
133
- )
134
  submit_btn = gr.Button("Send", variant="primary")
135
 
136
- # Normal Options
137
- with gr.Accordion("Normal Options", open=True):
138
- with gr.Row():
139
- max_length = gr.Slider(128, 4096, value=DEFAULT_MAX_LENGTH, label="Max Length", step=128)
140
- temperature = gr.Slider(0.1, 2.0, value=DEFAULT_TEMPERATURE, label="Temperature", step=0.1)
141
-
142
- # Advanced Options
143
- with gr.Accordion("Advanced Options", open=False):
144
- with gr.Row():
145
- top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top-p", step=0.05)
146
- top_k = gr.Slider(1, 100, value=50, label="Top-k", step=1)
147
- repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, label="Repetition Penalty", step=0.1)
148
- with gr.Row():
149
- presence_penalty = gr.Slider(-2.0, 2.0, value=0.0, label="Presence Penalty", step=0.1)
150
- frequency_penalty = gr.Slider(-2.0, 2.0, value=0.0, label="Frequency Penalty", step=0.1)
151
-
152
- # System Prompt
153
- system_prompt = gr.Textbox(
154
- label="System Prompt",
155
- value=SYSTEM_PROMPT,
156
- lines=5,
157
- max_lines=10
158
- )
159
 
160
- # Debug Options
161
- show_reasoning = gr.Checkbox(label="Show Model Reasoning", value=True)
162
-
163
- # Event Handling
164
- inputs = [
165
- message,
166
- chatbot,
167
- system_prompt,
168
- max_length,
169
- temperature,
170
- top_p,
171
- top_k,
172
- repetition_penalty,
173
- presence_penalty,
174
- frequency_penalty,
175
- show_reasoning
176
- ]
177
- outputs = [message, chatbot, reasoning_display]
178
 
 
 
 
179
  submit_btn.click(generate_response, inputs=inputs, outputs=outputs)
180
  message.submit(generate_response, inputs=inputs, outputs=outputs)
181
 
182
- # Launch the interface
183
  demo.launch(debug=True)
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  import torch
4
 
5
  # Configuration
6
  MODEL_NAME = "RekaAI/reka-flash-3"
7
+ DEFAULT_MAX_LENGTH = 4096 # Reduced for CPU efficiency
8
  DEFAULT_TEMPERATURE = 0.7
9
 
10
+ # System prompt with reasoning instructions
11
+ SYSTEM_PROMPT = """You are Reka Flash-3, a helpful AI assistant created by Reka AI.
12
+ When responding, think step-by-step within <thinking> tags and conclude your answer after </thinking>.
13
+ For example:
 
 
14
  User: What is 2+2?
15
+ Assistant: <thinking>Let me calculate that. 2 plus 2 equals 4.</thinking> The answer is 4."""
 
 
 
16
 
17
+ # Load model and tokenizer with 4-bit quantization
18
  try:
19
+ quantization_config = BitsAndBytesConfig(
20
+ load_in_4bit=True,
21
+ bnb_4bit_compute_dtype=torch.float16,
22
+ bnb_4bit_use_double_quant=True,
23
+ bnb_4bit_quant_type="nf4"
24
+ )
25
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ MODEL_NAME,
28
+ quantization_config=quantization_config,
29
+ device_map="auto", # Maps to CPU
30
+ torch_dtype=torch.float16
31
+ )
32
+ tokenizer.pad_token = tokenizer.eos_token # Ensure padding works
33
  except Exception as e:
34
+ raise Exception(f"Failed to load model: {str(e)}. Ensure access to {MODEL_NAME} and sufficient CPU memory.")
35
 
36
  def generate_response(
37
  message,
 
42
  top_p,
43
  top_k,
44
  repetition_penalty,
 
 
45
  show_reasoning
46
  ):
47
+ """Generate a response from Reka Flash-3 with reasoning tags."""
 
 
48
  try:
49
+ # Format chat history and prompt (multi-round conversation)
50
+ history_str = ""
51
+ for user_msg, assistant_msg in chat_history:
52
+ history_str += f"human: {user_msg} <sep> assistant: {assistant_msg} <sep> "
53
+ prompt = f"{system_prompt} <sep> human: {message} <sep> assistant: <thinking>\n"
54
 
55
  # Tokenize input
56
+ inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
57
 
58
+ # Generate response with budget forcing
59
  outputs = model.generate(
60
  **inputs,
61
  max_new_tokens=max_length,
 
63
  top_p=top_p,
64
  top_k=top_k,
65
  repetition_penalty=repetition_penalty,
 
 
66
  do_sample=True,
67
+ eos_token_id=tokenizer.convert_tokens_to_ids("<sep>"), # Stop at <sep>
68
+ pad_token_id=tokenizer.eos_token_id
69
  )
70
 
71
+ # Decode and clean response
72
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
73
+ response = response[len(prompt):].split("<sep>")[0].strip() # Extract assistant response
74
 
75
  # Parse reasoning and final answer
76
  if "</thinking>" in response:
77
  reasoning, final_answer = response.split("</thinking>", 1)
78
+ reasoning = reasoning.replace("<thinking>", "").strip()
79
  final_answer = final_answer.strip()
80
  else:
81
  reasoning = ""
82
+ final_answer = response
83
 
84
+ # Update chat history (drop reasoning to save tokens)
85
+ chat_history.append({"role": "user", "content": message})
86
+ chat_history.append({"role": "assistant", "content": final_answer})
87
 
88
  # Display reasoning if requested
89
+ reasoning_display = f"**Reasoning:**\n{reasoning}" if show_reasoning and reasoning else ""
 
 
 
90
  return "", chat_history, reasoning_display
91
 
92
  except Exception as e:
93
+ error_msg = f"Error: {str(e)}"
94
  gr.Warning(error_msg)
95
  return "", chat_history, error_msg
96
 
97
+ # Gradio Interface
98
+ with gr.Blocks(title="Reka Flash-3 Chat", theme=gr.themes.Soft()) as demo:
 
99
  gr.Markdown("""
100
  # Reka Flash-3 Chat Interface
101
+ *Powered by [Reka AI](https://www.reka.ai/)* - A 21B parameter reasoning model optimized for CPU.
102
  """)
103
 
104
+ with gr.Accordion("Deployment Instructions", open=True):
 
105
  gr.Textbox(
106
+ value="""To deploy on Hugging Face Spaces:
107
+ 1. Request access to RekaAI/reka-flash-3 from Reka AI.
108
+ 2. Use a Pro subscription with zero-GPU (CPU-only) hardware.
109
+ 3. Ensure 32GB+ CPU memory for 4-bit quantization.
110
+ 4. Install dependencies: gradio, transformers, torch, bitsandbytes.""",
111
+ label="How to Deploy",
 
112
  interactive=False
113
  )
114
 
 
115
  with gr.Row():
116
+ chatbot = gr.Chatbot(type="messages", height=400, label="Conversation")
117
+ reasoning_display = gr.Textbox(label="Model Reasoning", interactive=False, lines=8)
 
 
 
 
 
 
118
 
 
119
  with gr.Row():
120
+ message = gr.Textbox(label="Your Message", placeholder="Ask me anything...", lines=2)
 
 
 
 
 
121
  submit_btn = gr.Button("Send", variant="primary")
122
 
123
+ with gr.Accordion("Options", open=True):
124
+ max_length = gr.Slider(128, 512, value=DEFAULT_MAX_LENGTH, label="Max Length", step=64)
125
+ temperature = gr.Slider(0.1, 2.0, value=DEFAULT_TEMPERATURE, label="Temperature", step=0.1)
126
+ top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top-p", step=0.05)
127
+ top_k = gr.Slider(1, 100, value=50, label="Top-k", step=1)
128
+ repetition_penalty = gr.Slider(0.1, 2.0, value=1.1, label="Repetition Penalty", step=0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ system_prompt = gr.Textbox(label="System Prompt", value=SYSTEM_PROMPT, lines=4)
131
+ show_reasoning = gr.Checkbox(label="Show Reasoning", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ # Event handling
134
+ inputs = [message, chatbot, system_prompt, max_length, temperature, top_p, top_k, repetition_penalty, show_reasoning]
135
+ outputs = [message, chatbot, reasoning_display]
136
  submit_btn.click(generate_response, inputs=inputs, outputs=outputs)
137
  message.submit(generate_response, inputs=inputs, outputs=outputs)
138
 
 
139
  demo.launch(debug=True)
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  gradio>=3.50
2
  huggingface_hub==0.25.2
3
- torch
 
 
 
1
  gradio>=3.50
2
  huggingface_hub==0.25.2
3
+ torch
4
+ transformers
5
+ bitsandbytes