My-AI-Projects commited on
Commit
04b10bd
Β·
verified Β·
1 Parent(s): d415aed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -9
app.py CHANGED
@@ -1,22 +1,38 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import FluxPipeline
4
  from PIL import Image
5
 
6
- # Load the diffusion model
7
- pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
8
- pipeline.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
10
 
11
- # Set the model to the appropriate device
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
- pipeline.to(device)
14
 
15
  def generate_image(prompt, guidance_scale=7.5, num_inference_steps=50):
16
  # Generate an image based on the prompt
17
  with torch.no_grad():
18
- # Generate images
19
- images = pipeline(prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images
 
 
 
20
 
21
  # Assuming pipeline returns a list of images, just take the first one
22
  img = images[0]
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import DiffusionPipeline # Note: Change `FluxPipeline` to `DiffusionPipeline` if `FluxPipeline` is not correct
4
  from PIL import Image
5
 
6
+ # Function to determine the device and handle model loading
7
+ def setup_pipeline():
8
+ # Check for CUDA availability
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Load the diffusion model
12
+ try:
13
+ pipeline = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
14
+ if device == "cpu":
15
+ # If using CPU, ensure model is offloaded to avoid GPU-specific features
16
+ pipeline.enable_model_cpu_offload()
17
+ else:
18
+ # Move model to GPU
19
+ pipeline.to(device)
20
+ except Exception as e:
21
+ print(f"Error loading model: {e}")
22
+ raise e
23
 
24
+ return pipeline, device
25
 
26
+ pipeline, device = setup_pipeline()
 
 
27
 
28
  def generate_image(prompt, guidance_scale=7.5, num_inference_steps=50):
29
  # Generate an image based on the prompt
30
  with torch.no_grad():
31
+ try:
32
+ images = pipeline(prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps).images
33
+ except Exception as e:
34
+ print(f"Error generating image: {e}")
35
+ raise e
36
 
37
  # Assuming pipeline returns a list of images, just take the first one
38
  img = images[0]