DonImages commited on
Commit
c93b55a
·
verified ·
1 Parent(s): 23d76b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -23
app.py CHANGED
@@ -1,49 +1,76 @@
1
  import gradio as gr
2
  import torch
3
  import os
4
- from diffusers import StableDiffusion3Pipeline
 
 
5
  from safetensors.torch import load_file
6
  from spaces import GPU # Remove if not in HF Space
7
 
8
- # 1. Define model ID and HF_TOKEN (at the VERY beginning)
9
- model_id = "stabilityai/stable-diffusion-3.5-large" # Or your preferred model ID
10
- hf_token = os.getenv("HF_TOKEN") # For private models (set in HF Space settings)
 
 
11
 
12
- # 2. Initialize pipeline (to None initially)
13
- pipeline = None
14
-
15
- # 3. Load Stable Diffusion and LoRA (before Gradio)
16
  try:
17
- if hf_token: # check if the token exists, if not, then do not pass the token
18
- pipeline = StableDiffusion3Pipeline.from_pretrained(
19
- model_id,
20
- torch_dtype=torch.float16,
21
- cache_dir="./model_cache" # For caching
22
- )
23
- else:
24
- pipeline = StableDiffusion3Pipeline.from_pretrained(
25
- model_id,
26
- torch_dtype=torch.float16,
27
- cache_dir="./model_cache" # For caching
28
- )
29
 
30
  lora_filename = "lora_trained_model.safetensors" # EXACT filename of your LoRA
31
  lora_path = os.path.join("./", lora_filename)
32
 
33
  if os.path.exists(lora_path):
34
  lora_weights = load_file(lora_path)
35
- text_encoder = pipeline.text_encoder
36
  text_encoder.load_state_dict(lora_weights, strict=False)
37
  print(f"LoRA loaded successfully from: {lora_path}")
38
  else:
39
  print(f"Error: LoRA file not found at: {lora_path}")
40
  exit() # Stop if LoRA is not found
41
 
42
- print("Stable Diffusion model loaded successfully!")
43
 
44
  except Exception as e:
45
  print(f"Error loading model or LoRA: {e}")
46
- exit() # Stop if model loading fails
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # 4. Image generation function (now decorated)
49
  @GPU(duration=65) # Only if in HF Space
 
1
  import gradio as gr
2
  import torch
3
  import os
4
+ import random
5
+ import numpy as np
6
+ from diffusers import DiffusionPipeline
7
  from safetensors.torch import load_file
8
  from spaces import GPU # Remove if not in HF Space
9
 
10
+ # 1. Model and LoRA Loading (Before Gradio)
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
13
+ token = os.getenv("HF_TOKEN")
14
+ model_repo_id = "stabilityai/stable-diffusion-3.5-large"
15
 
 
 
 
 
16
  try:
17
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype, use_auth_token=token) # No need to check for token existence, diffusers handles this
18
+ pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
19
 
20
  lora_filename = "lora_trained_model.safetensors" # EXACT filename of your LoRA
21
  lora_path = os.path.join("./", lora_filename)
22
 
23
  if os.path.exists(lora_path):
24
  lora_weights = load_file(lora_path)
25
+ text_encoder = pipe.text_encoder
26
  text_encoder.load_state_dict(lora_weights, strict=False)
27
  print(f"LoRA loaded successfully from: {lora_path}")
28
  else:
29
  print(f"Error: LoRA file not found at: {lora_path}")
30
  exit() # Stop if LoRA is not found
31
 
32
+ print("Stable Diffusion model and LoRA loaded successfully!")
33
 
34
  except Exception as e:
35
  print(f"Error loading model or LoRA: {e}")
36
+ exit()
37
+
38
+
39
+ MAX_SEED = np.iinfo(np.int32).max
40
+ MAX_IMAGE_SIZE = 1024
41
+
42
+ @GPU(duration=65) # Only if in HF Space
43
+ def infer(
44
+ prompt,
45
+ negative_prompt="",
46
+ seed=42,
47
+ randomize_seed=False,
48
+ width=1024,
49
+ height=1024,
50
+ guidance_scale=4.5,
51
+ num_inference_steps=40,
52
+ progress=gr.Progress(track_tqdm=True),
53
+ ):
54
+ if randomize_seed:
55
+ seed = random.randint(0, MAX_SEED)
56
+ generator = torch.Generator().manual_seed(seed)
57
+
58
+ try:
59
+ image = pipe(
60
+ prompt=prompt,
61
+ negative_prompt=negative_prompt,
62
+ guidance_scale=guidance_scale,
63
+ num_inference_steps=num_inference_steps,
64
+ width=width,
65
+ height=height,
66
+ generator=generator,
67
+ ).images[0]
68
+ return image, seed
69
+ except Exception as e:
70
+ print(f"Error during image generation: {e}") # Print error for debugging
71
+ return f"Error: {e}", seed # Return error to Gradio interface
72
+
73
+ # ... (rest of your Gradio code - examples, CSS, etc. - same as before)
74
 
75
  # 4. Image generation function (now decorated)
76
  @GPU(duration=65) # Only if in HF Space