# app.py (Gradio Client) import gradio as gr import importlib.util # Still needed for CONFIG for max_seq_len from tokenizers import Tokenizer # Still needed for local tokenization from huggingface_hub import hf_hub_download import os import requests # For making HTTP requests import json # --- Configuration and Local Tokenizer (Still needed for UI-side processing) --- model_repo = "TimurHromek/HROM-V1" INFERENCE_SERVER_URL = "http://localhost:5000/generate" # CHANGE THIS to your actual https://inference.stormsurge.xyz/generate # 1. Import trainer module components (ONLY for CONFIG if needed locally) trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trainer.py") spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file) trainer_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(trainer_module) CONFIG = trainer_module.CONFIG # We need CONFIG["max_seq_len"] # 2. Load tokenizer (locally for encoding user input) tokenizer_file = hf_hub_download(repo_id=model_repo, filename="tokenizer/hrom_tokenizer.json") tokenizer = Tokenizer.from_file(tokenizer_file) max_response_length_config = 200 # Max tokens the *server* should generate for one response # Model and SafetyManager are NOT loaded/used on the Gradio client side for generation anymore. def process_message(user_input, chat_history, token_history, seed, temperature, top_k): if not user_input.strip(): chat_history.append((user_input, "Please provide some input.")) return chat_history, token_history, seed, temperature, top_k # Pass back params # 1. Process user input and update token_history user_turn_text = f" {user_input} " user_tokens = tokenizer.encode(user_turn_text).ids token_history.extend(user_tokens) # 2. Add assistant marker to token_history for the server to start generating after it assistant_start_token = tokenizer.token_to_id("") token_history.append(assistant_start_token) # 3. Prepare input sequence for the server (truncation) # The server expects the full context it needs to start generating the assistant's reply. # The max_response_length_config is for the *output*, so the input can be # max_seq_len - max_response_length_config. # The token_history already includes from previous turns or initial state. current_input_for_server = token_history.copy() max_input_len_for_server = CONFIG["max_seq_len"] - max_response_length_config if len(current_input_for_server) > max_input_len_for_server: # If too long, truncate from the beginning, but ensure is kept if present # More robust: find first after initial part if truncating heavily. # Simple truncation for now: num_tokens_to_remove = len(current_input_for_server) - max_input_len_for_server # Keep if it's the first token if current_input_for_server and current_input_for_server[0] == tokenizer.token_to_id(""): current_input_for_server = [tokenizer.token_to_id("")] + current_input_for_server[1+num_tokens_to_remove:] else: current_input_for_server = current_input_for_server[num_tokens_to_remove:] # Update token_history to reflect the truncated version sent to server # This is important so the client's token_history matches what the server 'sees' as context token_history = current_input_for_server.copy() # 4. Call the inference server payload = { "token_history": current_input_for_server, # This now includes ...... "max_response_length": max_response_length_config, "temperature": temperature, "top_k": top_k, "seed": seed if seed > 0 else None # Send None if seed is 0 or negative (for random) } assistant_response_text = "" assistant_response_token_ids = [] # Store IDs of the assistant's response chat_history.append((user_input, "")) # Add user message, prepare for streaming assistant try: with requests.post(INFERENCE_SERVER_URL, json=payload, stream=True, timeout=120) as r: r.raise_for_status() # Raise an exception for HTTP errors for line in r.iter_lines(): if line: decoded_line = line.decode('utf-8') if decoded_line.startswith('data: '): try: event_data_json = decoded_line[len('data: '):] event_data = json.loads(event_data_json) if event_data.get("type") == "token": token_text = event_data.get("text", "") token_id = event_data.get("token_id") assistant_response_text += token_text if token_id is not None: assistant_response_token_ids.append(token_id) chat_history[-1] = (user_input, assistant_response_text) yield chat_history, token_history, seed, temperature, top_k # Update UI progressively elif event_data.get("type") == "eos": # End of sentence token received eos_token_id = event_data.get("token_id") if eos_token_id is not None: assistant_response_token_ids.append(eos_token_id) # The server should have sent . We add it to token history. break # Stop processing more tokens for this response elif event_data.get("type") == "stop": reason = event_data.get("reason", "unknown reason") assistant_response_text += f"\n[Generation stopped: {reason}]" chat_history[-1] = (user_input, assistant_response_text) yield chat_history, token_history, seed, temperature, top_k break elif event_data.get("type") == "stream_end": # Server explicitly signals end of stream break elif event_data.get("type") == "error": err_msg = event_data.get("message", "Unknown server error") assistant_response_text += f"\n[Server Error: {err_msg}]" chat_history[-1] = (user_input, assistant_response_text) yield chat_history, token_history, seed, temperature, top_k break except json.JSONDecodeError: print(f"Failed to parse JSON: {decoded_line}") except Exception as e: print(f"Error processing stream line: {e}") assistant_response_text += f"\n[Client Error: {e}]" chat_history[-1] = (user_input, assistant_response_text) yield chat_history, token_history, seed, temperature, top_k break # Stop on error except requests.exceptions.RequestException as e: assistant_response_text = f"Error connecting to inference server: {e}" chat_history[-1] = (user_input, assistant_response_text) # No new tokens to add to token_history from assistant yield chat_history, token_history, seed, temperature, top_k return # Exit the generator # After stream is complete (or broken): # Update the main token_history with the assistant's generated tokens # The assistant_start_token was already added before calling the server. # The assistant_response_token_ids are the tokens *after* . token_history.extend(assistant_response_token_ids) # Ensure is at the end of the assistant's part in token_history if not already # (The server stream should ideally send eos_token_id for this) if not assistant_response_token_ids or assistant_response_token_ids[-1] != tokenizer.token_to_id(""): if event_data.get("type") != "eos": # if it wasn't already an EOS event that added it token_history.append(tokenizer.token_to_id("")) # Final update after generation is fully done if not assistant_response_text.strip(): # If nothing was generated chat_history[-1] = (user_input, "I couldn't generate a proper response.") # Update seed for next turn if it was used (randomize if seed was > 0) # If seed was <=0, it means use random, so keep it that way. if seed > 0: new_seed = seed + 1 # Or any other logic to change the seed for next turn else: new_seed = seed # Keep as random yield chat_history, token_history, new_seed, temperature, top_k def clear_history(): # Initial token_history should start with initial_token_history = [tokenizer.token_to_id("")] return [], initial_token_history, -1, 0.7, 50 # Cleared history, initial tokens, default seed/temp/top_k with gr.Blocks() as demo: gr.Markdown("# HROM-V1 Chatbot (Remote Inference)") with gr.Row(): with gr.Column(scale=1): seed_slider = gr.Slider(minimum=-1, maximum=99999, value=-1, step=1, label="Seed (-1 for random)") temp_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature") top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K (0 for no Top-K)") with gr.Column(scale=3): chatbot = gr.Chatbot(height=500, label="Chat") msg = gr.Textbox(label="Your Message", placeholder="Type your message here...") # token_state stores the *entire conversation history as token IDs* # It should be initialized with the token. initial_tokens = [tokenizer.token_to_id("")] token_state = gr.State(initial_tokens) # Parameters for generation # seed_state = gr.State(-1) # -1 for random # temp_state = gr.State(0.7) # top_k_state = gr.State(50) # 0 to disable # Chain actions: submit text -> process_message (yields updates) -> clear textbox msg.submit( process_message, [msg, chatbot, token_state, seed_slider, temp_slider, top_k_slider], [chatbot, token_state, seed_slider, temp_slider, top_k_slider], # Pass params back to update state if needed queue=True # Enable queue for streaming ).then( lambda: "", outputs=msg # Clear textbox ) clear_btn = gr.Button("Clear Chat History") clear_btn.click( clear_history, outputs=[chatbot, token_state, seed_slider, temp_slider, top_k_slider], queue=False ) demo.queue().launch(debug=True) # .queue() is important for streaming updates