File size: 18,377 Bytes
cb919f0
7ab8722
c5a20a4
ea82e64
cb919f0
 
 
 
c3b8601
 
cb919f0
c3b8601
cb919f0
7ab8722
 
 
 
81286e1
 
 
7ab8722
717cd1f
7ab8722
 
 
 
81286e1
 
 
 
 
7ab8722
81286e1
7ab8722
81286e1
 
 
 
 
cb919f0
7ab8722
c3b8601
7ab8722
 
717cd1f
 
 
 
 
7ab8722
c3b8601
7ab8722
 
 
cb919f0
7ab8722
 
c3b8601
7ab8722
 
 
 
c3b8601
 
7ab8722
 
 
c3b8601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ab8722
c3b8601
 
 
 
 
 
 
 
7ab8722
c3b8601
 
 
 
 
 
 
 
7ab8722
 
 
c3b8601
7ab8722
c3b8601
 
 
 
7ab8722
 
 
c3b8601
7ab8722
 
c3b8601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81b2233
7ab8722
 
 
 
 
81b2233
7ab8722
c3b8601
7ab8722
c3b8601
 
 
 
 
 
 
 
 
717cd1f
c3b8601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f66243
81b2233
c3b8601
 
 
 
 
 
a7fbaae
c3b8601
81286e1
717cd1f
cb919f0
81286e1
 
c3b8601
 
 
cb919f0
81286e1
a7fbaae
4fa442d
a7fbaae
 
 
 
 
 
 
717cd1f
 
 
7ab8722
 
 
81286e1
 
cb919f0
dc27384
c3b8601
 
 
dc27384
c3b8601
 
6f66243
c3b8601
 
7ab8722
 
c3b8601
 
 
7ab8722
 
 
c3b8601
 
 
7ab8722
 
c3b8601
7ab8722
717cd1f
c3b8601
 
 
 
a7fbaae
7ab8722
c3b8601
 
7ab8722
a7fbaae
c3b8601
717cd1f
c3b8601
 
 
 
 
 
7ab8722
c3b8601
 
 
7ab8722
 
c3b8601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ab8722
c3b8601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ab8722
717cd1f
c3b8601
4fa442d
7ab8722
4db9e4f
 
c3b8601
4db9e4f
717cd1f
4db9e4f
7ab8722
717cd1f
c3b8601
 
a7fbaae
717cd1f
 
c3b8601
 
 
7ab8722
c3b8601
7ab8722
c3b8601
7ab8722
c3b8601
 
cb919f0
717cd1f
cb919f0
 
717cd1f
c3b8601
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json
import base64
from PIL import Image
import io

# Load the default access token from environment variable at startup
# This will be used if no custom key is provided by the user.
ACCESS_TOKEN = os.getenv("HF_TOKEN")
print(f"Default HF_TOKEN from environment loaded: {'Present' if ACCESS_TOKEN else 'Not set'}")

# Function to encode image to base64
def encode_image(image_path):
    if not image_path:
        print("No image path provided")
        return None
    
    try:
        print(f"Encoding image from path: {image_path}")
        
        if isinstance(image_path, Image.Image):
            image = image_path
        else:
            image = Image.open(image_path)
        
        if image.mode == 'RGBA':
            image = image.convert('RGB')
        
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        print("Image encoded successfully")
        return img_str
    except Exception as e:
        print(f"Error encoding image: {e}")
        return None

def respond(
    message,
    image_files,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    frequency_penalty,
    seed,
    provider,
    custom_api_key, # This is the value from the BYOK textbox
    custom_model,    
    model_search_term,
    selected_model
):
    print(f"Received message: {message}")
    print(f"Received {len(image_files) if image_files else 0} images")
    # print(f"History: {history}") # Can be very verbose
    print(f"System message: {system_message}")
    print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
    print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
    print(f"Selected provider: {provider}")         
    print(f"Custom API Key input field value (raw): '{custom_api_key[:10]}...' (masked if long)")
    print(f"Selected model (custom_model input field): {custom_model}")  
    print(f"Model search term: {model_search_term}")
    print(f"Selected model from radio: {selected_model}")

    token_to_use = None
    original_hf_token_env_value = os.environ.get("HF_TOKEN")
    env_hf_token_temporarily_modified = False

    if custom_api_key and custom_api_key.strip():
        token_to_use = custom_api_key.strip()
        print(f"USING CUSTOM API KEY (BYOK): '{token_to_use[:5]}...' (masked for security).")
        # Aggressively ensure custom key is fundamental:
        # Temporarily remove HF_TOKEN from os.environ if it exists,
        # to prevent any possibility of InferenceClient picking it up.
        if "HF_TOKEN" in os.environ:
            print(f"Temporarily unsetting HF_TOKEN from environment (was: {'Present' if os.environ.get('HF_TOKEN') else 'Not set'}) to prioritize custom key.")
            del os.environ["HF_TOKEN"]
            env_hf_token_temporarily_modified = True
    elif ACCESS_TOKEN: # Use default token from environment if no custom key
        token_to_use = ACCESS_TOKEN
        print(f"USING DEFAULT API KEY (HF_TOKEN from environment variable at script start): '{token_to_use[:5]}...' (masked for security).")
        # Ensure HF_TOKEN is set in the current env if it was loaded at start
        # This handles cases where it might have been unset by a previous call with a custom key
        if original_hf_token_env_value is not None:
            os.environ["HF_TOKEN"] = original_hf_token_env_value
        elif "HF_TOKEN" in os.environ: # If ACCESS_TOKEN was loaded but original_hf_token_env_value was None (e.g. set by other means)
             pass # Let it be whatever it is
    else:
        print("No custom API key provided AND no default HF_TOKEN was found in environment at script start.")
        print("InferenceClient will be initialized without an explicit token. May fail or use public access.")
        # token_to_use remains None
        # If HF_TOKEN was in env and we want to ensure it's not used when token_to_use is None:
        if "HF_TOKEN" in os.environ:
            print(f"Temporarily unsetting HF_TOKEN from environment (was: {'Present' if os.environ.get('HF_TOKEN') else 'Not set'}) as no valid key is chosen.")
            del os.environ["HF_TOKEN"]
            env_hf_token_temporarily_modified = True # Mark for restoration

    print(f"Final token being passed to InferenceClient: '{str(token_to_use)[:5]}...' (masked)" if token_to_use else "None")

    try:
        client = InferenceClient(token=token_to_use, provider=provider)
        print(f"Hugging Face Inference Client initialized with {provider} provider.")

        if seed == -1:
            seed = None

        user_content = []
        if message and message.strip():
            user_content.append({"type": "text", "text": message})
        
        if image_files:
            for img_path in image_files:
                if img_path:
                    encoded_image = encode_image(img_path)
                    if encoded_image:
                        user_content.append({
                            "type": "image_url",
                            "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}
                        })
        
        if not user_content: # If only images were sent and none encoded, or empty message
             if image_files: # If there were image files, it implies an image-only message
                 user_content = [{"type": "text", "text": ""}] # Send an empty text for context, or specific prompt
             else: # Truly empty input
                 yield "Error: Empty message content."
                 return


        messages = [{"role": "system", "content": system_message}]
        for val in history:
            user_part, assistant_part = val
            # Handle multimodal history if necessary (simplified for now)
            if isinstance(user_part, dict) and 'files' in user_part: # from MultimodalTextbox
                history_text = user_part.get("text", "")
                history_files = user_part.get("files", [])
                current_user_content_history = []
                if history_text:
                    current_user_content_history.append({"type": "text", "text": history_text})
                for h_img_path in history_files:
                    encoded_h_img = encode_image(h_img_path)
                    if encoded_h_img:
                        current_user_content_history.append({
                            "type": "image_url",
                            "image_url": {"url": f"data:image/jpeg;base64,{encoded_h_img}"}
                        })
                if current_user_content_history:
                     messages.append({"role": "user", "content": current_user_content_history})
            elif isinstance(user_part, str): # from simple text history
                 messages.append({"role": "user", "content": user_part})

            if assistant_part:
                messages.append({"role": "assistant", "content": assistant_part})

        messages.append({"role": "user", "content": user_content if len(user_content) > 1 or not isinstance(user_content[0], dict) or user_content[0].get("type") != "text" else user_content[0]["text"]})


        model_to_use = custom_model.strip() if custom_model.strip() else selected_model
        print(f"Model selected for inference: {model_to_use}")

        response_text = ""
        print(f"Sending request to {provider} with model {model_to_use}.")

        parameters = {
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "frequency_penalty": frequency_penalty,
        }
        if seed is not None:
            parameters["seed"] = seed

        stream = client.chat_completion(
            model=model_to_use,
            messages=messages,
            stream=True,
            **parameters
        )
        
        print("Streaming response: ", end="", flush=True)
        for chunk in stream:
            if hasattr(chunk, 'choices') and chunk.choices:
                delta = chunk.choices[0].delta
                if hasattr(delta, 'content') and delta.content:
                    token_chunk = delta.content
                    print(token_chunk, end="", flush=True)
                    response_text += token_chunk
                    yield response_text
        print("\nStream finished.")
    
    except Exception as e:
        error_message = f"Error during inference: {e}"
        print(error_message)
        # If there was already some response, append error. Otherwise, yield error.
        if 'response_text' in locals() and response_text:
             response_text += f"\n{error_message}"
             yield response_text
        else:
             yield error_message
    finally:
        # Restore HF_TOKEN in os.environ if it was temporarily removed/modified
        if env_hf_token_temporarily_modified:
            if original_hf_token_env_value is not None:
                os.environ["HF_TOKEN"] = original_hf_token_env_value
                print("Restored HF_TOKEN in environment from its original value.")
            else:
                # If it was unset and originally not present, ensure it remains unset
                if "HF_TOKEN" in os.environ: # Should not happen if original was None and we deleted
                    del os.environ["HF_TOKEN"]
                print("HF_TOKEN was originally not set and was temporarily removed; ensuring it remains not set in env.")
        print("Response generation attempt complete.")


def validate_provider(api_key, provider_choice):
    # This validation might need adjustment based on providers.
    # For now, it assumes any custom key might work with other providers.
    # If HF_TOKEN is the only one available (no custom key), restrict to hf-inference.
    if not api_key.strip() and provider_choice != "hf-inference" and ACCESS_TOKEN:
        gr.Warning("Default HF_TOKEN can only be used with 'hf-inference' provider. Switching to 'hf-inference'.")
        return gr.update(value="hf-inference")
    return gr.update(value=provider_choice)

with gr.Blocks(theme="Nymbo/Nymbo_Theme") as demo:
    chatbot = gr.Chatbot(
        height=600, 
        show_copy_button=True, 
        placeholder="Select a model and begin chatting. Supports multimodal inputs.",
        layout="panel",
        avatar_images=(None, "https://hf.co/front/assets/huggingface_logo.svg") # Bot avatar
    )
    
    msg = gr.MultimodalTextbox(
        placeholder="Type a message or upload images...",
        show_label=False,
        container=False,
        scale=12,
        file_types=["image"],
        file_count="multiple",
        sources=["upload"]
    )
    
    with gr.Accordion("Settings", open=False):
        system_message_box = gr.Textbox(
            value="You are a helpful AI assistant that can understand images and text.", 
            placeholder="You are a helpful assistant.",
            label="System Prompt"
        )
        
        with gr.Row():
            with gr.Column():
                max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max tokens")
                temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.01, label="Temperature") # Allow 0 for deterministic
                top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.01, label="Top-P") # Allow 0
            with gr.Column():
                frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
                seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
        
        providers_list = ["hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"]
        provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider")
        
        byok_textbox = gr.Textbox(
            value="", label="BYOK (Bring Your Own Key)", 
            info="Enter your Hugging Face API key (or provider-specific key). Overrides default. If empty, uses Space's HF_TOKEN (if set) for 'hf-inference'.",
            placeholder="hf_... or provider_specific_key", type="password"
        )
        
        custom_model_box = gr.Textbox(
            value="", label="Custom Model ID", 
            info="(Optional) Provide a model ID (e.g., 'meta-llama/Llama-3-8B-Instruct'). Overrides featured model selection.",
            placeholder="org/model-name"
        )
        
        model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search...", lines=1)
        
        models_list = [
            "meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.1-70B-Instruct", 
            "mistralai/Mistral-Nemo-Instruct-2407", "Qwen/Qwen2.5-72B-Instruct",
            "microsoft/Phi-3.5-mini-instruct", "NousResearch/Hermes-3-Llama-3.1-8B",
            # Add more or fetch dynamically if possible
        ]
        featured_model_radio = gr.Radio(
            label="Select a Featured Model", choices=models_list, 
            value="meta-llama/Llama-3.2-11B-Vision-Instruct", interactive=True
        )
        
        gr.Markdown("[All Text Gen Models](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending) | [All Multimodal Models](https://huggingface.co/models?pipeline_tag=image-text-to-text&sort=trending)")

    # Chat history state (using chatbot component directly for history)

    def handle_user_message_submission(user_input_mmtb, chat_history_list):
        # user_input_mmtb is a dict: {"text": "...", "files": ["path1", "path2"]}
        text_content = user_input_mmtb.get("text", "")
        files = user_input_mmtb.get("files", [])
        
        # Construct the display for the user message in the chat
        # For Gradio Chatbot, user message can be a string or a tuple (text, filepath) or (None, filepath)
        # If multiple files, they need to be sent as separate messages or handled in display
        
        if not text_content and not files:
            return chat_history_list # Or raise an error/warning

        # Append user message to history.
        # The actual content for the API will be constructed in respond()
        # For display, we can show text and a placeholder for images, or actual images if supported well.
        # Let's pass the raw MultimodalTextbox output to history for now.
        chat_history_list.append([user_input_mmtb, None])
        return chat_history_list

    def handle_bot_response_generation(
        chat_history_list, system_msg, max_tokens, temp, top_p, freq_pen, seed_val, 
        prov, api_key_val, cust_model_val, search_term_val, feat_model_val
    ):
        if not chat_history_list or chat_history_list[-1][0] is None:
            yield chat_history_list # Or an error message
            return

        # The last user message is chat_history_list[-1][0]
        # It's the dict from MultimodalTextbox: {"text": "...", "files": ["path1", ...]}
        last_user_input_mmtb = chat_history_list[-1][0]
        
        current_message_text = last_user_input_mmtb.get("text", "")
        current_image_files = last_user_input_mmtb.get("files", [])

        # Prepare history for the `respond` function (excluding the current turn's user message)
        api_history = []
        for user_msg_item, bot_msg_item in chat_history_list[:-1]:
            # Convert past user messages (which are MMTB dicts) to API format or simple strings
            past_user_text = user_msg_item.get("text", "")
            # For simplicity, not including past images in API history here, but could be added
            api_history.append((past_user_text, bot_msg_item))


        # Stream the response
        full_response = ""
        for_stream_chunk in respond(
            message=current_message_text,
            image_files=current_image_files,
            history=api_history, # Pass the processed history
            system_message=system_msg,
            max_tokens=max_tokens,
            temperature=temp,
            top_p=top_p,
            frequency_penalty=freq_pen,
            seed=seed_val,
            provider=prov,
            custom_api_key=api_key_val,
            custom_model=cust_model_val,
            model_search_term=search_term_val, # Note: search_term is for UI filtering, not API
            selected_model=feat_model_val
        ):
            full_response = for_stream_chunk
            chat_history_list[-1][1] = full_response
            yield chat_history_list
            
    msg.submit(
        handle_user_message_submission,
        [msg, chatbot],
        [chatbot],
        queue=False
    ).then(
        handle_bot_response_generation,
        [chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider, 
         frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box, 
         model_search_box, featured_model_radio],
        [chatbot]
    ).then(
        lambda: gr.update(value=None),  # Clears MultimodalTextbox: {"text": None, "files": None}
        [], # No inputs needed for this
        [msg]
    )
    
    def filter_models_ui(search_term):
        filtered = [m for m in models_list if search_term.lower() in m.lower()] if search_term else models_list
        return gr.update(choices=filtered, value=filtered[0] if filtered else None)

    model_search_box.change(fn=filter_models_ui, inputs=model_search_box, outputs=featured_model_radio)
    
    # No need for set_custom_model_from_radio if custom_model_box overrides featured_model_radio directly in respond()

    byok_textbox.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
    provider_radio.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)

print("Gradio interface initialized.")

if __name__ == "__main__":
    print("Launching the demo application.")
    # ForSpaces, share=True is often implied or handled by Spaces platform
    # For local, share=True makes it public via Gradio link
    demo.queue().launch(show_api=False) # .queue() is good for handling multiple users / long tasks