Smart_LLM / app.py
Daemontatox's picture
Update app.py
86de665 verified
raw
history blame
6.7 kB
import torch
import spaces
import gradio as gr
from threading import Thread
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TextIteratorStreamer,
StoppingCriteria,
StoppingCriteriaList
)
MODEL_ID = "Daemontatox/Cogito-R1"
DEFAULT_SYSTEM_PROMPT = """
You are Cogito-R1 , an AI engineered for rigorous,Long , transparent reasoning.
Your responses must **strictly follow this protocol:**
1. **THINK FIRST:**
- Begin every interaction by generating a raw, unfiltered internal monologue.
- Enclose this step-by-step reasoning process—including doubts, methodical evaluations, and logical pivots—between `<think>` and `</think>` tags.
- Example: `<think>Analyzing query... Is the user asking for X or Y? Cross-checking definitions... Prioritizing accuracy...</think>`
2. **ANSWER AFTER:**
- Only after completing the `<think>` block, deliver a concise, precise answer enclosed between `<you>` and `</you>` tags.
- This answer must directly reflect conclusions from your reasoning phase.
**RULES:**
- **Tag Compliance:** Omitting or altering `<think>`, `</think>`, `<you>`, or `</you>` tags is **prohibited.**
- **No Shortcuts:** The `<think>` block must detail **every critical step**, even uncertain or exploratory thoughts.
- **Order Enforcement:** Never output an answer without a preceding `<think>` analysis.
Failure to adhere to this structure will result in termination."
""" # You can modify the default system instructions here
CSS = """
.gr-chatbot { min-height: 500px; border-radius: 15px; }
.special-tag { color: #2ecc71; font-weight: 600; }
footer { display: none !important; }
"""
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
# Stop when the EOS token is generated.
return input_ids[0][-1] == tokenizer.eos_token_id
def initialize_model():
# Enable 4-bit quantization for faster inference and lower memory usage.
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="cuda",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
trust_remote_code=True
)
model.to("cuda")
model.eval() # set evaluation mode to disable gradients and speed up inference
return model, tokenizer
def format_response(text):
# List of replacements to format key tokens with HTML for styling.
replacements = [
("[Understand]", '\n<strong class="special-tag">[Understand]</strong>\n'),
("[Reason]", '\n<strong class="special-tag">[Reason]</strong>\n'),
("[/Reason]", '\n<strong class="special-tag">[/Reason]</strong>\n'),
("[Answer]", '\n<strong class="special-tag">[Answer]</strong>\n'),
("[/Answer]", '\n<strong class="special-tag">[/Answer]</strong>\n'),
]
for old, new in replacements:
text = text.replace(old, new)
return text
@spaces.GPU(duration=360)
def generate_response(message, chat_history, system_prompt, temperature, max_tokens, top_p, top_k, repetition_penalty):
# Build the conversation history.
conversation = [{"role": "system", "content": system_prompt}]
for user_msg, bot_msg in chat_history:
conversation.append({"role": "user", "content": user_msg})
conversation.append({"role": "assistant", "content": bot_msg})
conversation.append({"role": "user", "content": message})
# Tokenize the conversation. (This assumes the tokenizer has an apply_chat_template method.)
input_ids = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Setup the streamer to yield new tokens as they are generated.
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
# Prepare generation parameters including extra customization options.
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"max_new_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"stopping_criteria": StoppingCriteriaList([StopOnTokens()])
}
# Run the generation inside a no_grad block for speed.
def generate_inference():
with torch.inference_mode():
model.generate(**generate_kwargs)
Thread(target=generate_inference, daemon=True).start()
# Stream the output tokens.
partial_message = ""
new_history = chat_history + [(message, "")]
for new_token in streamer:
partial_message += new_token
formatted = format_response(partial_message)
new_history[-1] = (message, formatted + "▌")
yield new_history
# Final update without the cursor.
new_history[-1] = (message, format_response(partial_message))
yield new_history
# Initialize the model and tokenizer globally.
model, tokenizer = initialize_model()
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.Markdown("""
<h1 align="center">🧠 AI Reasoning Assistant</h1>
<p align="center">Ask me hard questions and see the reasoning unfold.</p>
""")
chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot")
msg = gr.Textbox(label="Your Question", placeholder="Type your question...")
with gr.Accordion("⚙️ Settings", open=False):
system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions")
temperature = gr.Slider(0, 1, value=0.6, label="Creativity (Temperature)")
max_tokens = gr.Slider(128, 8192, 4096, label="Max Response Length")
top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top P (Nucleus Sampling)")
top_k = gr.Slider(0, 100, value=50, label="Top K")
repetition_penalty = gr.Slider(0.5, 2.0, value=1.1, label="Repetition Penalty")
clear = gr.Button("Clear History")
# Link the input textbox with the generation function.
msg.submit(
generate_response,
[msg, chatbot, system_prompt, temperature, max_tokens, top_p, top_k, repetition_penalty],
chatbot,
show_progress=True
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.queue().launch()