DonImages commited on
Commit
509d782
·
verified ·
1 Parent(s): 30d192f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -29
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import os
4
  from diffusers import StableDiffusion3Pipeline
5
  from safetensors.torch import load_file
6
- from spaces import GPU
7
 
8
  # Access HF_TOKEN from environment variables
9
  hf_token = os.getenv("HF_TOKEN")
@@ -11,31 +11,34 @@ hf_token = os.getenv("HF_TOKEN")
11
  # Specify the pre-trained model ID
12
  model_id = "stabilityai/stable-diffusion-3.5-large"
13
 
14
- # Lazy pipeline initialization
15
  pipeline = None
16
 
17
- # Function for image generation
18
- @gr.GPU(duration=65)
19
- def generate_image(prompt): # Remove lora_file input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  global pipeline
21
- if pipeline is None:
22
- try:
23
- pipeline = StableDiffusion3Pipeline.from_pretrained(
24
- model_id,
25
- use_auth_token=hf_token,
26
- torch_dtype=torch.float16,
27
- cache_dir="./model_cache"
28
- )
29
- except Exception as e:
30
- print(f"Error loading from cache: {e}")
31
- pipeline = StableDiffusion3Pipeline.from_pretrained(
32
- model_id,
33
- use_auth_token=hf_token,
34
- torch_dtype=torch.float16,
35
- local_files_only=False
36
- )
37
- pipeline.enable_model_cpu_offload()
38
- pipeline.enable_attention_slicing()
39
 
40
  # Load and apply LoRA (file is already in the Space)
41
  lora_filename = "lora_trained_model.safetensors" # Name of your LoRA file
@@ -47,10 +50,8 @@ def generate_image(prompt): # Remove lora_file input
47
  lora_weights = load_file(lora_path)
48
  text_encoder = pipeline.text_encoder
49
  text_encoder.load_state_dict(lora_weights, strict=False)
50
- else:
51
- return f"Error: LoRA file not found at {lora_path}"
52
- except Exception as e:
53
- return f"Error loading LoRA: {e}"
54
 
55
  try:
56
  image = pipeline(prompt).images[0]
@@ -59,15 +60,18 @@ def generate_image(prompt): # Remove lora_file input
59
  return f"Error generating image: {e}"
60
 
61
 
62
- # Create the Gradio interface (remove lora_upload)
63
  with gr.Blocks() as demo:
64
  prompt_input = gr.Textbox(label="Prompt")
65
  image_output = gr.Image(label="Generated Image")
66
  generate_button = gr.Button("Generate")
 
 
 
67
 
68
  generate_button.click(
69
  fn=generate_image,
70
- inputs=prompt_input, # Only prompt input now
71
  outputs=image_output,
72
  )
73
 
 
3
  import os
4
  from diffusers import StableDiffusion3Pipeline
5
  from safetensors.torch import load_file
6
+ from spaces import GPU # Import GPU if in HF Space, otherwise remove this line
7
 
8
  # Access HF_TOKEN from environment variables
9
  hf_token = os.getenv("HF_TOKEN")
 
11
  # Specify the pre-trained model ID
12
  model_id = "stabilityai/stable-diffusion-3.5-large"
13
 
14
+ # Initialize pipeline *outside* the function (but set to None initially)
15
  pipeline = None
16
 
17
+
18
+ # Function to load the Stable Diffusion pipeline (called only ONCE)
19
+ def load_pipeline():
20
+ global pipeline # Use the global keyword to modify the global variable
21
+ try:
22
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
23
+ model_id,
24
+ use_auth_token=hf_token,
25
+ torch_dtype=torch.float16,
26
+ cache_dir="./model_cache"
27
+ )
28
+ except Exception as e:
29
+ print(f"Error loading model: {e}")
30
+ return f"Error loading model: {e}" # Return error message
31
+
32
+ pipeline.enable_model_cpu_offload()
33
+ pipeline.enable_attention_slicing()
34
+ return "Model loaded successfully" # Return success message
35
+
36
+ # Function for image generation (now decorated)
37
+ @GPU(duration=65) # Use GPU decorator (ONLY if in HF Space)
38
+ def generate_image(prompt):
39
  global pipeline
40
+ if pipeline is None: # Check if pipeline is loaded
41
+ return "Model not loaded. Please wait." # Return message if not loaded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  # Load and apply LoRA (file is already in the Space)
44
  lora_filename = "lora_trained_model.safetensors" # Name of your LoRA file
 
50
  lora_weights = load_file(lora_path)
51
  text_encoder = pipeline.text_encoder
52
  text_encoder.load_state_dict(lora_weights, strict=False)
53
+ except Exception as e:
54
+ return f"Error loading LoRA: {e}"
 
 
55
 
56
  try:
57
  image = pipeline(prompt).images[0]
 
60
  return f"Error generating image: {e}"
61
 
62
 
63
+ # Create the Gradio interface
64
  with gr.Blocks() as demo:
65
  prompt_input = gr.Textbox(label="Prompt")
66
  image_output = gr.Image(label="Generated Image")
67
  generate_button = gr.Button("Generate")
68
+ load_model_button = gr.Button("Load Model") # Button to load model
69
+
70
+ load_model_button.click(fn=load_pipeline, outputs=load_model_button) # Call load_pipeline
71
 
72
  generate_button.click(
73
  fn=generate_image,
74
+ inputs=prompt_input,
75
  outputs=image_output,
76
  )
77