Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import spaces | |
import gradio as gr | |
from threading import Thread | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
TextIteratorStreamer, | |
StoppingCriteria, | |
StoppingCriteriaList | |
) | |
MODEL_ID = "FuseAI/FuseO1-DeepSeekR1-QwQ-SkyT1-32B-Preview" | |
DEFAULT_SYSTEM_PROMPT = """ | |
**Role:** You are an Expert Coding Assistant. | |
Your responses MUST follow this structured workflow: | |
``` | |
[Understand]: Analyze the problem, identify constraints, and clarify objectives. | |
[Plan]: Outline a technical methodology with numbered steps (algorithms, tools, etc.). | |
[Reason]: Execute the plan using code snippets, equations, or logic flows. | |
[Verify]: Validate correctness via tests, edge cases, or formal proofs. | |
[Conclude]: Summarize results with key insights/recommendations. | |
``` | |
**Rules:** | |
1. Use markdown code blocks for all code/equations (e.g., `python`, `javascript`, `latex`). | |
2. Prioritize computational thinking (e.g., "To solve X, we can model it as a graph problem because..."). | |
3. Structure EVERY answer using the exact tags: [Understand], [Plan], [Reason], [Verify], [Conclude]. | |
4. Never combine steps - keep sections distinct. | |
5. Use technical precision over verbose explanations. | |
**Example Output Format:** | |
[Understand] | |
- Key problem: "Develop a function to find prime numbers..." | |
- Constraints: O(n log n) time, memory < 500MB. | |
[Plan] | |
1. Implement Sieve of Eratosthenes | |
2. Optimize memory via bitwise array | |
3. Handle edge case: n < 2 | |
[Reason] | |
```python | |
def count_primes(n: int) -> int: | |
if n <= 2: | |
return 0 | |
sieve = [True] * n | |
# ... (full implementation) | |
``` | |
[Verify] | |
Test Cases: | |
- n=10 β Primes [2,3,5,7] β Output 4 βοΈ | |
- n=1 β Output 0 βοΈ | |
- Benchmark: 1e6 in 0.8s β | |
[Conclude] | |
Solution achieves O(n log log n) time with bitwise compression. Recommended for large-scale prime detection | |
``` | |
Always Use Code to solve your problems. | |
""" | |
CSS = """ | |
.gr-chatbot { min-height: 500px; border-radius: 15px; } | |
.special-tag { color: #2ecc71; font-weight: 600; } | |
footer { display: none !important; } | |
""" | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
return input_ids[0][-1] == tokenizer.eos_token_id | |
def initialize_model(): | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
device_map="cuda", | |
quantization_config=quantization_config, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True | |
).to("cuda") | |
return model, tokenizer | |
def format_response(text): | |
return text.replace("[Understand]", '\n<strong class="special-tag">[Understand]</strong>\n') \ | |
.replace("[Plan]", '\n<strong class="special-tag">[Plan]</strong>\n') \ | |
.replace("[Conclude]", '\n<strong class="special-tag">[Conclude]</strong>\n') \ | |
.replace("[Reason]", '\n<strong class="special-tag">[Reason]</strong>\n') \ | |
.replace("[Verify]", '\n<strong class="special-tag">[Verify]</strong>\n') | |
def generate_response(message, chat_history, system_prompt, temperature, max_tokens): | |
# Create conversation history for model | |
conversation = [{"role": "system", "content": system_prompt}] | |
for user_msg, bot_msg in chat_history: | |
conversation.extend([ | |
{"role": "user", "content": user_msg}, | |
{"role": "assistant", "content": bot_msg} | |
]) | |
conversation.append({"role": "user", "content": message}) | |
# Tokenize input | |
input_ids = tokenizer.apply_chat_template( | |
conversation, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
).to(model.device) | |
# Setup streaming | |
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
streamer=streamer, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
stopping_criteria=StoppingCriteriaList([StopOnTokens()]) | |
) | |
# Start generation thread | |
Thread(target=model.generate, kwargs=generate_kwargs).start() | |
# Initialize response buffer | |
partial_message = "" | |
new_history = chat_history + [(message, "")] | |
# Stream response | |
for new_token in streamer: | |
partial_message += new_token | |
formatted = format_response(partial_message) | |
new_history[-1] = (message, formatted + "β") | |
yield new_history | |
# Final update without cursor | |
new_history[-1] = (message, format_response(partial_message)) | |
yield new_history | |
model, tokenizer = initialize_model() | |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
<h1 align="center">π§ AI Reasoning Assistant</h1> | |
<p align="center">Ask me Hatd questions</p> | |
""") | |
chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot") | |
msg = gr.Textbox(label="Your Question", placeholder="Type your question...") | |
with gr.Accordion("βοΈ Settings", open=False): | |
system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions") | |
temperature = gr.Slider(0, 1, value=0.5, label="Creativity") | |
max_tokens = gr.Slider(128, 4096, value=2048, label="Max Response Length") | |
clear = gr.Button("Clear History") | |
msg.submit( | |
generate_response, | |
[msg, chatbot, system_prompt, temperature, max_tokens], | |
[chatbot], | |
show_progress=True | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
if __name__ == "__main__": | |
demo.queue().launch() |