DonImages commited on
Commit
dfab3a9
·
verified ·
1 Parent(s): dc25832

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -24
app.py CHANGED
@@ -3,59 +3,56 @@ import torch
3
  import os
4
  from diffusers import StableDiffusion3Pipeline
5
  from safetensors.torch import load_file
6
- from spaces import GPU # Remove if not in HF Space
7
 
8
- # ... (HF_TOKEN, model_id - same as before)
 
 
9
 
10
- pipeline = None # Global pipeline variable
 
11
 
12
- # Load Stable Diffusion and LoRA *immediately* (before Gradio)
13
  try:
14
  pipeline = StableDiffusion3Pipeline.from_pretrained(
15
  model_id,
16
  use_auth_token=hf_token,
17
- torch_dtype=torch.float16,
18
- cache_dir="./model_cache"
19
  )
20
- pipeline.enable_model_cpu_offload()
21
- pipeline.enable_attention_slicing()
22
 
23
- lora_filename = "lora_trained_model.safetensors"
24
  lora_path = os.path.join("./", lora_filename)
25
- print(f"Loading LoRA from: {lora_path}")
26
 
27
  if os.path.exists(lora_path):
28
  lora_weights = load_file(lora_path)
29
  text_encoder = pipeline.text_encoder
30
  text_encoder.load_state_dict(lora_weights, strict=False)
31
- print("LoRA loaded successfully!") # Confirmation message
32
  else:
33
- print(f"Error: LoRA file not found at {lora_path}")
34
- exit() # Exit if LoRA is not found
35
-
36
 
37
  print("Stable Diffusion model loaded successfully!")
38
 
39
  except Exception as e:
40
- print(f"Error loading Stable Diffusion or LoRA: {e}")
41
- exit() # Exit if there's an error
42
-
43
 
44
- # Function for image generation (now much simpler)
45
- @GPU(duration=65) # Use GPU decorator (ONLY if in HF Space)
46
  def generate_image(prompt):
47
  global pipeline
48
- if pipeline is None: # This should never happen now
49
- return "Error: Stable Diffusion model not loaded!"
50
 
51
  try:
52
- image = pipeline(prompt).images[0]
53
  return image
54
  except Exception as e:
55
  return f"Error generating image: {e}"
56
 
57
-
58
- # Create the Gradio interface (no "Load Model" button needed)
59
  with gr.Blocks() as demo:
60
  prompt_input = gr.Textbox(label="Prompt")
61
  image_output = gr.Image(label="Generated Image")
 
3
  import os
4
  from diffusers import StableDiffusion3Pipeline
5
  from safetensors.torch import load_file
6
+ from spaces import GPU # Remove this line if NOT in a HF Space
7
 
8
+ # 1. Define model ID and HF_TOKEN (at the VERY beginning)
9
+ model_id = "stabilityai/stable-diffusion-3.5-large" # Correct model ID for SD 3.5 Large
10
+ hf_token = os.getenv("HF_TOKEN") # For private models (set in HF Space settings)
11
 
12
+ # 2. Initialize pipeline (to None initially)
13
+ pipeline = None
14
 
15
+ # 3. Load Stable Diffusion and LoRA (before Gradio)
16
  try:
17
  pipeline = StableDiffusion3Pipeline.from_pretrained(
18
  model_id,
19
  use_auth_token=hf_token,
20
+ torch_dtype=torch.float16, # Use float16 for memory efficiency
21
+ cache_dir="./model_cache" # For caching
22
  )
 
 
23
 
24
+ lora_filename = "lora_trained_model.safetensors" # EXACT filename of your LoRA
25
  lora_path = os.path.join("./", lora_filename)
 
26
 
27
  if os.path.exists(lora_path):
28
  lora_weights = load_file(lora_path)
29
  text_encoder = pipeline.text_encoder
30
  text_encoder.load_state_dict(lora_weights, strict=False)
31
+ print(f"LoRA loaded successfully from: {lora_path}")
32
  else:
33
+ print(f"Error: LoRA file not found at: {lora_path}")
34
+ exit() # Stop if LoRA is not found
 
35
 
36
  print("Stable Diffusion model loaded successfully!")
37
 
38
  except Exception as e:
39
+ print(f"Error loading model or LoRA: {e}")
40
+ exit() # Stop if model loading fails
 
41
 
42
+ # 4. Image generation function (now decorated)
43
+ @GPU(duration=65) # ONLY if in a HF Space, remove if not
44
  def generate_image(prompt):
45
  global pipeline
46
+ if pipeline is None: # Should not happen, but good to check
47
+ return "Error: Model not loaded!"
48
 
49
  try:
50
+ image = pipeline(prompt).images[0] # Access the first image from the list
51
  return image
52
  except Exception as e:
53
  return f"Error generating image: {e}"
54
 
55
+ # 5. Gradio interface
 
56
  with gr.Blocks() as demo:
57
  prompt_input = gr.Textbox(label="Prompt")
58
  image_output = gr.Image(label="Generated Image")