DonImages commited on
Commit
82198c8
·
verified ·
1 Parent(s): 03b91b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -15
app.py CHANGED
@@ -4,18 +4,6 @@ import torch
4
  from diffusers import StableDiffusion3Pipeline
5
  import spaces
6
  import random
7
- from peft import PeftModel, get_peft_model
8
-
9
- # Ensure GPU allocation in Hugging Face Spaces
10
- @spaces.GPU(duration=65)
11
- def generate_image(prompt: str, seed: int = None):
12
- """Generates an image using Stable Diffusion 3.5 with LoRA fine-tuning."""
13
- if seed is None:
14
- seed = random.randint(0, 100000)
15
- generator = torch.manual_seed(seed)
16
-
17
- image = pipeline(prompt, generator=generator).images[0]
18
- return image
19
 
20
  # Device selection
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -27,22 +15,36 @@ token = os.getenv("HF_TOKEN")
27
  # Model ID for SD 3.5 Large
28
  model_repo_id = "stabilityai/stable-diffusion-3.5-large"
29
 
30
- # Load Stable Diffusion pipeline
31
  pipeline = StableDiffusion3Pipeline.from_pretrained(
32
  model_repo_id,
33
  torch_dtype=torch_dtype,
34
  use_safetensors=True, # Use safetensors format if supported
35
  ).to(device)
36
 
37
- # Load the LoRA trained weights
38
  lora_path = "lora_trained_model.pt" # Ensure this file is uploaded in the Space
39
  if os.path.exists(lora_path):
40
  lora_state_dict = torch.load(lora_path, map_location=device, weights_only=True)
41
- pipeline = PeftModel.from_pretrained(pipeline, lora_path)
42
  print("✅ LoRA weights loaded successfully!")
43
  else:
44
  print("⚠️ LoRA file not found! Running base model.")
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  # Gradio Interface
47
  with gr.Blocks() as demo:
48
  gr.Markdown("# 🖼️ LoRA Fine-Tuned SD 3.5 Image Generator")
 
4
  from diffusers import StableDiffusion3Pipeline
5
  import spaces
6
  import random
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Device selection
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
15
  # Model ID for SD 3.5 Large
16
  model_repo_id = "stabilityai/stable-diffusion-3.5-large"
17
 
18
+ # Load Stable Diffusion pipeline once at the start
19
  pipeline = StableDiffusion3Pipeline.from_pretrained(
20
  model_repo_id,
21
  torch_dtype=torch_dtype,
22
  use_safetensors=True, # Use safetensors format if supported
23
  ).to(device)
24
 
25
+ # Load the LoRA trained weights once at the start
26
  lora_path = "lora_trained_model.pt" # Ensure this file is uploaded in the Space
27
  if os.path.exists(lora_path):
28
  lora_state_dict = torch.load(lora_path, map_location=device, weights_only=True)
29
+ pipeline.load_lora_weights(lora_state_dict) # Load LoRA weights into the pipeline
30
  print("✅ LoRA weights loaded successfully!")
31
  else:
32
  print("⚠️ LoRA file not found! Running base model.")
33
 
34
+ # Ensure GPU allocation in Hugging Face Spaces
35
+ @spaces.GPU(duration=65)
36
+ def generate_image(prompt: str, seed: int = None):
37
+ """Generates an image using Stable Diffusion 3.5 with LoRA fine-tuning."""
38
+ if seed is None:
39
+ seed = random.randint(0, 100000)
40
+
41
+ # Create a generator with the seed
42
+ generator = torch.manual_seed(seed)
43
+
44
+ # Generate the image using the pipeline
45
+ image = pipeline(prompt, generator=generator).images[0]
46
+ return image
47
+
48
  # Gradio Interface
49
  with gr.Blocks() as demo:
50
  gr.Markdown("# 🖼️ LoRA Fine-Tuned SD 3.5 Image Generator")