HROM-V1 / app.py
elapt1c's picture
Update app.py
ac2e861 verified
raw
history blame
11.2 kB
# app.py (Gradio Client)
import gradio as gr
import importlib.util # Still needed for CONFIG for max_seq_len
from tokenizers import Tokenizer # Still needed for local tokenization
from huggingface_hub import hf_hub_download
import os
import requests # For making HTTP requests
import json
# --- Configuration and Local Tokenizer (Still needed for UI-side processing) ---
model_repo = "TimurHromek/HROM-V1"
INFERENCE_SERVER_URL = "http://localhost:5000/generate" # CHANGE THIS to your actual https://inference.stormsurge.xyz/generate
# 1. Import trainer module components (ONLY for CONFIG if needed locally)
trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trainer.py")
spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file)
trainer_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(trainer_module)
CONFIG = trainer_module.CONFIG # We need CONFIG["max_seq_len"]
# 2. Load tokenizer (locally for encoding user input)
tokenizer_file = hf_hub_download(repo_id=model_repo, filename="tokenizer/hrom_tokenizer.json")
tokenizer = Tokenizer.from_file(tokenizer_file)
max_response_length_config = 200 # Max tokens the *server* should generate for one response
# Model and SafetyManager are NOT loaded/used on the Gradio client side for generation anymore.
def process_message(user_input, chat_history, token_history, seed, temperature, top_k):
if not user_input.strip():
chat_history.append((user_input, "Please provide some input."))
return chat_history, token_history, seed, temperature, top_k # Pass back params
# 1. Process user input and update token_history
user_turn_text = f"<user> {user_input} </s>"
user_tokens = tokenizer.encode(user_turn_text).ids
token_history.extend(user_tokens)
# 2. Add assistant marker to token_history for the server to start generating after it
assistant_start_token = tokenizer.token_to_id("<assistant>")
token_history.append(assistant_start_token)
# 3. Prepare input sequence for the server (truncation)
# The server expects the full context it needs to start generating the assistant's reply.
# The max_response_length_config is for the *output*, so the input can be
# max_seq_len - max_response_length_config.
# The token_history already includes <s> from previous turns or initial state.
current_input_for_server = token_history.copy()
max_input_len_for_server = CONFIG["max_seq_len"] - max_response_length_config
if len(current_input_for_server) > max_input_len_for_server:
# If too long, truncate from the beginning, but ensure <s> is kept if present
# More robust: find first <s> after initial part if truncating heavily.
# Simple truncation for now:
num_tokens_to_remove = len(current_input_for_server) - max_input_len_for_server
# Keep <s> if it's the first token
if current_input_for_server and current_input_for_server[0] == tokenizer.token_to_id("<s>"):
current_input_for_server = [tokenizer.token_to_id("<s>")] + current_input_for_server[1+num_tokens_to_remove:]
else:
current_input_for_server = current_input_for_server[num_tokens_to_remove:]
# Update token_history to reflect the truncated version sent to server
# This is important so the client's token_history matches what the server 'sees' as context
token_history = current_input_for_server.copy()
# 4. Call the inference server
payload = {
"token_history": current_input_for_server, # This now includes <s>...<user>...</s><assistant>
"max_response_length": max_response_length_config,
"temperature": temperature,
"top_k": top_k,
"seed": seed if seed > 0 else None # Send None if seed is 0 or negative (for random)
}
assistant_response_text = ""
assistant_response_token_ids = [] # Store IDs of the assistant's response
chat_history.append((user_input, "")) # Add user message, prepare for streaming assistant
try:
with requests.post(INFERENCE_SERVER_URL, json=payload, stream=True, timeout=120) as r:
r.raise_for_status() # Raise an exception for HTTP errors
for line in r.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if decoded_line.startswith('data: '):
try:
event_data_json = decoded_line[len('data: '):]
event_data = json.loads(event_data_json)
if event_data.get("type") == "token":
token_text = event_data.get("text", "")
token_id = event_data.get("token_id")
assistant_response_text += token_text
if token_id is not None:
assistant_response_token_ids.append(token_id)
chat_history[-1] = (user_input, assistant_response_text)
yield chat_history, token_history, seed, temperature, top_k # Update UI progressively
elif event_data.get("type") == "eos":
# End of sentence token received
eos_token_id = event_data.get("token_id")
if eos_token_id is not None:
assistant_response_token_ids.append(eos_token_id)
# The server should have sent </s>. We add it to token history.
break # Stop processing more tokens for this response
elif event_data.get("type") == "stop":
reason = event_data.get("reason", "unknown reason")
assistant_response_text += f"\n[Generation stopped: {reason}]"
chat_history[-1] = (user_input, assistant_response_text)
yield chat_history, token_history, seed, temperature, top_k
break
elif event_data.get("type") == "stream_end":
# Server explicitly signals end of stream
break
elif event_data.get("type") == "error":
err_msg = event_data.get("message", "Unknown server error")
assistant_response_text += f"\n[Server Error: {err_msg}]"
chat_history[-1] = (user_input, assistant_response_text)
yield chat_history, token_history, seed, temperature, top_k
break
except json.JSONDecodeError:
print(f"Failed to parse JSON: {decoded_line}")
except Exception as e:
print(f"Error processing stream line: {e}")
assistant_response_text += f"\n[Client Error: {e}]"
chat_history[-1] = (user_input, assistant_response_text)
yield chat_history, token_history, seed, temperature, top_k
break # Stop on error
except requests.exceptions.RequestException as e:
assistant_response_text = f"Error connecting to inference server: {e}"
chat_history[-1] = (user_input, assistant_response_text)
# No new tokens to add to token_history from assistant
yield chat_history, token_history, seed, temperature, top_k
return # Exit the generator
# After stream is complete (or broken):
# Update the main token_history with the assistant's generated tokens
# The assistant_start_token was already added before calling the server.
# The assistant_response_token_ids are the tokens *after* <assistant>.
token_history.extend(assistant_response_token_ids)
# Ensure </s> is at the end of the assistant's part in token_history if not already
# (The server stream should ideally send eos_token_id for this)
if not assistant_response_token_ids or assistant_response_token_ids[-1] != tokenizer.token_to_id("</s>"):
if event_data.get("type") != "eos": # if it wasn't already an EOS event that added it
token_history.append(tokenizer.token_to_id("</s>"))
# Final update after generation is fully done
if not assistant_response_text.strip(): # If nothing was generated
chat_history[-1] = (user_input, "I couldn't generate a proper response.")
# Update seed for next turn if it was used (randomize if seed was > 0)
# If seed was <=0, it means use random, so keep it that way.
if seed > 0:
new_seed = seed + 1 # Or any other logic to change the seed for next turn
else:
new_seed = seed # Keep as random
yield chat_history, token_history, new_seed, temperature, top_k
def clear_history():
# Initial token_history should start with <s>
initial_token_history = [tokenizer.token_to_id("<s>")]
return [], initial_token_history, -1, 0.7, 50 # Cleared history, initial tokens, default seed/temp/top_k
with gr.Blocks() as demo:
gr.Markdown("# HROM-V1 Chatbot (Remote Inference)")
with gr.Row():
with gr.Column(scale=1):
seed_slider = gr.Slider(minimum=-1, maximum=99999, value=-1, step=1, label="Seed (-1 for random)")
temp_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature")
top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K (0 for no Top-K)")
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=500, label="Chat")
msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")
# token_state stores the *entire conversation history as token IDs*
# It should be initialized with the <s> token.
initial_tokens = [tokenizer.token_to_id("<s>")]
token_state = gr.State(initial_tokens)
# Parameters for generation
# seed_state = gr.State(-1) # -1 for random
# temp_state = gr.State(0.7)
# top_k_state = gr.State(50) # 0 to disable
# Chain actions: submit text -> process_message (yields updates) -> clear textbox
msg.submit(
process_message,
[msg, chatbot, token_state, seed_slider, temp_slider, top_k_slider],
[chatbot, token_state, seed_slider, temp_slider, top_k_slider], # Pass params back to update state if needed
queue=True # Enable queue for streaming
).then(
lambda: "", outputs=msg # Clear textbox
)
clear_btn = gr.Button("Clear Chat History")
clear_btn.click(
clear_history,
outputs=[chatbot, token_state, seed_slider, temp_slider, top_k_slider],
queue=False
)
demo.queue().launch(debug=True) # .queue() is important for streaming updates