elapt1c commited on
Commit
ac2e861
·
verified ·
1 Parent(s): 08771f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -81
app.py CHANGED
@@ -1,116 +1,215 @@
 
1
  import gradio as gr
2
- import torch
3
- import importlib.util
4
- from tokenizers import Tokenizer
5
  from huggingface_hub import hf_hub_download
6
  import os
 
 
7
 
8
- # Download and import model components from HF Hub
9
  model_repo = "TimurHromek/HROM-V1"
 
10
 
11
- # 1. Import trainer module components
12
  trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trainer.py")
13
  spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file)
14
  trainer_module = importlib.util.module_from_spec(spec)
15
  spec.loader.exec_module(trainer_module)
16
- HROM = trainer_module.HROM
17
- CONFIG = trainer_module.CONFIG
18
- SafetyManager = trainer_module.SafetyManager
19
 
20
- # 2. Load tokenizer
21
  tokenizer_file = hf_hub_download(repo_id=model_repo, filename="tokenizer/hrom_tokenizer.json")
22
  tokenizer = Tokenizer.from_file(tokenizer_file)
23
 
24
- # 3. Load model checkpoint
25
- checkpoint_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trained-Model/HROM-V1.5.pt")
26
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
-
28
- def load_model():
29
- model = HROM().to(device)
30
- checkpoint = torch.load(checkpoint_file, map_location=device)
31
- model.load_state_dict(checkpoint['model'])
32
- model.eval()
33
- return model
34
-
35
- model = load_model()
36
- safety = SafetyManager(model, tokenizer)
37
- max_response_length = 200
38
-
39
- def generate_response(model, tokenizer, input_ids, safety_manager, max_length=200):
40
- device = next(model.parameters()).device
41
- generated_ids = input_ids.copy()
42
- for _ in range(max_length):
43
- input_tensor = torch.tensor([generated_ids], device=device)
44
- with torch.no_grad():
45
- logits = model(input_tensor)
46
- next_token = logits.argmax(-1)[:, -1].item()
47
- if next_token == tokenizer.token_to_id("</s>"):
48
- break
49
- current_text = tokenizer.decode(generated_ids + [next_token])
50
- if not safety_manager.content_filter(current_text):
51
- break
52
- generated_ids.append(next_token)
53
- return generated_ids[len(input_ids):]
54
-
55
- def process_message(user_input, chat_history, token_history):
56
- # Process user input
57
- user_turn = f"<user> {user_input} </s>"
58
- user_tokens = tokenizer.encode(user_turn).ids
59
  token_history.extend(user_tokens)
 
 
 
 
 
 
 
 
 
 
60
 
61
- # Prepare input sequence
62
- input_sequence = [tokenizer.token_to_id("<s>")] + token_history
63
-
64
- # Truncate if needed
65
- max_input_len = CONFIG["max_seq_len"] - max_response_length
66
- if len(input_sequence) > max_input_len:
67
- input_sequence = input_sequence[-max_input_len:]
68
- token_history = input_sequence[1:]
69
-
70
- # Generate response
71
- response_ids = generate_response(model, tokenizer, input_sequence, safety, max_response_length)
72
 
73
- # Process assistant response
74
- assistant_text = "I couldn't generate a proper response."
75
- if response_ids:
76
- if response_ids[0] == tokenizer.token_to_id("<assistant>"):
77
- try:
78
- end_idx = response_ids.index(tokenizer.token_to_id("</s>"))
79
- assistant_text = tokenizer.decode(response_ids[1:end_idx])
80
- token_history.extend(response_ids[:end_idx+1])
81
- except ValueError:
82
- assistant_text = tokenizer.decode(response_ids[1:])
83
- token_history.extend(response_ids)
84
  else:
85
- assistant_text = tokenizer.decode(response_ids)
86
- token_history.extend(response_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- chat_history.append((user_input, assistant_text))
89
- return chat_history, token_history
 
 
 
 
 
 
 
90
 
91
  def clear_history():
92
- return [], []
 
 
 
93
 
94
  with gr.Blocks() as demo:
95
- gr.Markdown("# HROM-V1 Chatbot")
96
- chatbot = gr.Chatbot(height=500)
97
- msg = gr.Textbox(label="Your Message")
98
- token_state = gr.State([])
 
 
 
 
 
 
 
 
 
 
 
99
 
 
 
 
 
 
 
100
  msg.submit(
101
  process_message,
102
- [msg, chatbot, token_state],
103
- [chatbot, token_state],
104
- queue=False
105
  ).then(
106
- lambda: "", None, msg
107
  )
108
 
109
  clear_btn = gr.Button("Clear Chat History")
110
  clear_btn.click(
111
  clear_history,
112
- outputs=[chatbot, token_state],
113
  queue=False
114
  )
115
 
116
- demo.launch()
 
1
+ # app.py (Gradio Client)
2
  import gradio as gr
3
+ import importlib.util # Still needed for CONFIG for max_seq_len
4
+ from tokenizers import Tokenizer # Still needed for local tokenization
 
5
  from huggingface_hub import hf_hub_download
6
  import os
7
+ import requests # For making HTTP requests
8
+ import json
9
 
10
+ # --- Configuration and Local Tokenizer (Still needed for UI-side processing) ---
11
  model_repo = "TimurHromek/HROM-V1"
12
+ INFERENCE_SERVER_URL = "http://localhost:5000/generate" # CHANGE THIS to your actual https://inference.stormsurge.xyz/generate
13
 
14
+ # 1. Import trainer module components (ONLY for CONFIG if needed locally)
15
  trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trainer.py")
16
  spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file)
17
  trainer_module = importlib.util.module_from_spec(spec)
18
  spec.loader.exec_module(trainer_module)
19
+ CONFIG = trainer_module.CONFIG # We need CONFIG["max_seq_len"]
 
 
20
 
21
+ # 2. Load tokenizer (locally for encoding user input)
22
  tokenizer_file = hf_hub_download(repo_id=model_repo, filename="tokenizer/hrom_tokenizer.json")
23
  tokenizer = Tokenizer.from_file(tokenizer_file)
24
 
25
+ max_response_length_config = 200 # Max tokens the *server* should generate for one response
26
+
27
+ # Model and SafetyManager are NOT loaded/used on the Gradio client side for generation anymore.
28
+
29
+ def process_message(user_input, chat_history, token_history, seed, temperature, top_k):
30
+ if not user_input.strip():
31
+ chat_history.append((user_input, "Please provide some input."))
32
+ return chat_history, token_history, seed, temperature, top_k # Pass back params
33
+
34
+ # 1. Process user input and update token_history
35
+ user_turn_text = f"<user> {user_input} </s>"
36
+ user_tokens = tokenizer.encode(user_turn_text).ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  token_history.extend(user_tokens)
38
+
39
+ # 2. Add assistant marker to token_history for the server to start generating after it
40
+ assistant_start_token = tokenizer.token_to_id("<assistant>")
41
+ token_history.append(assistant_start_token)
42
+
43
+ # 3. Prepare input sequence for the server (truncation)
44
+ # The server expects the full context it needs to start generating the assistant's reply.
45
+ # The max_response_length_config is for the *output*, so the input can be
46
+ # max_seq_len - max_response_length_config.
47
+ # The token_history already includes <s> from previous turns or initial state.
48
 
49
+ current_input_for_server = token_history.copy()
50
+ max_input_len_for_server = CONFIG["max_seq_len"] - max_response_length_config
 
 
 
 
 
 
 
 
 
51
 
52
+ if len(current_input_for_server) > max_input_len_for_server:
53
+ # If too long, truncate from the beginning, but ensure <s> is kept if present
54
+ # More robust: find first <s> after initial part if truncating heavily.
55
+ # Simple truncation for now:
56
+ num_tokens_to_remove = len(current_input_for_server) - max_input_len_for_server
57
+ # Keep <s> if it's the first token
58
+ if current_input_for_server and current_input_for_server[0] == tokenizer.token_to_id("<s>"):
59
+ current_input_for_server = [tokenizer.token_to_id("<s>")] + current_input_for_server[1+num_tokens_to_remove:]
 
 
 
60
  else:
61
+ current_input_for_server = current_input_for_server[num_tokens_to_remove:]
62
+
63
+ # Update token_history to reflect the truncated version sent to server
64
+ # This is important so the client's token_history matches what the server 'sees' as context
65
+ token_history = current_input_for_server.copy()
66
+
67
+
68
+ # 4. Call the inference server
69
+ payload = {
70
+ "token_history": current_input_for_server, # This now includes <s>...<user>...</s><assistant>
71
+ "max_response_length": max_response_length_config,
72
+ "temperature": temperature,
73
+ "top_k": top_k,
74
+ "seed": seed if seed > 0 else None # Send None if seed is 0 or negative (for random)
75
+ }
76
+
77
+ assistant_response_text = ""
78
+ assistant_response_token_ids = [] # Store IDs of the assistant's response
79
+ chat_history.append((user_input, "")) # Add user message, prepare for streaming assistant
80
+
81
+ try:
82
+ with requests.post(INFERENCE_SERVER_URL, json=payload, stream=True, timeout=120) as r:
83
+ r.raise_for_status() # Raise an exception for HTTP errors
84
+ for line in r.iter_lines():
85
+ if line:
86
+ decoded_line = line.decode('utf-8')
87
+ if decoded_line.startswith('data: '):
88
+ try:
89
+ event_data_json = decoded_line[len('data: '):]
90
+ event_data = json.loads(event_data_json)
91
+
92
+ if event_data.get("type") == "token":
93
+ token_text = event_data.get("text", "")
94
+ token_id = event_data.get("token_id")
95
+ assistant_response_text += token_text
96
+ if token_id is not None:
97
+ assistant_response_token_ids.append(token_id)
98
+ chat_history[-1] = (user_input, assistant_response_text)
99
+ yield chat_history, token_history, seed, temperature, top_k # Update UI progressively
100
+
101
+ elif event_data.get("type") == "eos":
102
+ # End of sentence token received
103
+ eos_token_id = event_data.get("token_id")
104
+ if eos_token_id is not None:
105
+ assistant_response_token_ids.append(eos_token_id)
106
+ # The server should have sent </s>. We add it to token history.
107
+ break # Stop processing more tokens for this response
108
+
109
+ elif event_data.get("type") == "stop":
110
+ reason = event_data.get("reason", "unknown reason")
111
+ assistant_response_text += f"\n[Generation stopped: {reason}]"
112
+ chat_history[-1] = (user_input, assistant_response_text)
113
+ yield chat_history, token_history, seed, temperature, top_k
114
+ break
115
+
116
+ elif event_data.get("type") == "stream_end":
117
+ # Server explicitly signals end of stream
118
+ break
119
+
120
+ elif event_data.get("type") == "error":
121
+ err_msg = event_data.get("message", "Unknown server error")
122
+ assistant_response_text += f"\n[Server Error: {err_msg}]"
123
+ chat_history[-1] = (user_input, assistant_response_text)
124
+ yield chat_history, token_history, seed, temperature, top_k
125
+ break
126
+
127
+ except json.JSONDecodeError:
128
+ print(f"Failed to parse JSON: {decoded_line}")
129
+ except Exception as e:
130
+ print(f"Error processing stream line: {e}")
131
+ assistant_response_text += f"\n[Client Error: {e}]"
132
+ chat_history[-1] = (user_input, assistant_response_text)
133
+ yield chat_history, token_history, seed, temperature, top_k
134
+ break # Stop on error
135
+
136
+ except requests.exceptions.RequestException as e:
137
+ assistant_response_text = f"Error connecting to inference server: {e}"
138
+ chat_history[-1] = (user_input, assistant_response_text)
139
+ # No new tokens to add to token_history from assistant
140
+ yield chat_history, token_history, seed, temperature, top_k
141
+ return # Exit the generator
142
+
143
+ # After stream is complete (or broken):
144
+ # Update the main token_history with the assistant's generated tokens
145
+ # The assistant_start_token was already added before calling the server.
146
+ # The assistant_response_token_ids are the tokens *after* <assistant>.
147
+ token_history.extend(assistant_response_token_ids)
148
+
149
+ # Ensure </s> is at the end of the assistant's part in token_history if not already
150
+ # (The server stream should ideally send eos_token_id for this)
151
+ if not assistant_response_token_ids or assistant_response_token_ids[-1] != tokenizer.token_to_id("</s>"):
152
+ if event_data.get("type") != "eos": # if it wasn't already an EOS event that added it
153
+ token_history.append(tokenizer.token_to_id("</s>"))
154
+
155
+
156
+ # Final update after generation is fully done
157
+ if not assistant_response_text.strip(): # If nothing was generated
158
+ chat_history[-1] = (user_input, "I couldn't generate a proper response.")
159
 
160
+ # Update seed for next turn if it was used (randomize if seed was > 0)
161
+ # If seed was <=0, it means use random, so keep it that way.
162
+ if seed > 0:
163
+ new_seed = seed + 1 # Or any other logic to change the seed for next turn
164
+ else:
165
+ new_seed = seed # Keep as random
166
+
167
+ yield chat_history, token_history, new_seed, temperature, top_k
168
+
169
 
170
  def clear_history():
171
+ # Initial token_history should start with <s>
172
+ initial_token_history = [tokenizer.token_to_id("<s>")]
173
+ return [], initial_token_history, -1, 0.7, 50 # Cleared history, initial tokens, default seed/temp/top_k
174
+
175
 
176
  with gr.Blocks() as demo:
177
+ gr.Markdown("# HROM-V1 Chatbot (Remote Inference)")
178
+
179
+ with gr.Row():
180
+ with gr.Column(scale=1):
181
+ seed_slider = gr.Slider(minimum=-1, maximum=99999, value=-1, step=1, label="Seed (-1 for random)")
182
+ temp_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature")
183
+ top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K (0 for no Top-K)")
184
+ with gr.Column(scale=3):
185
+ chatbot = gr.Chatbot(height=500, label="Chat")
186
+ msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
187
+
188
+ # token_state stores the *entire conversation history as token IDs*
189
+ # It should be initialized with the <s> token.
190
+ initial_tokens = [tokenizer.token_to_id("<s>")]
191
+ token_state = gr.State(initial_tokens)
192
 
193
+ # Parameters for generation
194
+ # seed_state = gr.State(-1) # -1 for random
195
+ # temp_state = gr.State(0.7)
196
+ # top_k_state = gr.State(50) # 0 to disable
197
+
198
+ # Chain actions: submit text -> process_message (yields updates) -> clear textbox
199
  msg.submit(
200
  process_message,
201
+ [msg, chatbot, token_state, seed_slider, temp_slider, top_k_slider],
202
+ [chatbot, token_state, seed_slider, temp_slider, top_k_slider], # Pass params back to update state if needed
203
+ queue=True # Enable queue for streaming
204
  ).then(
205
+ lambda: "", outputs=msg # Clear textbox
206
  )
207
 
208
  clear_btn = gr.Button("Clear Chat History")
209
  clear_btn.click(
210
  clear_history,
211
+ outputs=[chatbot, token_state, seed_slider, temp_slider, top_k_slider],
212
  queue=False
213
  )
214
 
215
+ demo.queue().launch(debug=True) # .queue() is important for streaming updates