Daemontatox commited on
Commit
86de665
·
verified ·
1 Parent(s): 32359f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -39
app.py CHANGED
@@ -11,14 +11,30 @@ from transformers import (
11
  StoppingCriteriaList
12
  )
13
 
14
- MODEL_ID ="Daemontatox/Cogito-R1"
15
 
 
 
 
16
 
17
- #
18
- #
 
 
19
 
20
- DEFAULT_SYSTEM_PROMPT ="""
21
- """
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  CSS = """
24
  .gr-chatbot { min-height: 500px; border-radius: 15px; }
@@ -28,9 +44,11 @@ footer { display: none !important; }
28
 
29
  class StopOnTokens(StoppingCriteria):
30
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
31
  return input_ids[0][-1] == tokenizer.eos_token_id
32
 
33
  def initialize_model():
 
34
  quantization_config = BitsAndBytesConfig(
35
  load_in_4bit=True,
36
  bnb_4bit_compute_dtype=torch.bfloat16,
@@ -47,84 +65,102 @@ def initialize_model():
47
  quantization_config=quantization_config,
48
  torch_dtype=torch.bfloat16,
49
  trust_remote_code=True
50
- ).to("cuda")
 
 
51
 
52
  return model, tokenizer
53
 
54
  def format_response(text):
55
- return text.replace("[Understand]", '\n<strong class="special-tag">[Understand]</strong>\n') \
56
- .replace("[/Reason]", '\n<strong class="special-tag">[/Reason]</strong>\n') \
57
- .replace("[/Answer]", '\n<strong class="special-tag">[/Answer]</strong>\n') \
58
- .replace("[Reason]", '\n<strong class="special-tag">[Reason]</strong>\n') \
59
- .replace("[Answer]", '\n<strong class="special-tag">[Answer]</strong>\n')
 
 
 
 
 
 
 
60
  @spaces.GPU(duration=360)
61
- def generate_response(message, chat_history, system_prompt, temperature, max_tokens):
62
- # Create conversation history for model
63
  conversation = [{"role": "system", "content": system_prompt}]
64
  for user_msg, bot_msg in chat_history:
65
- conversation.extend([
66
- {"role": "user", "content": user_msg},
67
- {"role": "assistant", "content": bot_msg}
68
- ])
69
  conversation.append({"role": "user", "content": message})
70
 
71
- # Tokenize input
72
  input_ids = tokenizer.apply_chat_template(
73
  conversation,
74
  add_generation_prompt=True,
75
  return_tensors="pt"
76
  ).to(model.device)
77
 
78
- # Setup streaming
79
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
80
- generate_kwargs = dict(
81
- input_ids=input_ids,
82
- streamer=streamer,
83
- max_new_tokens=max_tokens,
84
- temperature=temperature,
85
- stopping_criteria=StoppingCriteriaList([StopOnTokens()])
86
- )
87
-
88
- # Start generation thread
89
- Thread(target=model.generate, kwargs=generate_kwargs).start()
90
 
91
- # Initialize response buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  partial_message = ""
93
  new_history = chat_history + [(message, "")]
94
-
95
- # Stream response
96
  for new_token in streamer:
97
  partial_message += new_token
98
  formatted = format_response(partial_message)
99
  new_history[-1] = (message, formatted + "▌")
100
  yield new_history
101
 
102
- # Final update without cursor
103
  new_history[-1] = (message, format_response(partial_message))
104
  yield new_history
105
 
 
106
  model, tokenizer = initialize_model()
107
 
108
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
109
  gr.Markdown("""
110
  <h1 align="center">🧠 AI Reasoning Assistant</h1>
111
- <p align="center">Ask me Hard questions</p>
112
  """)
113
 
114
  chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot")
115
  msg = gr.Textbox(label="Your Question", placeholder="Type your question...")
116
-
117
  with gr.Accordion("⚙️ Settings", open=False):
118
  system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions")
119
- temperature = gr.Slider(0, 1, value=0.6, label="Creativity")
120
- max_tokens = gr.Slider(128, 8192, 2048, label="Max Response Length")
 
 
 
121
 
122
  clear = gr.Button("Clear History")
123
 
 
124
  msg.submit(
125
  generate_response,
126
- [msg, chatbot, system_prompt, temperature, max_tokens],
127
- [chatbot],
128
  show_progress=True
129
  )
130
  clear.click(lambda: None, None, chatbot, queue=False)
 
11
  StoppingCriteriaList
12
  )
13
 
14
+ MODEL_ID = "Daemontatox/Cogito-R1"
15
 
16
+ DEFAULT_SYSTEM_PROMPT = """
17
+ You are Cogito-R1 , an AI engineered for rigorous,Long , transparent reasoning.
18
+ Your responses must **strictly follow this protocol:**
19
 
20
+ 1. **THINK FIRST:**
21
+ - Begin every interaction by generating a raw, unfiltered internal monologue.
22
+ - Enclose this step-by-step reasoning process—including doubts, methodical evaluations, and logical pivots—between `<think>` and `</think>` tags.
23
+ - Example: `<think>Analyzing query... Is the user asking for X or Y? Cross-checking definitions... Prioritizing accuracy...</think>`
24
 
25
+ 2. **ANSWER AFTER:**
26
+ - Only after completing the `<think>` block, deliver a concise, precise answer enclosed between `<you>` and `</you>` tags.
27
+ - This answer must directly reflect conclusions from your reasoning phase.
28
+
29
+ **RULES:**
30
+ - **Tag Compliance:** Omitting or altering `<think>`, `</think>`, `<you>`, or `</you>` tags is **prohibited.**
31
+ - **No Shortcuts:** The `<think>` block must detail **every critical step**, even uncertain or exploratory thoughts.
32
+ - **Order Enforcement:** Never output an answer without a preceding `<think>` analysis.
33
+
34
+ Failure to adhere to this structure will result in termination."
35
+
36
+
37
+ """ # You can modify the default system instructions here
38
 
39
  CSS = """
40
  .gr-chatbot { min-height: 500px; border-radius: 15px; }
 
44
 
45
  class StopOnTokens(StoppingCriteria):
46
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
47
+ # Stop when the EOS token is generated.
48
  return input_ids[0][-1] == tokenizer.eos_token_id
49
 
50
  def initialize_model():
51
+ # Enable 4-bit quantization for faster inference and lower memory usage.
52
  quantization_config = BitsAndBytesConfig(
53
  load_in_4bit=True,
54
  bnb_4bit_compute_dtype=torch.bfloat16,
 
65
  quantization_config=quantization_config,
66
  torch_dtype=torch.bfloat16,
67
  trust_remote_code=True
68
+ )
69
+ model.to("cuda")
70
+ model.eval() # set evaluation mode to disable gradients and speed up inference
71
 
72
  return model, tokenizer
73
 
74
  def format_response(text):
75
+ # List of replacements to format key tokens with HTML for styling.
76
+ replacements = [
77
+ ("[Understand]", '\n<strong class="special-tag">[Understand]</strong>\n'),
78
+ ("[Reason]", '\n<strong class="special-tag">[Reason]</strong>\n'),
79
+ ("[/Reason]", '\n<strong class="special-tag">[/Reason]</strong>\n'),
80
+ ("[Answer]", '\n<strong class="special-tag">[Answer]</strong>\n'),
81
+ ("[/Answer]", '\n<strong class="special-tag">[/Answer]</strong>\n'),
82
+ ]
83
+ for old, new in replacements:
84
+ text = text.replace(old, new)
85
+ return text
86
+
87
  @spaces.GPU(duration=360)
88
+ def generate_response(message, chat_history, system_prompt, temperature, max_tokens, top_p, top_k, repetition_penalty):
89
+ # Build the conversation history.
90
  conversation = [{"role": "system", "content": system_prompt}]
91
  for user_msg, bot_msg in chat_history:
92
+ conversation.append({"role": "user", "content": user_msg})
93
+ conversation.append({"role": "assistant", "content": bot_msg})
 
 
94
  conversation.append({"role": "user", "content": message})
95
 
96
+ # Tokenize the conversation. (This assumes the tokenizer has an apply_chat_template method.)
97
  input_ids = tokenizer.apply_chat_template(
98
  conversation,
99
  add_generation_prompt=True,
100
  return_tensors="pt"
101
  ).to(model.device)
102
 
103
+ # Setup the streamer to yield new tokens as they are generated.
104
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
105
 
106
+ # Prepare generation parameters including extra customization options.
107
+ generate_kwargs = {
108
+ "input_ids": input_ids,
109
+ "streamer": streamer,
110
+ "max_new_tokens": max_tokens,
111
+ "temperature": temperature,
112
+ "top_p": top_p,
113
+ "top_k": top_k,
114
+ "repetition_penalty": repetition_penalty,
115
+ "stopping_criteria": StoppingCriteriaList([StopOnTokens()])
116
+ }
117
+
118
+ # Run the generation inside a no_grad block for speed.
119
+ def generate_inference():
120
+ with torch.inference_mode():
121
+ model.generate(**generate_kwargs)
122
+ Thread(target=generate_inference, daemon=True).start()
123
+
124
+ # Stream the output tokens.
125
  partial_message = ""
126
  new_history = chat_history + [(message, "")]
 
 
127
  for new_token in streamer:
128
  partial_message += new_token
129
  formatted = format_response(partial_message)
130
  new_history[-1] = (message, formatted + "▌")
131
  yield new_history
132
 
133
+ # Final update without the cursor.
134
  new_history[-1] = (message, format_response(partial_message))
135
  yield new_history
136
 
137
+ # Initialize the model and tokenizer globally.
138
  model, tokenizer = initialize_model()
139
 
140
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
141
  gr.Markdown("""
142
  <h1 align="center">🧠 AI Reasoning Assistant</h1>
143
+ <p align="center">Ask me hard questions and see the reasoning unfold.</p>
144
  """)
145
 
146
  chatbot = gr.Chatbot(label="Conversation", elem_id="chatbot")
147
  msg = gr.Textbox(label="Your Question", placeholder="Type your question...")
148
+
149
  with gr.Accordion("⚙️ Settings", open=False):
150
  system_prompt = gr.TextArea(value=DEFAULT_SYSTEM_PROMPT, label="System Instructions")
151
+ temperature = gr.Slider(0, 1, value=0.6, label="Creativity (Temperature)")
152
+ max_tokens = gr.Slider(128, 8192, 4096, label="Max Response Length")
153
+ top_p = gr.Slider(0.0, 1.0, value=0.95, label="Top P (Nucleus Sampling)")
154
+ top_k = gr.Slider(0, 100, value=50, label="Top K")
155
+ repetition_penalty = gr.Slider(0.5, 2.0, value=1.1, label="Repetition Penalty")
156
 
157
  clear = gr.Button("Clear History")
158
 
159
+ # Link the input textbox with the generation function.
160
  msg.submit(
161
  generate_response,
162
+ [msg, chatbot, system_prompt, temperature, max_tokens, top_p, top_k, repetition_penalty],
163
+ chatbot,
164
  show_progress=True
165
  )
166
  clear.click(lambda: None, None, chatbot, queue=False)