DonImages commited on
Commit
dc25832
·
verified ·
1 Parent(s): 4a6aac0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -42
app.py CHANGED
@@ -3,55 +3,50 @@ import torch
3
  import os
4
  from diffusers import StableDiffusion3Pipeline
5
  from safetensors.torch import load_file
6
- from spaces import GPU # Import GPU if in HF Space, otherwise remove this line
7
 
8
- # Access HF_TOKEN from environment variables
9
- hf_token = os.getenv("HF_TOKEN")
10
 
11
- # Specify the pre-trained model ID
12
- model_id = "stabilityai/stable-diffusion-3.5-large"
13
-
14
- # Initialize pipeline *outside* the function (but set to None initially)
15
- pipeline = None
16
-
17
-
18
- # Function to load the Stable Diffusion pipeline (called only ONCE)
19
- def load_pipeline():
20
- global pipeline
21
- try:
22
- pipeline = StableDiffusion3Pipeline.from_pretrained(
23
- model_id,
24
- use_auth_token=hf_token,
25
- torch_dtype=torch.float16,
26
- cache_dir="./model_cache"
27
- )
28
- except Exception as e:
29
- print(f"Error loading model: {e}")
30
- return f"Error loading model: {e}"
31
 
 
 
 
 
 
 
 
 
32
  pipeline.enable_model_cpu_offload()
33
  pipeline.enable_attention_slicing()
34
- return "Model loaded successfully"
35
-
36
- @GPU(duration=65) # Use GPU decorator (ONLY if in HF Space)
37
- def generate_image(prompt):
38
- global pipeline
39
- if pipeline is None:
40
- return "Model not loaded. Please wait."
41
 
42
  lora_filename = "lora_trained_model.safetensors"
43
  lora_path = os.path.join("./", lora_filename)
44
  print(f"Loading LoRA from: {lora_path}")
45
 
46
- try:
47
- if os.path.exists(lora_path):
48
- lora_weights = load_file(lora_path)
49
- text_encoder = pipeline.text_encoder
50
- text_encoder.load_state_dict(lora_weights, strict=False)
51
- else:
52
- return f"Error: LoRA file not found at {lora_path}"
53
- except Exception as e:
54
- return f"Error loading LoRA: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  try:
57
  image = pipeline(prompt).images[0]
@@ -60,13 +55,11 @@ def generate_image(prompt):
60
  return f"Error generating image: {e}"
61
 
62
 
 
63
  with gr.Blocks() as demo:
64
  prompt_input = gr.Textbox(label="Prompt")
65
  image_output = gr.Image(label="Generated Image")
66
  generate_button = gr.Button("Generate")
67
- load_model_button = gr.Button("Load Model")
68
-
69
- load_model_button.click(fn=load_pipeline, outputs=load_model_button)
70
 
71
  generate_button.click(
72
  fn=generate_image,
 
3
  import os
4
  from diffusers import StableDiffusion3Pipeline
5
  from safetensors.torch import load_file
6
+ from spaces import GPU # Remove if not in HF Space
7
 
8
+ # ... (HF_TOKEN, model_id - same as before)
 
9
 
10
+ pipeline = None # Global pipeline variable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # Load Stable Diffusion and LoRA *immediately* (before Gradio)
13
+ try:
14
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
15
+ model_id,
16
+ use_auth_token=hf_token,
17
+ torch_dtype=torch.float16,
18
+ cache_dir="./model_cache"
19
+ )
20
  pipeline.enable_model_cpu_offload()
21
  pipeline.enable_attention_slicing()
 
 
 
 
 
 
 
22
 
23
  lora_filename = "lora_trained_model.safetensors"
24
  lora_path = os.path.join("./", lora_filename)
25
  print(f"Loading LoRA from: {lora_path}")
26
 
27
+ if os.path.exists(lora_path):
28
+ lora_weights = load_file(lora_path)
29
+ text_encoder = pipeline.text_encoder
30
+ text_encoder.load_state_dict(lora_weights, strict=False)
31
+ print("LoRA loaded successfully!") # Confirmation message
32
+ else:
33
+ print(f"Error: LoRA file not found at {lora_path}")
34
+ exit() # Exit if LoRA is not found
35
+
36
+
37
+ print("Stable Diffusion model loaded successfully!")
38
+
39
+ except Exception as e:
40
+ print(f"Error loading Stable Diffusion or LoRA: {e}")
41
+ exit() # Exit if there's an error
42
+
43
+
44
+ # Function for image generation (now much simpler)
45
+ @GPU(duration=65) # Use GPU decorator (ONLY if in HF Space)
46
+ def generate_image(prompt):
47
+ global pipeline
48
+ if pipeline is None: # This should never happen now
49
+ return "Error: Stable Diffusion model not loaded!"
50
 
51
  try:
52
  image = pipeline(prompt).images[0]
 
55
  return f"Error generating image: {e}"
56
 
57
 
58
+ # Create the Gradio interface (no "Load Model" button needed)
59
  with gr.Blocks() as demo:
60
  prompt_input = gr.Textbox(label="Prompt")
61
  image_output = gr.Image(label="Generated Image")
62
  generate_button = gr.Button("Generate")
 
 
 
63
 
64
  generate_button.click(
65
  fn=generate_image,