TimurHromek commited on
Commit
60624d8
·
verified ·
1 Parent(s): 280f2ed

Reverted to working build of chat interface.

Browse files
Files changed (1) hide show
  1. app.py +15 -131
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  import importlib.util
4
  from tokenizers import Tokenizer
5
  from huggingface_hub import hf_hub_download
 
6
 
7
  # Download and import model components from HF Hub
8
  model_repo = "TimurHromek/HROM-V1"
@@ -35,53 +36,39 @@ model = load_model()
35
  safety = SafetyManager(model, tokenizer)
36
  max_response_length = 200
37
 
38
- def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200, temperature=1.0):
39
  device = next(model.parameters()).device
40
  generated_ids = input_ids.copy()
41
  for _ in range(max_length):
42
  input_tensor = torch.tensor([generated_ids], device=device)
43
  with torch.no_grad():
44
  logits = model(input_tensor)
45
-
46
- # Get last token logits and apply temperature
47
- next_token_logits = logits[0, -1, :]
48
- if temperature != 1.0:
49
- next_token_logits = next_token_logits / temperature
50
- probs = torch.softmax(next_token_logits, dim=-1)
51
-
52
- # Sample next token
53
- next_token = torch.multinomial(probs, num_samples=1).item()
54
-
55
- # Stop if end token is generated
56
  if next_token == tokenizer.token_to_id("</s>"):
57
  break
58
-
59
- # Safety check
60
  current_text = tokenizer.decode(generated_ids + [next_token])
61
  if not safety_manager.content_filter(current_text):
62
  break
63
-
64
  generated_ids.append(next_token)
65
  return generated_ids[len(input_ids):]
66
 
67
- def process_message(user_input, chat_history, token_history, temperature, max_context_length):
68
  # Process user input
69
  user_turn = f"<user> {user_input} </s>"
70
  user_tokens = tokenizer.encode(user_turn).ids
71
  token_history.extend(user_tokens)
72
 
73
- # Prepare input sequence with context limit
74
  input_sequence = [tokenizer.token_to_id("<s>")] + token_history
75
 
76
- # Truncate based on max context length
77
- max_input_len = max_context_length
78
  if len(input_sequence) > max_input_len:
79
  input_sequence = input_sequence[-max_input_len:]
80
  token_history = input_sequence[1:]
81
 
82
- # Generate response with temperature
83
- response_ids = generate_response(model, tokenizer, input_sequence, safety,
84
- max_response_length, temperature)
85
 
86
  # Process assistant response
87
  assistant_text = "I couldn't generate a proper response."
@@ -104,125 +91,22 @@ def process_message(user_input, chat_history, token_history, temperature, max_co
104
  def clear_history():
105
  return [], []
106
 
107
- css = """
108
- :root {
109
- --background: white;
110
- --text: black;
111
- --border: #e0e0e0;
112
- --button-bg: #f0f0f0;
113
- --button-hover: #e0e0e0;
114
- --chatbot-bg: #f8f8f8;
115
- }
116
-
117
- .dark {
118
- --background: #1a1a1a;
119
- --text: white;
120
- --border: #404040;
121
- --button-bg: #404040;
122
- --button-hover: #505050;
123
- --chatbot-bg: #262626;
124
- }
125
-
126
- body {
127
- background: var(--background) !important;
128
- color: var(--text) !important;
129
- transition: all 0.3s ease;
130
- }
131
-
132
- .gr-box {
133
- border-color: var(--border) !important;
134
- background: var(--background) !important;
135
- }
136
-
137
- .gr-button {
138
- background: var(--button-bg) !important;
139
- color: var(--text) !important;
140
- border-color: var(--border) !important;
141
- }
142
-
143
- .gr-button:hover {
144
- background: var(--button-hover) !important;
145
- }
146
-
147
- #chatbot {
148
- background: var(--chatbot-bg) !important;
149
- border-color: var(--border) !important;
150
- min-height: 500px;
151
- }
152
-
153
- .gr-textbox input {
154
- color: var(--text) !important;
155
- }
156
-
157
- .dark .gr-markdown {
158
- color: var(--text) !important;
159
- }
160
-
161
- .settings-panel {
162
- border-left: 1px solid var(--border) !important;
163
- padding-left: 20px !important;
164
- }
165
- """
166
-
167
- with gr.Blocks(css=css, title="HROM-V1.5 Chatbot") as demo:
168
- current_theme = gr.State("light")
169
-
170
- with gr.Row():
171
- with gr.Column(scale=3):
172
- gr.Markdown("# HROM-V1.5 Chatbot")
173
- chatbot = gr.Chatbot(height=500, elem_id="chatbot")
174
- msg = gr.Textbox(label="Your Message",
175
- placeholder="Type your message...",
176
- show_label=False,
177
- container=False)
178
-
179
- with gr.Column(scale=1, min_width=300, elem_classes="settings-panel"):
180
- with gr.Accordion("⚙️ Settings", open=False):
181
- with gr.Row():
182
- theme_btn = gr.Button("🌙 Dark Theme", variant="secondary")
183
- with gr.Row():
184
- temperature = gr.Slider(0.1, 2.0, value=1.0, step=0.1,
185
- label="Temperature (higher = more creative)")
186
- with gr.Row():
187
- max_context = gr.Slider(100, CONFIG["max_seq_len"] - max_response_length,
188
- value=CONFIG["max_seq_len"] - max_response_length, step=1,
189
- label="Context Window Size")
190
- with gr.Row():
191
- clear_btn = gr.Button("🧹 Clear History", variant="secondary")
192
-
193
  token_state = gr.State([])
194
- theme_css = gr.HTML("<style></style>")
195
-
196
- def toggle_theme(theme):
197
- new_theme = "dark" if theme == "light" else "light"
198
- btn_text = "🌞 Light Theme" if new_theme == "light" else "🌙 Dark Theme"
199
- css = """
200
- <style>
201
- body { background: %s !important; color: %s !important; }
202
- .dark-mode { display: %s !important; }
203
- </style>
204
- """ % (
205
- "var(--background)",
206
- "var(--text)",
207
- "block" if new_theme == "dark" else "none"
208
- )
209
- return new_theme, btn_text, css
210
-
211
- theme_btn.click(
212
- toggle_theme,
213
- current_theme,
214
- [current_theme, theme_btn, theme_css]
215
- )
216
 
217
  msg.submit(
218
  process_message,
219
- [msg, chatbot, token_state, temperature, max_context],
220
  [chatbot, token_state],
221
  queue=False
222
  ).then(
223
  lambda: "", None, msg
224
  )
225
 
 
226
  clear_btn.click(
227
  clear_history,
228
  outputs=[chatbot, token_state],
 
3
  import importlib.util
4
  from tokenizers import Tokenizer
5
  from huggingface_hub import hf_hub_download
6
+ import os
7
 
8
  # Download and import model components from HF Hub
9
  model_repo = "TimurHromek/HROM-V1"
 
36
  safety = SafetyManager(model, tokenizer)
37
  max_response_length = 200
38
 
39
+ def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):
40
  device = next(model.parameters()).device
41
  generated_ids = input_ids.copy()
42
  for _ in range(max_length):
43
  input_tensor = torch.tensor([generated_ids], device=device)
44
  with torch.no_grad():
45
  logits = model(input_tensor)
46
+ next_token = logits.argmax(-1)[:, -1].item()
 
 
 
 
 
 
 
 
 
 
47
  if next_token == tokenizer.token_to_id("</s>"):
48
  break
 
 
49
  current_text = tokenizer.decode(generated_ids + [next_token])
50
  if not safety_manager.content_filter(current_text):
51
  break
 
52
  generated_ids.append(next_token)
53
  return generated_ids[len(input_ids):]
54
 
55
+ def process_message(user_input, chat_history, token_history):
56
  # Process user input
57
  user_turn = f"<user> {user_input} </s>"
58
  user_tokens = tokenizer.encode(user_turn).ids
59
  token_history.extend(user_tokens)
60
 
61
+ # Prepare input sequence
62
  input_sequence = [tokenizer.token_to_id("<s>")] + token_history
63
 
64
+ # Truncate if needed
65
+ max_input_len = CONFIG["max_seq_len"] - max_response_length
66
  if len(input_sequence) > max_input_len:
67
  input_sequence = input_sequence[-max_input_len:]
68
  token_history = input_sequence[1:]
69
 
70
+ # Generate response
71
+ response_ids = generate_response(model, tokenizer, input_sequence, safety, max_response_length)
 
72
 
73
  # Process assistant response
74
  assistant_text = "I couldn't generate a proper response."
 
91
  def clear_history():
92
  return [], []
93
 
94
+ with gr.Blocks() as demo:
95
+ gr.Markdown("# HROM-V1 Chatbot")
96
+ chatbot = gr.Chatbot(height=500)
97
+ msg = gr.Textbox(label="Your Message")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  token_state = gr.State([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  msg.submit(
101
  process_message,
102
+ [msg, chatbot, token_state],
103
  [chatbot, token_state],
104
  queue=False
105
  ).then(
106
  lambda: "", None, msg
107
  )
108
 
109
+ clear_btn = gr.Button("Clear Chat History")
110
  clear_btn.click(
111
  clear_history,
112
  outputs=[chatbot, token_state],