Spaces:
Sleeping
Sleeping
# app.py (Gradio Client) | |
import gradio as gr | |
import importlib.util # Still needed for CONFIG for max_seq_len | |
from tokenizers import Tokenizer # Still needed for local tokenization | |
from huggingface_hub import hf_hub_download | |
import os | |
import requests # For making HTTP requests | |
import json | |
# --- Configuration and Local Tokenizer (Still needed for UI-side processing) --- | |
model_repo = "TimurHromek/HROM-V1" | |
INFERENCE_SERVER_URL = "http://localhost:5000/generate" # CHANGE THIS to your actual https://inference.stormsurge.xyz/generate | |
# 1. Import trainer module components (ONLY for CONFIG if needed locally) | |
trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trainer.py") | |
spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file) | |
trainer_module = importlib.util.module_from_spec(spec) | |
spec.loader.exec_module(trainer_module) | |
CONFIG = trainer_module.CONFIG # We need CONFIG["max_seq_len"] | |
# 2. Load tokenizer (locally for encoding user input) | |
tokenizer_file = hf_hub_download(repo_id=model_repo, filename="tokenizer/hrom_tokenizer.json") | |
tokenizer = Tokenizer.from_file(tokenizer_file) | |
max_response_length_config = 200 # Max tokens the *server* should generate for one response | |
# Model and SafetyManager are NOT loaded/used on the Gradio client side for generation anymore. | |
def process_message(user_input, chat_history, token_history, seed, temperature, top_k): | |
if not user_input.strip(): | |
chat_history.append((user_input, "Please provide some input.")) | |
return chat_history, token_history, seed, temperature, top_k # Pass back params | |
# 1. Process user input and update token_history | |
user_turn_text = f"<user> {user_input} </s>" | |
user_tokens = tokenizer.encode(user_turn_text).ids | |
token_history.extend(user_tokens) | |
# 2. Add assistant marker to token_history for the server to start generating after it | |
assistant_start_token = tokenizer.token_to_id("<assistant>") | |
token_history.append(assistant_start_token) | |
# 3. Prepare input sequence for the server (truncation) | |
# The server expects the full context it needs to start generating the assistant's reply. | |
# The max_response_length_config is for the *output*, so the input can be | |
# max_seq_len - max_response_length_config. | |
# The token_history already includes <s> from previous turns or initial state. | |
current_input_for_server = token_history.copy() | |
max_input_len_for_server = CONFIG["max_seq_len"] - max_response_length_config | |
if len(current_input_for_server) > max_input_len_for_server: | |
# If too long, truncate from the beginning, but ensure <s> is kept if present | |
# More robust: find first <s> after initial part if truncating heavily. | |
# Simple truncation for now: | |
num_tokens_to_remove = len(current_input_for_server) - max_input_len_for_server | |
# Keep <s> if it's the first token | |
if current_input_for_server and current_input_for_server[0] == tokenizer.token_to_id("<s>"): | |
current_input_for_server = [tokenizer.token_to_id("<s>")] + current_input_for_server[1+num_tokens_to_remove:] | |
else: | |
current_input_for_server = current_input_for_server[num_tokens_to_remove:] | |
# Update token_history to reflect the truncated version sent to server | |
# This is important so the client's token_history matches what the server 'sees' as context | |
token_history = current_input_for_server.copy() | |
# 4. Call the inference server | |
payload = { | |
"token_history": current_input_for_server, # This now includes <s>...<user>...</s><assistant> | |
"max_response_length": max_response_length_config, | |
"temperature": temperature, | |
"top_k": top_k, | |
"seed": seed if seed > 0 else None # Send None if seed is 0 or negative (for random) | |
} | |
assistant_response_text = "" | |
assistant_response_token_ids = [] # Store IDs of the assistant's response | |
chat_history.append((user_input, "")) # Add user message, prepare for streaming assistant | |
try: | |
with requests.post(INFERENCE_SERVER_URL, json=payload, stream=True, timeout=120) as r: | |
r.raise_for_status() # Raise an exception for HTTP errors | |
for line in r.iter_lines(): | |
if line: | |
decoded_line = line.decode('utf-8') | |
if decoded_line.startswith('data: '): | |
try: | |
event_data_json = decoded_line[len('data: '):] | |
event_data = json.loads(event_data_json) | |
if event_data.get("type") == "token": | |
token_text = event_data.get("text", "") | |
token_id = event_data.get("token_id") | |
assistant_response_text += token_text | |
if token_id is not None: | |
assistant_response_token_ids.append(token_id) | |
chat_history[-1] = (user_input, assistant_response_text) | |
yield chat_history, token_history, seed, temperature, top_k # Update UI progressively | |
elif event_data.get("type") == "eos": | |
# End of sentence token received | |
eos_token_id = event_data.get("token_id") | |
if eos_token_id is not None: | |
assistant_response_token_ids.append(eos_token_id) | |
# The server should have sent </s>. We add it to token history. | |
break # Stop processing more tokens for this response | |
elif event_data.get("type") == "stop": | |
reason = event_data.get("reason", "unknown reason") | |
assistant_response_text += f"\n[Generation stopped: {reason}]" | |
chat_history[-1] = (user_input, assistant_response_text) | |
yield chat_history, token_history, seed, temperature, top_k | |
break | |
elif event_data.get("type") == "stream_end": | |
# Server explicitly signals end of stream | |
break | |
elif event_data.get("type") == "error": | |
err_msg = event_data.get("message", "Unknown server error") | |
assistant_response_text += f"\n[Server Error: {err_msg}]" | |
chat_history[-1] = (user_input, assistant_response_text) | |
yield chat_history, token_history, seed, temperature, top_k | |
break | |
except json.JSONDecodeError: | |
print(f"Failed to parse JSON: {decoded_line}") | |
except Exception as e: | |
print(f"Error processing stream line: {e}") | |
assistant_response_text += f"\n[Client Error: {e}]" | |
chat_history[-1] = (user_input, assistant_response_text) | |
yield chat_history, token_history, seed, temperature, top_k | |
break # Stop on error | |
except requests.exceptions.RequestException as e: | |
assistant_response_text = f"Error connecting to inference server: {e}" | |
chat_history[-1] = (user_input, assistant_response_text) | |
# No new tokens to add to token_history from assistant | |
yield chat_history, token_history, seed, temperature, top_k | |
return # Exit the generator | |
# After stream is complete (or broken): | |
# Update the main token_history with the assistant's generated tokens | |
# The assistant_start_token was already added before calling the server. | |
# The assistant_response_token_ids are the tokens *after* <assistant>. | |
token_history.extend(assistant_response_token_ids) | |
# Ensure </s> is at the end of the assistant's part in token_history if not already | |
# (The server stream should ideally send eos_token_id for this) | |
if not assistant_response_token_ids or assistant_response_token_ids[-1] != tokenizer.token_to_id("</s>"): | |
if event_data.get("type") != "eos": # if it wasn't already an EOS event that added it | |
token_history.append(tokenizer.token_to_id("</s>")) | |
# Final update after generation is fully done | |
if not assistant_response_text.strip(): # If nothing was generated | |
chat_history[-1] = (user_input, "I couldn't generate a proper response.") | |
# Update seed for next turn if it was used (randomize if seed was > 0) | |
# If seed was <=0, it means use random, so keep it that way. | |
if seed > 0: | |
new_seed = seed + 1 # Or any other logic to change the seed for next turn | |
else: | |
new_seed = seed # Keep as random | |
yield chat_history, token_history, new_seed, temperature, top_k | |
def clear_history(): | |
# Initial token_history should start with <s> | |
initial_token_history = [tokenizer.token_to_id("<s>")] | |
return [], initial_token_history, -1, 0.7, 50 # Cleared history, initial tokens, default seed/temp/top_k | |
with gr.Blocks() as demo: | |
gr.Markdown("# HROM-V1 Chatbot (Remote Inference)") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
seed_slider = gr.Slider(minimum=-1, maximum=99999, value=-1, step=1, label="Seed (-1 for random)") | |
temp_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature") | |
top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K (0 for no Top-K)") | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot(height=500, label="Chat") | |
msg = gr.Textbox(label="Your Message", placeholder="Type your message here...") | |
# token_state stores the *entire conversation history as token IDs* | |
# It should be initialized with the <s> token. | |
initial_tokens = [tokenizer.token_to_id("<s>")] | |
token_state = gr.State(initial_tokens) | |
# Parameters for generation | |
# seed_state = gr.State(-1) # -1 for random | |
# temp_state = gr.State(0.7) | |
# top_k_state = gr.State(50) # 0 to disable | |
# Chain actions: submit text -> process_message (yields updates) -> clear textbox | |
msg.submit( | |
process_message, | |
[msg, chatbot, token_state, seed_slider, temp_slider, top_k_slider], | |
[chatbot, token_state, seed_slider, temp_slider, top_k_slider], # Pass params back to update state if needed | |
queue=True # Enable queue for streaming | |
).then( | |
lambda: "", outputs=msg # Clear textbox | |
) | |
clear_btn = gr.Button("Clear Chat History") | |
clear_btn.click( | |
clear_history, | |
outputs=[chatbot, token_state, seed_slider, temp_slider, top_k_slider], | |
queue=False | |
) | |
demo.queue().launch(debug=True) # .queue() is important for streaming updates |