Spaces:
Sleeping
Sleeping
File size: 11,175 Bytes
ac2e861 95d187a ac2e861 bfe7166 60624d8 ac2e861 95d187a ac2e861 bfe7166 ac2e861 bfe7166 ac2e861 c17825d bfe7166 ac2e861 bfe7166 ac2e861 c17825d bfe7166 ac2e861 95d187a ac2e861 95d187a ac2e861 95d187a ac2e861 95d187a ac2e861 95d187a ac2e861 95d187a ac2e861 95d187a 08771f3 ac2e861 08771f3 ac2e861 08771f3 ac2e861 08771f3 ac2e861 08771f3 ac2e861 08771f3 ac2e861 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
# app.py (Gradio Client)
import gradio as gr
import importlib.util # Still needed for CONFIG for max_seq_len
from tokenizers import Tokenizer # Still needed for local tokenization
from huggingface_hub import hf_hub_download
import os
import requests # For making HTTP requests
import json
# --- Configuration and Local Tokenizer (Still needed for UI-side processing) ---
model_repo = "TimurHromek/HROM-V1"
INFERENCE_SERVER_URL = "http://localhost:5000/generate" # CHANGE THIS to your actual https://inference.stormsurge.xyz/generate
# 1. Import trainer module components (ONLY for CONFIG if needed locally)
trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trainer.py")
spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file)
trainer_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(trainer_module)
CONFIG = trainer_module.CONFIG # We need CONFIG["max_seq_len"]
# 2. Load tokenizer (locally for encoding user input)
tokenizer_file = hf_hub_download(repo_id=model_repo, filename="tokenizer/hrom_tokenizer.json")
tokenizer = Tokenizer.from_file(tokenizer_file)
max_response_length_config = 200 # Max tokens the *server* should generate for one response
# Model and SafetyManager are NOT loaded/used on the Gradio client side for generation anymore.
def process_message(user_input, chat_history, token_history, seed, temperature, top_k):
if not user_input.strip():
chat_history.append((user_input, "Please provide some input."))
return chat_history, token_history, seed, temperature, top_k # Pass back params
# 1. Process user input and update token_history
user_turn_text = f"<user> {user_input} </s>"
user_tokens = tokenizer.encode(user_turn_text).ids
token_history.extend(user_tokens)
# 2. Add assistant marker to token_history for the server to start generating after it
assistant_start_token = tokenizer.token_to_id("<assistant>")
token_history.append(assistant_start_token)
# 3. Prepare input sequence for the server (truncation)
# The server expects the full context it needs to start generating the assistant's reply.
# The max_response_length_config is for the *output*, so the input can be
# max_seq_len - max_response_length_config.
# The token_history already includes <s> from previous turns or initial state.
current_input_for_server = token_history.copy()
max_input_len_for_server = CONFIG["max_seq_len"] - max_response_length_config
if len(current_input_for_server) > max_input_len_for_server:
# If too long, truncate from the beginning, but ensure <s> is kept if present
# More robust: find first <s> after initial part if truncating heavily.
# Simple truncation for now:
num_tokens_to_remove = len(current_input_for_server) - max_input_len_for_server
# Keep <s> if it's the first token
if current_input_for_server and current_input_for_server[0] == tokenizer.token_to_id("<s>"):
current_input_for_server = [tokenizer.token_to_id("<s>")] + current_input_for_server[1+num_tokens_to_remove:]
else:
current_input_for_server = current_input_for_server[num_tokens_to_remove:]
# Update token_history to reflect the truncated version sent to server
# This is important so the client's token_history matches what the server 'sees' as context
token_history = current_input_for_server.copy()
# 4. Call the inference server
payload = {
"token_history": current_input_for_server, # This now includes <s>...<user>...</s><assistant>
"max_response_length": max_response_length_config,
"temperature": temperature,
"top_k": top_k,
"seed": seed if seed > 0 else None # Send None if seed is 0 or negative (for random)
}
assistant_response_text = ""
assistant_response_token_ids = [] # Store IDs of the assistant's response
chat_history.append((user_input, "")) # Add user message, prepare for streaming assistant
try:
with requests.post(INFERENCE_SERVER_URL, json=payload, stream=True, timeout=120) as r:
r.raise_for_status() # Raise an exception for HTTP errors
for line in r.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if decoded_line.startswith('data: '):
try:
event_data_json = decoded_line[len('data: '):]
event_data = json.loads(event_data_json)
if event_data.get("type") == "token":
token_text = event_data.get("text", "")
token_id = event_data.get("token_id")
assistant_response_text += token_text
if token_id is not None:
assistant_response_token_ids.append(token_id)
chat_history[-1] = (user_input, assistant_response_text)
yield chat_history, token_history, seed, temperature, top_k # Update UI progressively
elif event_data.get("type") == "eos":
# End of sentence token received
eos_token_id = event_data.get("token_id")
if eos_token_id is not None:
assistant_response_token_ids.append(eos_token_id)
# The server should have sent </s>. We add it to token history.
break # Stop processing more tokens for this response
elif event_data.get("type") == "stop":
reason = event_data.get("reason", "unknown reason")
assistant_response_text += f"\n[Generation stopped: {reason}]"
chat_history[-1] = (user_input, assistant_response_text)
yield chat_history, token_history, seed, temperature, top_k
break
elif event_data.get("type") == "stream_end":
# Server explicitly signals end of stream
break
elif event_data.get("type") == "error":
err_msg = event_data.get("message", "Unknown server error")
assistant_response_text += f"\n[Server Error: {err_msg}]"
chat_history[-1] = (user_input, assistant_response_text)
yield chat_history, token_history, seed, temperature, top_k
break
except json.JSONDecodeError:
print(f"Failed to parse JSON: {decoded_line}")
except Exception as e:
print(f"Error processing stream line: {e}")
assistant_response_text += f"\n[Client Error: {e}]"
chat_history[-1] = (user_input, assistant_response_text)
yield chat_history, token_history, seed, temperature, top_k
break # Stop on error
except requests.exceptions.RequestException as e:
assistant_response_text = f"Error connecting to inference server: {e}"
chat_history[-1] = (user_input, assistant_response_text)
# No new tokens to add to token_history from assistant
yield chat_history, token_history, seed, temperature, top_k
return # Exit the generator
# After stream is complete (or broken):
# Update the main token_history with the assistant's generated tokens
# The assistant_start_token was already added before calling the server.
# The assistant_response_token_ids are the tokens *after* <assistant>.
token_history.extend(assistant_response_token_ids)
# Ensure </s> is at the end of the assistant's part in token_history if not already
# (The server stream should ideally send eos_token_id for this)
if not assistant_response_token_ids or assistant_response_token_ids[-1] != tokenizer.token_to_id("</s>"):
if event_data.get("type") != "eos": # if it wasn't already an EOS event that added it
token_history.append(tokenizer.token_to_id("</s>"))
# Final update after generation is fully done
if not assistant_response_text.strip(): # If nothing was generated
chat_history[-1] = (user_input, "I couldn't generate a proper response.")
# Update seed for next turn if it was used (randomize if seed was > 0)
# If seed was <=0, it means use random, so keep it that way.
if seed > 0:
new_seed = seed + 1 # Or any other logic to change the seed for next turn
else:
new_seed = seed # Keep as random
yield chat_history, token_history, new_seed, temperature, top_k
def clear_history():
# Initial token_history should start with <s>
initial_token_history = [tokenizer.token_to_id("<s>")]
return [], initial_token_history, -1, 0.7, 50 # Cleared history, initial tokens, default seed/temp/top_k
with gr.Blocks() as demo:
gr.Markdown("# HROM-V1 Chatbot (Remote Inference)")
with gr.Row():
with gr.Column(scale=1):
seed_slider = gr.Slider(minimum=-1, maximum=99999, value=-1, step=1, label="Seed (-1 for random)")
temp_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature")
top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K (0 for no Top-K)")
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=500, label="Chat")
msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
# token_state stores the *entire conversation history as token IDs*
# It should be initialized with the <s> token.
initial_tokens = [tokenizer.token_to_id("<s>")]
token_state = gr.State(initial_tokens)
# Parameters for generation
# seed_state = gr.State(-1) # -1 for random
# temp_state = gr.State(0.7)
# top_k_state = gr.State(50) # 0 to disable
# Chain actions: submit text -> process_message (yields updates) -> clear textbox
msg.submit(
process_message,
[msg, chatbot, token_state, seed_slider, temp_slider, top_k_slider],
[chatbot, token_state, seed_slider, temp_slider, top_k_slider], # Pass params back to update state if needed
queue=True # Enable queue for streaming
).then(
lambda: "", outputs=msg # Clear textbox
)
clear_btn = gr.Button("Clear Chat History")
clear_btn.click(
clear_history,
outputs=[chatbot, token_state, seed_slider, temp_slider, top_k_slider],
queue=False
)
demo.queue().launch(debug=True) # .queue() is important for streaming updates |