File size: 9,936 Bytes
038f313
fab24df
c5a20a4
5b8ad4f
038f313
db00df1
5b8ad4f
 
 
 
c6bdd15
5b8ad4f
 
 
 
 
 
038f313
5b8ad4f
038f313
27c8b8d
 
 
038f313
 
 
3a64d68
98674ca
5b8ad4f
 
038f313
0ef95ea
5b8ad4f
 
0ef95ea
5b8ad4f
0ef95ea
 
 
5b8ad4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7c4208
901bafe
0ef95ea
 
038f313
c5a20a4
5b8ad4f
901bafe
5b8ad4f
27c8b8d
5b8ad4f
 
 
27c8b8d
5b8ad4f
 
 
4df41b9
5b8ad4f
 
 
 
0ef95ea
5b8ad4f
0ef95ea
5b8ad4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ef95ea
 
901bafe
5b8ad4f
f7c4208
5b8ad4f
0ef95ea
a8fc89d
5b8ad4f
 
 
 
4df41b9
 
 
901bafe
5b8ad4f
 
 
 
4df41b9
 
5b8ad4f
 
 
 
 
 
901bafe
5b8ad4f
a8fc89d
b0cbd1c
5b8ad4f
a8fc89d
5b8ad4f
 
 
30153c5
 
 
 
 
 
817474e
5b8ad4f
a8fc89d
5b8ad4f
 
 
 
 
901bafe
0ef95ea
901bafe
5b8ad4f
 
 
 
 
0ef95ea
b0cbd1c
5b8ad4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4df41b9
b0cbd1c
0ef95ea
b0cbd1c
5b8ad4f
 
 
 
 
 
 
 
0ef95ea
5b8ad4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8fc89d
 
5b8ad4f
769901b
77298b9
5b8ad4f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import gradio as gr
from openai import OpenAI
import os
import requests # Added for potential future use, though OpenAI client handles it now

ACCESS_TOKEN = os.getenv("HF_TOKEN")
if not ACCESS_TOKEN:
    print("Warning: HF_TOKEN environment variable not set. Authentication might fail.")
else:
    print("Access token loaded.")

# Base URLs for different providers
HF_INFERENCE_BASE_URL = "https://api-inference.huggingface.co/v1/"
CEREBRAS_ROUTER_BASE_URL = "https://router.huggingface.co/cerebras/v1/" # Use base URL for OpenAI client

# Default provider
DEFAULT_PROVIDER = "hf-inference"

# --- Main Respond Function ---
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    frequency_penalty,
    seed,
    custom_model,
    inference_provider # New argument for provider selection
):

    print(f"--- New Request ---")
    print(f"Selected Inference Provider: {inference_provider}")
    print(f"Received message: {message}")
    # print(f"History: {history}") # Can be verbose
    print(f"System message: {system_message}")
    print(f"Max tokens: {max_tokens}, Temperature: {temperature}, Top-P: {top_p}")
    print(f"Frequency Penalty: {frequency_penalty}, Seed: {seed}")
    print(f"Selected model (custom_model): {custom_model}")

    # Determine the base URL based on the selected provider
    if inference_provider == "cerebras":
        base_url = CEREBRAS_ROUTER_BASE_URL
        print(f"Using Cerebras Router endpoint: {base_url}")
    else: # Default to hf-inference
        base_url = HF_INFERENCE_BASE_URL
        print(f"Using HF Inference API endpoint: {base_url}")

    # Initialize the OpenAI client dynamically for each request
    try:
        client = OpenAI(
            base_url=base_url,
            api_key=ACCESS_TOKEN,
        )
        print("OpenAI client initialized for the request.")
    except Exception as e:
        print(f"Error initializing OpenAI client: {e}")
        yield f"Error: Could not initialize API client for provider {inference_provider}. Check token and endpoint."
        return

    # Convert seed to None if -1 (meaning random)
    if seed == -1:
        seed = None

    messages = [{"role": "system", "content": system_message}]
    # print("Initial messages array constructed.") # Less verbose logging

    # Add conversation history to the context
    for val in history:
        user_part, assistant_part = val[0], val[1]
        if user_part: messages.append({"role": "user", "content": user_part})
        if assistant_part: messages.append({"role": "assistant", "content": assistant_part})

    # Append the latest user message
    messages.append({"role": "user", "content": message})
    # print("Full message context prepared.") # Less verbose logging

    # If user provided a model, use that; otherwise, fall back to a default model
    # Ensure a default model is always set if custom_model is empty
    model_to_use = custom_model.strip() if custom_model.strip() else "meta-llama/Llama-3.3-70B-Instruct"
    print(f"Model selected for inference: {model_to_use}")

    # Start streaming response
    response = ""
    print(f"Sending request to {inference_provider} via {base_url}...")

    try:
        stream = client.chat.completions.create(
            model=model_to_use,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
            frequency_penalty=frequency_penalty,
            seed=seed,
            messages=messages,
        )
        for message_chunk in stream:
            token_text = message_chunk.choices[0].delta.content
            # Handle potential None or empty tokens gracefully
            if token_text:
                # print(f"Received token: {token_text}") # Very verbose
                response += token_text
                yield response
            # Handle potential finish reason if needed (e.g., length)
            # finish_reason = message_chunk.choices[0].finish_reason
            # if finish_reason:
            #     print(f"Stream finished with reason: {finish_reason}")

    except Exception as e:
        print(f"Error during API call to {inference_provider}: {e}")
        yield f"Error: API call failed. Details: {str(e)}"
        return # Stop generation on error

    print("Completed response generation.")

# --- GRADIO UI Elements ---

chatbot = gr.Chatbot(height=600, show_copy_button=True, placeholder="Select a model and provider, then begin chatting", layout="panel")
print("Chatbot interface created.")

# Moved these inside the Accordion later
system_message_box = gr.Textbox(value="You are a helpful assistant.", label="System Prompt")
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max new tokens") # Increased default
temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature") # Adjusted range
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-P")
frequency_penalty_slider = gr.Slider(minimum=-2.0, maximum=2.0, value=0.0, step=0.1, label="Frequency Penalty")
seed_slider = gr.Slider(minimum=-1, maximum=65535, value=-1, step=1, label="Seed (-1 for random)")
custom_model_box = gr.Textbox(
    value="",
    label="Custom Model Path",
    info="(Optional) Provide a Hugging Face model path. Overrides featured model selection.",
    placeholder="meta-llama/Llama-3.3-70B-Instruct"
)

# New UI Element for Provider Selection (will be placed in Accordion)
inference_provider_radio = gr.Radio(
    choices=["hf-inference", "cerebras"],
    value=DEFAULT_PROVIDER,
    label="Inference Provider",
    info=f"Select the backend API. Default: {DEFAULT_PROVIDER}"
)
print("Inference provider radio button created.")


# --- Gradio Chat Interface Definition ---
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        # Order matters: must match the 'respond' function signature
        system_message_box,
        max_tokens_slider,
        temperature_slider,
        top_p_slider,
        frequency_penalty_slider,
        seed_slider,
        custom_model_box,
        inference_provider_radio, # Added the new input
    ],
    fill_height=True,
    chatbot=chatbot,
    theme="Nymbo/Nymbo_Theme",
    title="Multi-Provider Chat Hub",
    description="Chat with various models using different inference backends (HF Inference API or Cerebras via HF Router)."
)
print("ChatInterface object created.")

# --- Add Accordions for Settings within the Demo context ---
with demo:
    # Model Selection Accordion (existing logic)
    with gr.Accordion("Model Selection", open=False):
        model_search_box = gr.Textbox(label="Filter Featured Models", placeholder="Search...", lines=1)
        print("Model search box created.")

        # Example models list (keep your extensive list)
        models_list = [
            "meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", "meta-llama/Llama-3.1-8B-Instruct",
            "NousResearch/Hermes-3-Llama-3.1-8B", "mistralai/Mistral-Nemo-Instruct-2407", "mistralai/Mixtral-8x7B-Instruct-v0.1",
            "mistralai/Mistral-7B-Instruct-v0.3", "Qwen/Qwen3-32B", "microsoft/Phi-3.5-mini-instruct",
            # Add the rest of your models here...
        ]
        print("Models list initialized.")

        featured_model_radio = gr.Radio(
            label="Select a Featured Model",
            choices=models_list,
            value="meta-llama/Llama-3.3-70B-Instruct", # Default featured model
            interactive=True
        )
        print("Featured models radio button created.")

        def filter_models(search_term):
            print(f"Filtering models with search term: {search_term}")
            filtered = [m for m in models_list if search_term.lower() in m.lower()]
            # Ensure a valid value is selected if the current one is filtered out
            current_value = featured_model_radio.value
            if current_value not in filtered and filtered:
                 new_value = filtered[0] # Select the first available filtered model
            elif not filtered:
                 new_value = None # Or handle empty case as needed
            else:
                 new_value = current_value # Keep current if still valid
            print(f"Filtered models: {filtered}")
            return gr.update(choices=filtered, value=new_value)


        def set_custom_model_from_radio(selected_model):
            """Updates the Custom Model text box when a featured model is selected."""
            print(f"Featured model selected: {selected_model}")
            return selected_model # Directly return the selected model name

        model_search_box.change(fn=filter_models, inputs=model_search_box, outputs=featured_model_radio)
        featured_model_radio.change(fn=set_custom_model_from_radio, inputs=featured_model_radio, outputs=custom_model_box)
        print("Model selection events linked.")

    # Advanced Settings Accordion (New)
    with gr.Accordion("Advanced Settings", open=False):
        # Place the provider selection and parameter sliders here
        gr.Markdown("Configure inference parameters and select the backend provider.")
        # Add the UI elements defined earlier into this accordion
        gr.Textbox(value="You are a helpful assistant.", label="System Prompt").render() # Render system_message_box here
        inference_provider_radio.render() # Render the provider radio here
        max_tokens_slider.render()
        temperature_slider.render()
        top_p_slider.render()
        frequency_penalty_slider.render()
        seed_slider.render()
        print("Advanced settings accordion created with provider selection and parameters.")


print("Gradio interface fully initialized.")

if __name__ == "__main__":
    print("Launching the demo application.")
    demo.launch(show_api=False)