DonImages commited on
Commit
02a3a52
·
verified ·
1 Parent(s): dfab3a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -3,10 +3,10 @@ import torch
3
  import os
4
  from diffusers import StableDiffusion3Pipeline
5
  from safetensors.torch import load_file
6
- from spaces import GPU # Remove this line if NOT in a HF Space
7
 
8
  # 1. Define model ID and HF_TOKEN (at the VERY beginning)
9
- model_id = "stabilityai/stable-diffusion-3.5-large" # Correct model ID for SD 3.5 Large
10
  hf_token = os.getenv("HF_TOKEN") # For private models (set in HF Space settings)
11
 
12
  # 2. Initialize pipeline (to None initially)
@@ -14,12 +14,18 @@ pipeline = None
14
 
15
  # 3. Load Stable Diffusion and LoRA (before Gradio)
16
  try:
17
- pipeline = StableDiffusion3Pipeline.from_pretrained(
18
- model_id,
19
- use_auth_token=hf_token,
20
- torch_dtype=torch.float16, # Use float16 for memory efficiency
21
- cache_dir="./model_cache" # For caching
22
- )
 
 
 
 
 
 
23
 
24
  lora_filename = "lora_trained_model.safetensors" # EXACT filename of your LoRA
25
  lora_path = os.path.join("./", lora_filename)
 
3
  import os
4
  from diffusers import StableDiffusion3Pipeline
5
  from safetensors.torch import load_file
6
+ from spaces import GPU # Remove if not in HF Space
7
 
8
  # 1. Define model ID and HF_TOKEN (at the VERY beginning)
9
+ model_id = "stabilityai/stable-diffusion-3.5-large" # Or your preferred model ID
10
  hf_token = os.getenv("HF_TOKEN") # For private models (set in HF Space settings)
11
 
12
  # 2. Initialize pipeline (to None initially)
 
14
 
15
  # 3. Load Stable Diffusion and LoRA (before Gradio)
16
  try:
17
+ if hf_token: # check if the token exists, if not, then do not pass the token
18
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
19
+ model_id,
20
+ torch_dtype=torch.float16,
21
+ cache_dir="./model_cache" # For caching
22
+ )
23
+ else:
24
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
25
+ model_id,
26
+ torch_dtype=torch.float16,
27
+ cache_dir="./model_cache" # For caching
28
+ )
29
 
30
  lora_filename = "lora_trained_model.safetensors" # EXACT filename of your LoRA
31
  lora_path = os.path.join("./", lora_filename)