Smart_LLM / app.py
Daemontatox's picture
Update app.py
e41071d verified
raw
history blame
5.92 kB
import torch
import spaces
import gradio as gr
from threading import Thread
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TextIteratorStreamer,
StoppingCriteria,
StoppingCriteriaList
)
MODEL_ID = "FuseAI/FuseO1-DeepSeekR1-QwQ-SkyT1-32B-Preview"
DEFAULT_SYSTEM_PROMPT = """
**Role:** You are an Expert Coding Assistant.
Your responses MUST follow this structured workflow:
```
[Understand]: Analyze the problem, identify constraints, and clarify objectives.
[Plan]: Outline a technical methodology with numbered steps (algorithms, tools, etc.).
[Reason]: Execute the plan using code snippets, equations, or logic flows.
[Verify]: Validate correctness via tests, edge cases, or formal proofs.
[Conclude]: Summarize results with key insights/recommendations.
```
**Rules:**
1. Use markdown code blocks for all code/equations (e.g., `python`, `javascript`, `latex`).
2. Prioritize computational thinking (e.g., "To solve X, we can model it as a graph problem because...").
3. Structure EVERY answer using the exact tags: [Understand], [Plan], [Reason], [Verify], [Conclude].
4. Never combine steps - keep sections distinct.
5. Use technical precision over verbose explanations.
**Example Output Format:**
[Understand]
- Key problem: "Develop a function to find prime numbers..."
- Constraints: O(n log n) time, memory < 500MB.
[Plan]
1. Implement Sieve of Eratosthenes
2. Optimize memory via bitwise array
3. Handle edge case: n < 2
[Reason]
```python
def count_primes(n: int) -> int:
if n <= 2:
return 0
sieve = [True] * n
# ... (full implementation)
```
[Verify]
Test Cases:
- n=10 β†’ Primes [2,3,5,7] β†’ Output 4 βœ”οΈ
- n=1 β†’ Output 0 βœ”οΈ
- Benchmark: 1e6 in 0.8s βœ…
[Conclude]
Solution achieves O(n log log n) time with bitwise compression. Recommended for large-scale prime detection
```
Always Use Code to solve your problems.
"""
CSS = """
.gr-chatbot { min-height: 500px; border-radius: 15px; }
.special-tag { color: #2ecc71; font-weight: 600; }
footer { display: none !important; }
"""
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids[0][-1] == tokenizer.eos_token_id
def initialize_model():
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="cuda",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
trust_remote_code=True
).to("cuda")
return model, tokenizer
def format_response(text):
return text.replace("[Understand]", '\n<strong class="special-tag">[Understand]</strong>\n') \
.replace("[Plan]", '\n<strong class="special-tag">[Plan]</strong>\n') \
.replace("[Conclude]", '\n<strong class="special-tag">[Conclude]</strong>\n') \
.replace("[Reason]", '\n<strong class="special-tag">[Reason]</strong>\n') \
.replace("[Verify]", '\n<strong class="special-tag">[Verify]</strong>\n')
@spaces.GPU(duration=360)
def generate_response(message, chat_history, system_prompt, temperature, max_tokens):
# Create conversation history for model
conversation = [{"role": "system", "content": system_prompt}]
for user_msg, bot_msg in chat_history:
conversation.extend([
{"role": "user", "content": user_msg},
{"role": "assistant", "content": bot_msg}
])
conversation.append({"role": "user", "content": message})
# Tokenize input
input_ids = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Setup streaming
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
stopping_criteria=StoppingCriteriaList([StopOnTokens()])
)
# Start generation thread
Thread(target=model.generate, kwargs=generate_kwargs).start()
# Initialize response buffer
partial_message = ""
new_history = chat_history + [(message, "")]
# Stream response
for new_token in streamer:
partial_message += new_token
formatted = format_response(partial_message)
new_history[-1] = (message, formatted + "β–Œ")
yield new_history
# Final update without cursor
new_history[-1] = (message, format_response(partial_message))
yield new_history
model, tokenizer = initialize_model()
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.Markdown("""
<h1 align="center">🧠 AI Reasoning Assistant</h1>
<p align="center">Ask me Hatd questions</p>
""")
chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot")
msg = gr.Textbox(label="Your Question", placeholder="Type your question...")
with gr.Accordion("βš™οΈ Settings", open=False):
system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions")
temperature = gr.Slider(0, 1, value=0.5, label="Creativity")
max_tokens = gr.Slider(128, 4096, value=2048, label="Max Response Length")
clear = gr.Button("Clear History")
msg.submit(
generate_response,
[msg, chatbot, system_prompt, temperature, max_tokens],
[chatbot],
show_progress=True
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.queue().launch()