File size: 16,869 Bytes
038f313
1cee504
c5a20a4
ea82e64
75bf974
 
 
57cb471
038f313
57cb471
 
 
 
 
db00df1
2d6eaa5
c6bdd15
75bf974
70d58c7
 
 
75bf974
 
6a6b98f
70d58c7
 
 
 
 
 
 
 
 
 
57cb471
70d58c7
 
 
 
 
 
75bf974
57cb471
 
 
 
 
 
 
 
 
 
 
 
 
 
038f313
27c8b8d
57cb471
27c8b8d
 
038f313
 
 
3a64d68
98674ca
9e12544
75bf974
9e12544
57cb471
 
038f313
0ef95ea
57cb471
2d6eaa5
0ef95ea
 
 
9e12544
75bf974
9e12544
d92e5cd
 
f7c4208
9e12544
 
8eb1697
 
 
 
 
9e12544
ba0614b
 
0ef95ea
 
038f313
45b3867
57cb471
 
 
 
 
 
4c304f3
57cb471
4c304f3
57cb471
4c304f3
57cb471
4c304f3
 
57cb471
4c304f3
57cb471
4c304f3
57cb471
2d6eaa5
901bafe
57cb471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a6b98f
57cb471
 
 
6a6b98f
57cb471
 
27c8b8d
d92e5cd
5b8ad4f
0ef95ea
57cb471
1cee504
 
 
3b18f78
1cee504
 
 
 
2d6eaa5
1cee504
 
 
 
 
5b8ad4f
57cb471
1cee504
75bf974
1cee504
2d6eaa5
23119eb
 
1cee504
 
 
57cb471
 
 
 
 
23119eb
 
1cee504
 
57cb471
 
1cee504
0ef95ea
901bafe
57cb471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2ae72a
 
 
 
 
57cb471
 
75bf974
 
7c1212e
75bf974
 
 
 
70d58c7
 
 
 
 
 
 
 
 
57cb471
 
 
 
 
 
75bf974
57fd5c0
 
 
 
 
75bf974
57fd5c0
 
57cb471
 
 
57fd5c0
57cb471
 
57fd5c0
 
57cb471
57fd5c0
 
57cb471
 
 
 
57fd5c0
 
57cb471
 
 
 
 
 
 
57fd5c0
 
57cb471
06cdbf8
7c1212e
57cb471
75bf974
57cb471
75bf974
 
 
57cb471
d92e5cd
57cb471
75bf974
 
b0cbd1c
57cb471
 
 
 
 
75bf974
57cb471
 
75bf974
57cb471
 
75bf974
57cb471
 
 
 
 
 
 
 
 
4c304f3
57cb471
 
6a6b98f
 
57cb471
 
 
 
 
75bf974
57cb471
 
 
 
6a6b98f
57cb471
 
7c1212e
57cb471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75bf974
57cb471
a9862a1
fdab9dd
57cb471
a9862a1
9e12544
57cb471
a9862a1
9e12544
57cb471
a9862a1
 
57cb471
 
 
 
 
 
a9862a1
769901b
77298b9
a9862a1
57cb471
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json
import base64
from PIL import Image
import io
from smolagents.mcp_client import MCPClient

# Global variables for MCP Client and TTS tool
mcp_client = None
tts_tool = None

# Access token from environment
ACCESS_TOKEN = os.getenv("HF_TOKEN")
print("Access token loaded.")

# Function to encode image to base64
def encode_image(image_path):
    if not image_path:
        print("No image path provided")
        return None
    
    try:
        print(f"Encoding image from path: {image_path}")
        if isinstance(image_path, Image.Image):
            image = image_path
        else:
            image = Image.open(image_path)
        
        if image.mode == 'RGBA':
            image = image.convert('RGB')
        
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        print("Image encoded successfully")
        return img_str
    except Exception as e:
        print(f"Error encoding image: {e}")
        return None

# Initialize MCP Client at startup
def init_mcp_client():
    global mcp_client, tts_tool
    try:
        mcp_client = MCPClient({"url": "https://fdaudens-kokoro-mcp.hf.space/gradio_api/mcp/sse"})
        tools = mcp_client.get_tools()
        tts_tool = next((tool for tool in tools if tool.name == "text_to_audio"), None)
        if tts_tool:
            print("Successfully connected to Kokoro TTS tool")
        else:
            print("TTS tool not found")
    except Exception as e:
        print(f"Error initializing MCP Client: {e}")

def respond(
    message,
    image_files,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    frequency_penalty,
    seed,
    provider,
    custom_api_key,
    custom_model,    
    model_search_term,
    selected_model
):
    print(f"Received message: {message}")
    print(f"Received {len(image_files) if image_files else 0} images")
    print(f"History: {history}")
    print(f"System message: {system_message}")
    print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
    print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
    print(f"Selected provider: {provider}")         
    print(f"Custom API Key provided: {bool(custom_api_key.strip())}")
    print(f"Selected model (custom_model): {custom_model}")  
    print(f"Model search term: {model_search_term}")
    print(f"Selected model from radio: {selected_model}")

    token_to_use = custom_api_key if custom_api_key.strip() != "" else ACCESS_TOKEN
    
    if custom_api_key.strip() != "":
        print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication")
    else:
        print("USING DEFAULT API KEY: Environment variable HF_TOKEN is being used for authentication")
    
    client = InferenceClient(token=token_to_use, provider=provider)
    print(f"Hugging Face Inference Client initialized with {provider} provider.")

    if seed == -1:
        seed = None

    if image_files and len(image_files) > 0:
        user_content = []
        if message and message.strip():
            user_content.append({"type": "text", "text": message})
        
        for img in image_files:
            if img is not None:
                try:
                    encoded_image = encode_image(img)
                    if encoded_image:
                        user_content.append({
                            "type": "image_url",
                            "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}
                        })
                except Exception as e:
                    print(f"Error encoding image: {e}")
    else:
        user_content = message

    messages = [{"role": "system", "content": system_message}]
    print("Initial messages array constructed.")

    for val in history:
        user_part = val[0]
        assistant_part = val[1]
        if user_part:
            if isinstance(user_part, tuple) and len(user_part) == 2:
                history_content = []
                if user_part[0]:
                    history_content.append({"type": "text", "text": user_part[0]})
                
                for img in user_part[1]:
                    if img:
                        try:
                            encoded_img = encode_image(img)
                            if encoded_img:
                                history_content.append({
                                    "type": "image_url",
                                    "image_url": {"url": f"data:image/jpeg;base64,{encoded_img}"}
                                })
                        except Exception as e:
                            print(f"Error encoding history image: {e}")
                
                messages.append({"role": "user", "content": history_content})
            else:
                messages.append({"role": "user", "content": user_part})
            print(f"Added user message to context (type: {type(user_part)})")
        
        if assistant_part:
            messages.append({"role": "assistant", "content": assistant_part})
            print(f"Added assistant message to context: {assistant_part}")

    messages.append({"role": "user", "content": user_content})
    print(f"Latest user message appended (content type: {type(user_content)})")

    model_to_use = custom_model.strip() if custom_model.strip() != "" else selected_model
    print(f"Model selected for inference: {model_to_use}")

    response = ""
    print(f"Sending request to {provider} provider.")

    parameters = {
        "max_tokens": max_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "frequency_penalty": frequency_penalty,
    }
    
    if seed is not None:
        parameters["seed"] = seed

    try:
        stream = client.chat_completion(
            model=model_to_use,
            messages=messages,
            stream=True,
            **parameters
        )
        
        print("Received tokens: ", end="", flush=True)
        
        for chunk in stream:
            if hasattr(chunk, 'choices') and len(chunk.choices) > 0:
                if hasattr(chunk.choices[0], 'delta') and hasattr(chunk.choices[0].delta, 'content'):
                    token_text = chunk.choices[0].delta.content
                    if token_text:
                        print(token_text, end="", flush=True)
                        response += token_text
                        yield response
        
        print()
    except Exception as e:
        print(f"Error during inference: {e}")
        response += f"\nError: {str(e)}"
        yield response

    print("Completed response generation.")

# Function to generate audio from the last bot response
def generate_audio(history):
    if not history or len(history) == 0:
        print("No history available for audio generation")
        return None
    last_message = history[-1][1]  # Bot's response
    if not last_message or not isinstance(last_message, str):
        print("Last message is empty or not a string")
        return None
    if tts_tool:
        try:
            # Call the TTS tool directly, expecting (sample_rate, audio_array)
            result = tts_tool(text=last_message, speed=1.0)
            if result and len(result) == 2:
                sample_rate, audio_data = result
                print("Audio generated successfully")
                return (sample_rate, audio_data)
            else:
                print("TTS tool returned invalid result")
                return None
        except Exception as e:
            print(f"Error generating audio: {e}")
            return None
    else:
        print("TTS tool not available")
        return None

def validate_provider(api_key, provider):
    if not api_key.strip() and provider != "hf-inference":
        return gr.update(value="hf-inference")
    return gr.update(value=provider)

# Gradio UI
with gr.Blocks(theme="Nymbo/Nymbo_Theme")    chatbot = gr.Chatbot(
        height=600, 
        show_copy_button=True, 
        placeholder="Select a model and begin chatting. Now supports multiple inference providers and multimodal inputs",
        layout="panel"
    )
    print("Chatbot interface created.")
    
    msg = gr.MultimodalTextbox(
        placeholder="Type a message or upload images...",
        show_label=False,
        container=False,
        scale=12,
        file_types=["image"],
        file_count="multiple",
        sources=["upload"]
    )

    # Audio generation components
    with gr.Row():
        generate_audio_btn = gr.Button("Generate Audio from Last Response")
        audio_output = gr.Audio(label="Generated Audio", type="numpy")

    with gr.Accordion("Settings", open=False):
        system_message_box = gr.Textbox(
            value="You are a helpful AI assistant that can understand images and text.", 
            placeholder="You are a helpful assistant.",
            label="System Prompt"
        )
        
        with gr.Row():
            with gr.Column():
                max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max tokens")
                temperature_slider = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
                top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P")
            with gr.Column():
                frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
                seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
        
        providers_list = [
            "hf-inference", "cerebras", "together", "sambanova", "novita", "cohere", "fireworks-ai", "hyperbolic", "nebius"
        ]
        
        provider_radio = gr.Radio(choices=providers_list, value="hf-inference", label="Inference Provider")
        byok_textbox = gr.Textbox(value="", label="BYOK (Bring Your Own Key)", info="Enter a custom Hugging Face API key here.", placeholder="Enter your Hugging Face API token", type="password")
        custom_model_box = gr.Textbox(value="", label="Custom Model", info="(Optional) Provide a custom Hugging Face model path.", placeholder="meta-llama/Llama-3.3-70B-Instruct")
        model_search_box = gr.Textbox(label="Filter Models", placeholder="Search for a featured model...", lines=1)
        
        models_list = [
            "meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct",
            "meta-llama/Llama-3.0-70B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-1B-Instruct",
            "meta-llama/Llama-3.1-8B-Instruct", "NousResearch/Hermes-3-Llama-3.1-8B", "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
            "mistralai/Mistral-Nemo-Instruct-2407", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3",
            "mistralai/Mistral-7B-Instruct-v0.2", "Qwen/Qwen3-235B-A22B", "Qwen/Qwen3-32B", "Qwen/Qwen2.5-72B-Instruct",
            "Qwen/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct", "Qwen/QwQ-32B", "Qwen/Qwen2.5-Coder-32B-Instruct",
            "microsoft/Phi-3.5-mini-instruct", "microsoft/Phi-3-mini-128k-instruct", "microsoft/Phi-3-mini-4k-instruct"
        ]

        featured_model_radio = gr.Radio(label="Select a model below", choices=models_list, value="meta-llama/Llama-3.2-11B-Vision-Instruct", interactive=True)
        gr.Markdown("[View all Text-to-Text models](https://huggingface.co/models?inference_provider=all&pipeline_tag=text-generation&sort=trending) | [View all multimodal models](https://huggingface.co/models?inference_provider=all&pipeline_tag=image-text-to-text&sort=trending)")

    chat_history = gr.State([])
    
    def filter_models(search_term):
        print(f"Filtering models with search term: {search_term}")
        filtered = [m for m in models_list if search_term.lower() in m.lower()]
        print(f"Filtered models: {filtered}")
        return gr.update(choices=filtered)

    def set_custom_model_from_radio(selected):
        print(f"Featured model selected: {selected}")
        return selected

    def user(user_message, history):
        print(f"User message received: {user_message}")
        if not user_message or (not user_message.get("text") and not user_message.get("files")):
            print("Empty message, skipping")
            return history
        
        text_content = user_message.get("text", "").strip()
        files = user_message.get("files", [])
        
        print(f"Text content: {text_content}")
        print(f"Files: {files}")
        
        if not text_content and not files:
            print("No content to display")
            return history
        
        if files and len(files) > 0:
            if text_content:
                print(f"Adding text message: {text_content}")
                history.append([text_content, None])
            
            for file_path in files:
                if file_path and isinstance(file_path, str):
                    print(f"Adding image: {file_path}")
                    history.append([f"![Image]({file_path})", None])
            
            return history
        else:
            print(f"Adding text-only message: {text_content}")
            history.append([text_content, None])
            return history
    
    def bot(history, system_msg, max_tokens, temperature, top_p, freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model):
        if not history or len(history) == 0:
            print("No history to process")
            return history
        
        user_message = history[-1][0]
        print(f"Processing user message: {user_message}")
        
        is_image = False
        image_path = None
        text_content = user_message
        
        if isinstance(user_message, str) and user_message.startswith("![Image]("):
            is_image = True
            image_path = user_message.replace("![Image](", "").replace(")", "")
            print(f"Image detected: {image_path}")
            text_content = ""
        
        text_context = ""
        if is_image and len(history) > 1:
            prev_message = history[-2][0]
            if isinstance(prev_message, str) and not prev_message.startswith("![Image]("):
                text_context = prev_message
                print(f"Using text context from previous message: {text_context}")
        
        history[-1][1] = ""
        
        if is_image:
            for response in respond(
                text_context, [image_path], history[:-1], system_msg, max_tokens, temperature, top_p,
                freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model
            ):
                history[-1][1] = response
                yield history
        else:
            for response in respond(
                text_content, None, history[:-1], system_msg, max_tokens, temperature, top_p,
                freq_penalty, seed, provider, api_key, custom_model, search_term, selected_model
            ):
                history[-1][1] = response
                yield history

    msg.submit(user, [msg, chatbot], [chatbot], queue=False).then(
        bot, [chatbot, system_message_box, max_tokens_slider, temperature_slider, top_p_slider,
              frequency_penalty_slider, seed_slider, provider_radio, byok_textbox, custom_model_box,
              model_search_box, featured_model_radio], [chatbot]
    ).then(lambda: {"text": "", "files": []}, None, [msg])
    
    model_search_box.change(fn=filter_models, inputs=model_search_box, outputs=featured_model_radio)
    print("Model search box change event linked.")

    featured_model_radio.change(fn=set_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box)
    print("Featured model radio button change event linked.")
    
    byok_textbox.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
    print("BYOK textbox change event linked.")

    provider_radio.change(fn=validate_provider, inputs=[byok_textbox, provider_radio], outputs=provider_radio)
    print("Provider radio button change event linked.")

    # Event handler for audio generation
    generate_audio_btn.click(fn=generate_audio, inputs=[chatbot], outputs=[audio_output])

    # Initialize MCP Client on app load
    demo.load(init_mcp_client)

print("Gradio interface initialized.")

if __name__ == "__main__":
    print("Launching the demo application.")
    try:
        demo.launch(server_api=True)
    finally:
        if mcp_client:
            mcp_client.close()
            print("MCP Client closed.")