My-AI-Projects commited on
Commit
64b241d
Β·
verified Β·
1 Parent(s): ab37886

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -48
app.py CHANGED
@@ -1,65 +1,47 @@
1
  import gradio as gr
2
- import torch
3
- from diffusers import DiffusionPipeline # Note: Change `FluxPipeline` to `DiffusionPipeline` if `FluxPipeline` is not correct
4
- from PIL import Image
5
 
6
- # Function to determine the device and handle model loading
7
- def setup_pipeline():
8
- # Check for CUDA availability
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
-
11
- # Load the diffusion model
12
- try:
13
- pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
14
- if device == "cpu":
15
- # If using CPU, ensure model is offloaded to avoid GPU-specific features
16
- pipeline.enable_model_cpu_offload()
17
- else:
18
- # Move model to GPU
19
- pipeline.to(device)
20
- except Exception as e:
21
- print(f"Error loading model: {e}")
22
- raise e
23
 
24
- return pipeline, device
25
-
26
- pipeline, device = setup_pipeline()
27
-
28
- def generate_image(prompt, guidance_scale=7.5, num_inference_steps=50):
29
- # Generate an image based on the prompt
30
- with torch.no_grad():
31
- try:
32
- images = pipeline(prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images
33
- except Exception as e:
34
- print(f"Error generating image: {e}")
35
- raise e
36
-
37
- # Assuming pipeline returns a list of images, just take the first one
38
- img = images[0]
39
-
40
- # Convert PIL image to format suitable for Gradio
41
- return img
42
 
43
- # Set up Gradio interface
44
  with gr.Blocks() as demo:
45
  gr.Markdown("# Text to Image Generation")
46
-
47
  with gr.Row():
48
  prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt here...")
49
- guidance_scale = gr.Slider(minimum=1, maximum=15, step=0.1, value=7.5, label="Guidance Scale")
50
- num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, value=50, label="Number of Inference Steps")
51
-
 
 
 
 
52
  with gr.Row():
53
  generate_button = gr.Button("Generate Image")
54
-
55
  result = gr.Image(label="Generated Image")
56
-
57
- # Connect the function to the button
58
  generate_button.click(
59
  fn=generate_image,
60
- inputs=[prompt, guidance_scale, num_inference_steps],
61
  outputs=result
62
  )
63
 
64
- # Launch the app
65
  demo.launch()
 
1
  import gradio as gr
2
+ from gradio_client import Client
 
 
3
 
4
+ # Initialize the client with the model endpoint
5
+ client = Client("black-forest-labs/FLUX.1-dev")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ def generate_image(prompt, seed=0, randomize_seed=True, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=28):
8
+ # Make the API request
9
+ result = client.predict(
10
+ prompt=prompt,
11
+ seed=seed,
12
+ randomize_seed=randomize_seed,
13
+ width=width,
14
+ height=height,
15
+ guidance_scale=guidance_scale,
16
+ num_inference_steps=num_inference_steps,
17
+ api_name="/infer"
18
+ )
19
+ return result
 
 
 
 
 
20
 
21
+ # Define the Gradio interface
22
  with gr.Blocks() as demo:
23
  gr.Markdown("# Text to Image Generation")
24
+
25
  with gr.Row():
26
  prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt here...")
27
+ seed = gr.Slider(minimum=0, maximum=100000, step=1, value=0, label="Seed")
28
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
29
+ width = gr.Slider(minimum=256, maximum=2048, step=32, value=1024, label="Width")
30
+ height = gr.Slider(minimum=256, maximum=2048, step=32, value=1024, label="Height")
31
+ guidance_scale = gr.Slider(minimum=1, maximum=15, step=0.1, value=3.5, label="Guidance Scale")
32
+ num_inference_steps = gr.Slider(minimum=1, maximum=50, step=1, value=28, label="Number of Inference Steps")
33
+
34
  with gr.Row():
35
  generate_button = gr.Button("Generate Image")
36
+
37
  result = gr.Image(label="Generated Image")
38
+
39
+ # Define the button click action
40
  generate_button.click(
41
  fn=generate_image,
42
+ inputs=[prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
43
  outputs=result
44
  )
45
 
46
+ # Launch the Gradio app
47
  demo.launch()