File size: 11,175 Bytes
ac2e861
95d187a
ac2e861
 
bfe7166
60624d8
ac2e861
 
95d187a
ac2e861
bfe7166
ac2e861
bfe7166
ac2e861
c17825d
bfe7166
 
 
ac2e861
bfe7166
ac2e861
c17825d
bfe7166
 
ac2e861
 
 
 
 
 
 
 
 
 
 
 
95d187a
ac2e861
 
 
 
 
 
 
 
 
 
95d187a
ac2e861
 
95d187a
ac2e861
 
 
 
 
 
 
 
95d187a
ac2e861
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95d187a
ac2e861
 
 
 
 
 
 
 
 
95d187a
 
ac2e861
 
 
 
95d187a
08771f3
ac2e861
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08771f3
ac2e861
 
 
 
 
 
08771f3
 
ac2e861
 
 
08771f3
ac2e861
08771f3
 
 
 
 
ac2e861
08771f3
 
 
ac2e861
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# app.py (Gradio Client)
import gradio as gr
import importlib.util # Still needed for CONFIG for max_seq_len
from tokenizers import Tokenizer # Still needed for local tokenization
from huggingface_hub import hf_hub_download
import os
import requests # For making HTTP requests
import json

# --- Configuration and Local Tokenizer (Still needed for UI-side processing) ---
model_repo = "TimurHromek/HROM-V1"
INFERENCE_SERVER_URL = "http://localhost:5000/generate" # CHANGE THIS to your actual https://inference.stormsurge.xyz/generate

# 1. Import trainer module components (ONLY for CONFIG if needed locally)
trainer_file = hf_hub_download(repo_id=model_repo, filename="HROM-V1.5_Trainer.py")
spec = importlib.util.spec_from_file_location("HROM_Trainer", trainer_file)
trainer_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(trainer_module)
CONFIG = trainer_module.CONFIG # We need CONFIG["max_seq_len"]

# 2. Load tokenizer (locally for encoding user input)
tokenizer_file = hf_hub_download(repo_id=model_repo, filename="tokenizer/hrom_tokenizer.json")
tokenizer = Tokenizer.from_file(tokenizer_file)

max_response_length_config = 200 # Max tokens the *server* should generate for one response

# Model and SafetyManager are NOT loaded/used on the Gradio client side for generation anymore.

def process_message(user_input, chat_history, token_history, seed, temperature, top_k):
    if not user_input.strip():
        chat_history.append((user_input, "Please provide some input."))
        return chat_history, token_history, seed, temperature, top_k # Pass back params

    # 1. Process user input and update token_history
    user_turn_text = f"<user> {user_input} </s>"
    user_tokens = tokenizer.encode(user_turn_text).ids
    token_history.extend(user_tokens)

    # 2. Add assistant marker to token_history for the server to start generating after it
    assistant_start_token = tokenizer.token_to_id("<assistant>")
    token_history.append(assistant_start_token)

    # 3. Prepare input sequence for the server (truncation)
    # The server expects the full context it needs to start generating the assistant's reply.
    # The max_response_length_config is for the *output*, so the input can be
    # max_seq_len - max_response_length_config.
    # The token_history already includes <s> from previous turns or initial state.
    
    current_input_for_server = token_history.copy()
    max_input_len_for_server = CONFIG["max_seq_len"] - max_response_length_config
    
    if len(current_input_for_server) > max_input_len_for_server:
        # If too long, truncate from the beginning, but ensure <s> is kept if present
        # More robust: find first <s> after initial part if truncating heavily.
        # Simple truncation for now:
        num_tokens_to_remove = len(current_input_for_server) - max_input_len_for_server
        # Keep <s> if it's the first token
        if current_input_for_server and current_input_for_server[0] == tokenizer.token_to_id("<s>"):
            current_input_for_server = [tokenizer.token_to_id("<s>")] + current_input_for_server[1+num_tokens_to_remove:]
        else:
            current_input_for_server = current_input_for_server[num_tokens_to_remove:]
        
        # Update token_history to reflect the truncated version sent to server
        # This is important so the client's token_history matches what the server 'sees' as context
        token_history = current_input_for_server.copy()


    # 4. Call the inference server
    payload = {
        "token_history": current_input_for_server, # This now includes <s>...<user>...</s><assistant>
        "max_response_length": max_response_length_config,
        "temperature": temperature,
        "top_k": top_k,
        "seed": seed if seed > 0 else None # Send None if seed is 0 or negative (for random)
    }

    assistant_response_text = ""
    assistant_response_token_ids = [] # Store IDs of the assistant's response
    chat_history.append((user_input, "")) # Add user message, prepare for streaming assistant

    try:
        with requests.post(INFERENCE_SERVER_URL, json=payload, stream=True, timeout=120) as r:
            r.raise_for_status() # Raise an exception for HTTP errors
            for line in r.iter_lines():
                if line:
                    decoded_line = line.decode('utf-8')
                    if decoded_line.startswith('data: '):
                        try:
                            event_data_json = decoded_line[len('data: '):]
                            event_data = json.loads(event_data_json)

                            if event_data.get("type") == "token":
                                token_text = event_data.get("text", "")
                                token_id = event_data.get("token_id")
                                assistant_response_text += token_text
                                if token_id is not None:
                                    assistant_response_token_ids.append(token_id)
                                chat_history[-1] = (user_input, assistant_response_text)
                                yield chat_history, token_history, seed, temperature, top_k # Update UI progressively
                            
                            elif event_data.get("type") == "eos":
                                # End of sentence token received
                                eos_token_id = event_data.get("token_id")
                                if eos_token_id is not None:
                                    assistant_response_token_ids.append(eos_token_id)
                                # The server should have sent </s>. We add it to token history.
                                break # Stop processing more tokens for this response
                            
                            elif event_data.get("type") == "stop":
                                reason = event_data.get("reason", "unknown reason")
                                assistant_response_text += f"\n[Generation stopped: {reason}]"
                                chat_history[-1] = (user_input, assistant_response_text)
                                yield chat_history, token_history, seed, temperature, top_k
                                break

                            elif event_data.get("type") == "stream_end":
                                # Server explicitly signals end of stream
                                break
                            
                            elif event_data.get("type") == "error":
                                err_msg = event_data.get("message", "Unknown server error")
                                assistant_response_text += f"\n[Server Error: {err_msg}]"
                                chat_history[-1] = (user_input, assistant_response_text)
                                yield chat_history, token_history, seed, temperature, top_k
                                break

                        except json.JSONDecodeError:
                            print(f"Failed to parse JSON: {decoded_line}")
                        except Exception as e:
                            print(f"Error processing stream line: {e}")
                            assistant_response_text += f"\n[Client Error: {e}]"
                            chat_history[-1] = (user_input, assistant_response_text)
                            yield chat_history, token_history, seed, temperature, top_k
                            break # Stop on error

    except requests.exceptions.RequestException as e:
        assistant_response_text = f"Error connecting to inference server: {e}"
        chat_history[-1] = (user_input, assistant_response_text)
        # No new tokens to add to token_history from assistant
        yield chat_history, token_history, seed, temperature, top_k
        return # Exit the generator

    # After stream is complete (or broken):
    # Update the main token_history with the assistant's generated tokens
    # The assistant_start_token was already added before calling the server.
    # The assistant_response_token_ids are the tokens *after* <assistant>.
    token_history.extend(assistant_response_token_ids)

    # Ensure </s> is at the end of the assistant's part in token_history if not already
    # (The server stream should ideally send eos_token_id for this)
    if not assistant_response_token_ids or assistant_response_token_ids[-1] != tokenizer.token_to_id("</s>"):
        if event_data.get("type") != "eos": # if it wasn't already an EOS event that added it
            token_history.append(tokenizer.token_to_id("</s>"))


    # Final update after generation is fully done
    if not assistant_response_text.strip(): # If nothing was generated
        chat_history[-1] = (user_input, "I couldn't generate a proper response.")
    
    # Update seed for next turn if it was used (randomize if seed was > 0)
    # If seed was <=0, it means use random, so keep it that way.
    if seed > 0:
        new_seed = seed + 1 # Or any other logic to change the seed for next turn
    else:
        new_seed = seed # Keep as random

    yield chat_history, token_history, new_seed, temperature, top_k


def clear_history():
    # Initial token_history should start with <s>
    initial_token_history = [tokenizer.token_to_id("<s>")]
    return [], initial_token_history, -1, 0.7, 50 # Cleared history, initial tokens, default seed/temp/top_k


with gr.Blocks() as demo:
    gr.Markdown("# HROM-V1 Chatbot (Remote Inference)")
    
    with gr.Row():
        with gr.Column(scale=1):
            seed_slider = gr.Slider(minimum=-1, maximum=99999, value=-1, step=1, label="Seed (-1 for random)")
            temp_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature")
            top_k_slider = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K (0 for no Top-K)")
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(height=500, label="Chat")
            msg = gr.Textbox(label="Your Message", placeholder="Type your message here...")

    # token_state stores the *entire conversation history as token IDs*
    # It should be initialized with the <s> token.
    initial_tokens = [tokenizer.token_to_id("<s>")]
    token_state = gr.State(initial_tokens)
    
    # Parameters for generation
    # seed_state = gr.State(-1) # -1 for random
    # temp_state = gr.State(0.7)
    # top_k_state = gr.State(50) # 0 to disable

    # Chain actions: submit text -> process_message (yields updates) -> clear textbox
    msg.submit(
        process_message,
        [msg, chatbot, token_state, seed_slider, temp_slider, top_k_slider],
        [chatbot, token_state, seed_slider, temp_slider, top_k_slider], # Pass params back to update state if needed
        queue=True # Enable queue for streaming
    ).then(
        lambda: "", outputs=msg # Clear textbox
    )
    
    clear_btn = gr.Button("Clear Chat History")
    clear_btn.click(
        clear_history,
        outputs=[chatbot, token_state, seed_slider, temp_slider, top_k_slider],
        queue=False
    )

demo.queue().launch(debug=True) # .queue() is important for streaming updates