Nymbo commited on
Commit
254adbf
·
verified ·
1 Parent(s): 3bc16e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -10
app.py CHANGED
@@ -17,13 +17,14 @@ headers = {"Authorization": f"Bearer {API_TOKEN}"}
17
  # Timeout for requests
18
  timeout = 100
19
 
20
- def query(prompt, model, custom_lora, is_negative=False, steps=35, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, width=1024, height=1024, custom_api_key=""):
21
  # Debug log to indicate function start
22
  print("Starting query function...")
23
  # Print the parameters for debugging purposes
24
  print(f"Prompt: {prompt}")
25
  print(f"Model: {model}")
26
  print(f"Custom LoRA: {custom_lora}")
 
27
  print(f"Parameters - Steps: {steps}, CFG Scale: {cfg_scale}, Seed: {seed}, Strength: {strength}, Width: {width}, Height: {height}")
28
  print(f"Custom API Key provided: {bool(custom_api_key.strip())}") # Log whether a custom key was provided without printing the key
29
 
@@ -41,7 +42,12 @@ def query(prompt, model, custom_lora, is_negative=False, steps=35, cfg_scale=7,
41
  print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication")
42
  API_TOKEN = custom_api_key.strip()
43
  else:
44
- # Randomly select an API token from available options to distribute the load
 
 
 
 
 
45
  API_TOKEN = random.choice([os.getenv("HF_READ_TOKEN"), os.getenv("HF_READ_TOKEN_2"), os.getenv("HF_READ_TOKEN_3"), os.getenv("HF_READ_TOKEN_4"), os.getenv("HF_READ_TOKEN_5")])
46
  print("USING DEFAULT API KEY: Random environment variable token is being used for authentication")
47
 
@@ -275,6 +281,31 @@ def query(prompt, model, custom_lora, is_negative=False, steps=35, cfg_scale=7,
275
  if model == 'Stable Diffusion 3 Medium':
276
  API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-3-medium-diffusers"
277
  prompt = f"A, {prompt}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  if model == 'Duchaiten Real3D NSFW XL':
279
  API_URL = "https://api-inference.huggingface.co/models/stablediffusionapi/duchaiten-real3d-nsfw-xl"
280
  if model == 'Pixel Art XL':
@@ -351,8 +382,45 @@ def query(prompt, model, custom_lora, is_negative=False, steps=35, cfg_scale=7,
351
  }
352
  print(f"Payload: {json.dumps(payload, indent=2)}") # Debug log
353
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  # Make a request to the API to generate the image
355
  try:
 
 
 
 
 
 
 
356
  response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
357
  print(f"Response status code: {response.status_code}") # Debug log
358
  except requests.exceptions.RequestException as e:
@@ -378,8 +446,21 @@ def query(prompt, model, custom_lora, is_negative=False, steps=35, cfg_scale=7,
378
  raise gr.Error(f"{response.status_code}: An unexpected error occurred.")
379
 
380
  try:
381
- # Attempt to read the image from the response content
382
- image_bytes = response.content
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  image = Image.open(io.BytesIO(image_bytes))
384
  print(f'Generation {key} completed! ({prompt})') # Debug log
385
  return image
@@ -537,16 +618,56 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme_5') as dalle:
537
  with gr.Row():
538
  # Textbox for specifying elements to exclude from the image
539
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="What should not be in the image", value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos", lines=3, elem_id="negative-prompt-text-input")
 
540
  with gr.Row():
541
- # New: BYOK (Bring Your Own Key) textbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
  byok_textbox = gr.Textbox(
543
  value="",
544
  label="BYOK (Bring Your Own Key)",
545
- info="Enter a custom Hugging Face API key here. When provided, this key will be used instead of the default keys.",
546
- placeholder="Enter your Hugging Face API token",
547
  type="password", # Hide the API key for security
548
  elem_id="byok-input"
549
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  with gr.Row():
551
  # Slider for selecting the image width
552
  width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=32)
@@ -632,8 +753,16 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme_5') as dalle:
632
  ## Negative Prompt
633
  ###### This box is for telling the AI what you don't want in your images. Think of it as a way to avoid certain elements. For instance, if you don't want blurry images or extra limbs showing up, this is where you'd mention it.
634
 
 
 
 
 
 
 
 
 
635
  ## BYOK (Bring Your Own Key)
636
- ###### This allows you to use your own Hugging Face API key instead of the default keys. Enter your key here for direct access to models using your account's permissions and rate limits.
637
 
638
  ## Width & Height
639
  ###### These sliders allow you to specify the resolution of your image. Default value is 1024x1024, and maximum output is 1216x1216.
@@ -664,8 +793,8 @@ with gr.Blocks(theme='Nymbo/Nymbo_Theme_5') as dalle:
664
  with gr.Row():
665
  image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery")
666
 
667
- # Set up button click event to call the query function with the added BYOK parameter
668
- text_button.click(query, inputs=[text_prompt, model, custom_lora, negative_prompt, steps, cfg, method, seed, strength, width, height, byok_textbox], outputs=image_output)
669
 
670
  print("Launching Gradio interface...") # Debug log
671
  # Launch the Gradio interface without showing the API or sharing externally
 
17
  # Timeout for requests
18
  timeout = 100
19
 
20
+ def query(prompt, model, custom_lora, is_negative=False, steps=35, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, width=1024, height=1024, provider="hf-inference", custom_api_key=""):
21
  # Debug log to indicate function start
22
  print("Starting query function...")
23
  # Print the parameters for debugging purposes
24
  print(f"Prompt: {prompt}")
25
  print(f"Model: {model}")
26
  print(f"Custom LoRA: {custom_lora}")
27
+ print(f"Provider: {provider}")
28
  print(f"Parameters - Steps: {steps}, CFG Scale: {cfg_scale}, Seed: {seed}, Strength: {strength}, Width: {width}, Height: {height}")
29
  print(f"Custom API Key provided: {bool(custom_api_key.strip())}") # Log whether a custom key was provided without printing the key
30
 
 
42
  print("USING CUSTOM API KEY: BYOK token provided by user is being used for authentication")
43
  API_TOKEN = custom_api_key.strip()
44
  else:
45
+ # If no custom key, check if provider is "hf-inference"
46
+ if provider != "hf-inference":
47
+ print(f"ERROR: Custom API key is required for {provider} provider")
48
+ raise gr.Error(f"A custom API key is required when using the {provider} provider. Please enter your key in the BYOK field.")
49
+
50
+ # For hf-inference, randomly select an API token from available options to distribute the load
51
  API_TOKEN = random.choice([os.getenv("HF_READ_TOKEN"), os.getenv("HF_READ_TOKEN_2"), os.getenv("HF_READ_TOKEN_3"), os.getenv("HF_READ_TOKEN_4"), os.getenv("HF_READ_TOKEN_5")])
52
  print("USING DEFAULT API KEY: Random environment variable token is being used for authentication")
53
 
 
281
  if model == 'Stable Diffusion 3 Medium':
282
  API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-3-medium-diffusers"
283
  prompt = f"A, {prompt}"
284
+
285
+ # Adjust URL for different providers if needed
286
+ base_url = "https://api-inference.huggingface.co"
287
+ if provider == "fal-ai":
288
+ base_url = "https://api.fal.ai/hf"
289
+ elif provider == "nebius":
290
+ base_url = "https://api.nebius.ai/hf"
291
+ elif provider == "replicate":
292
+ base_url = "https://api.replicate.com/v1/models"
293
+ # Replicate has a different endpoint structure
294
+ if "https://api-inference.huggingface.co/models/" in API_URL:
295
+ model_id = API_URL.replace("https://api-inference.huggingface.co/models/", "")
296
+ API_URL = f"{base_url}/{model_id}"
297
+ elif provider == "together":
298
+ base_url = "https://api.together.xyz/inference"
299
+ if "https://api-inference.huggingface.co/models/" in API_URL:
300
+ model_id = API_URL.replace("https://api-inference.huggingface.co/models/", "")
301
+ API_URL = f"{base_url}/{model_id}"
302
+
303
+ # Only update URL if we're not using a custom URL format (like for replicate)
304
+ if provider != "replicate" and provider != "together" and "https://api-inference.huggingface.co/models/" in API_URL:
305
+ model_id = API_URL.replace("https://api-inference.huggingface.co/models/", "")
306
+ API_URL = f"{base_url}/models/{model_id}"
307
+
308
+ print(f"API URL set to: {API_URL}") # Debug log
309
  if model == 'Duchaiten Real3D NSFW XL':
310
  API_URL = "https://api-inference.huggingface.co/models/stablediffusionapi/duchaiten-real3d-nsfw-xl"
311
  if model == 'Pixel Art XL':
 
382
  }
383
  print(f"Payload: {json.dumps(payload, indent=2)}") # Debug log
384
 
385
+ # Adjust payload for different providers if needed
386
+ if provider == "replicate" or provider == "together":
387
+ # These providers might have different API formats
388
+ if provider == "replicate":
389
+ payload = {
390
+ "version": model_id,
391
+ "input": {
392
+ "prompt": prompt,
393
+ "negative_prompt": is_negative,
394
+ "num_inference_steps": steps,
395
+ "guidance_scale": cfg_scale,
396
+ "seed": seed if seed != -1 else random.randint(1, 1000000000),
397
+ "strength": strength,
398
+ "width": width,
399
+ "height": height
400
+ }
401
+ }
402
+ elif provider == "together":
403
+ payload = {
404
+ "model": model_id,
405
+ "prompt": prompt,
406
+ "negative_prompt": is_negative,
407
+ "num_inference_steps": steps,
408
+ "guidance_scale": cfg_scale,
409
+ "seed": seed if seed != -1 else random.randint(1, 1000000000),
410
+ "strength": strength,
411
+ "width": width,
412
+ "height": height
413
+ }
414
+
415
  # Make a request to the API to generate the image
416
  try:
417
+ # Log which provider we're using
418
+ print(f"Sending request to {provider} provider")
419
+
420
+ # Add provider to headers for HF inference if needed
421
+ if provider != "hf-inference" and "api-inference.huggingface.co" in API_URL:
422
+ headers["X-Provider"] = provider
423
+
424
  response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
425
  print(f"Response status code: {response.status_code}") # Debug log
426
  except requests.exceptions.RequestException as e:
 
446
  raise gr.Error(f"{response.status_code}: An unexpected error occurred.")
447
 
448
  try:
449
+ # Handle different response formats based on provider
450
+ if provider == "replicate":
451
+ # Replicate might return a URL to the image rather than the image itself
452
+ result = response.json()
453
+ if "output" in result and isinstance(result["output"], list) and len(result["output"]) > 0:
454
+ image_url = result["output"][0]
455
+ # Fetch the image from the URL
456
+ image_response = requests.get(image_url, timeout=timeout)
457
+ image_bytes = image_response.content
458
+ else:
459
+ raise Exception(f"Unexpected Replicate response format: {result}")
460
+ else:
461
+ # Standard response with image content
462
+ image_bytes = response.content
463
+
464
  image = Image.open(io.BytesIO(image_bytes))
465
  print(f'Generation {key} completed! ({prompt})') # Debug log
466
  return image
 
618
  with gr.Row():
619
  # Textbox for specifying elements to exclude from the image
620
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="What should not be in the image", value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos", lines=3, elem_id="negative-prompt-text-input")
621
+
622
  with gr.Row():
623
+ # Provider selection radio buttons
624
+ providers_list = [
625
+ "hf-inference", # Default Hugging Face Inference
626
+ "fal-ai", # Fal.ai
627
+ "nebius", # Nebius
628
+ "replicate", # Replicate
629
+ "together", # Together AI
630
+ ]
631
+
632
+ provider_radio = gr.Radio(
633
+ choices=providers_list,
634
+ value="hf-inference",
635
+ label="Inference Provider",
636
+ info="Select which provider to use for image generation. Providers other than HF Inference require a custom API key.",
637
+ elem_id="provider-radio"
638
+ )
639
+
640
+ with gr.Row():
641
+ # BYOK (Bring Your Own Key) textbox
642
  byok_textbox = gr.Textbox(
643
  value="",
644
  label="BYOK (Bring Your Own Key)",
645
+ info="Enter a custom API key here. Required for all providers except HF Inference.",
646
+ placeholder="Enter your API token",
647
  type="password", # Hide the API key for security
648
  elem_id="byok-input"
649
  )
650
+
651
+ # Function to validate provider selection based on BYOK
652
+ def validate_provider(api_key, provider):
653
+ # If no custom API key is provided, only "hf-inference" can be used
654
+ if not api_key.strip() and provider != "hf-inference":
655
+ return gr.update(value="hf-inference")
656
+ return gr.update(value=provider)
657
+
658
+ # Connect the BYOK textbox to validate provider selection
659
+ byok_textbox.change(
660
+ fn=validate_provider,
661
+ inputs=[byok_textbox, provider_radio],
662
+ outputs=provider_radio
663
+ )
664
+
665
+ # Also validate provider when the radio changes to ensure consistency
666
+ provider_radio.change(
667
+ fn=validate_provider,
668
+ inputs=[byok_textbox, provider_radio],
669
+ outputs=provider_radio
670
+ )
671
  with gr.Row():
672
  # Slider for selecting the image width
673
  width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=32)
 
753
  ## Negative Prompt
754
  ###### This box is for telling the AI what you don't want in your images. Think of it as a way to avoid certain elements. For instance, if you don't want blurry images or extra limbs showing up, this is where you'd mention it.
755
 
756
+ ## Inference Provider
757
+ ###### Select which AI provider to use for generating images. Different providers may have different capabilities, pricing, and performance characteristics:
758
+ ###### - HF Inference: Default Hugging Face inference API (uses application's default keys)
759
+ ###### - Fal AI: Optimized for low latency and high throughput (requires your own API key)
760
+ ###### - Nebius: Cloud provider with enterprise-grade infrastructure (requires your own API key)
761
+ ###### - Replicate: Wide variety of models with flexible deployment (requires your own API key)
762
+ ###### - Together: High-performance inference service (requires your own API key)
763
+
764
  ## BYOK (Bring Your Own Key)
765
+ ###### This allows you to use your own API key instead of the default keys. Enter your key here for direct access to models using your account's permissions and rate limits. A custom key is required for all providers except HF Inference.
766
 
767
  ## Width & Height
768
  ###### These sliders allow you to specify the resolution of your image. Default value is 1024x1024, and maximum output is 1216x1216.
 
793
  with gr.Row():
794
  image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery")
795
 
796
+ # Set up button click event to call the query function with provider and BYOK parameters
797
+ text_button.click(query, inputs=[text_prompt, model, custom_lora, negative_prompt, steps, cfg, method, seed, strength, width, height, provider_radio, byok_textbox], outputs=image_output)
798
 
799
  print("Launching Gradio interface...") # Debug log
800
  # Launch the Gradio interface without showing the API or sharing externally