DonImages commited on
Commit
20271b2
Β·
verified Β·
1 Parent(s): 79d9e62

Update appCODE.py

Browse files
Files changed (1) hide show
  1. appCODE.py +60 -59
appCODE.py CHANGED
@@ -7,75 +7,76 @@ from diffusers import StableDiffusion3Pipeline
7
  from diffusers.loaders import SD3LoraLoaderMixin
8
  from safetensors.torch import load_file, save_file
9
 
10
- # Device selection
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
 
13
 
14
- # Load Hugging Face token securely
15
- token = os.getenv("HF_TOKEN")
16
 
17
- # Model ID for SD 3.5 Large
18
- model_repo_id = "stabilityai/stable-diffusion-3.5-large"
19
 
20
- # Convert .pt to .safetensors if needed
21
- lora_pt_path = "lora_trained_model.pt"
22
- lora_safetensors_path = "lora_trained_model.safetensors"
23
 
24
- if os.path.exists(lora_pt_path) and not os.path.exists(lora_safetensors_path):
25
- print("πŸ”„ Converting LoRA .pt to .safetensors...")
26
- lora_weights = torch.load(lora_pt_path, map_location="cpu")
27
- save_file(lora_weights, lora_safetensors_path)
28
- print(f"βœ… LoRA saved as {lora_safetensors_path}")
29
 
30
- # Load Stable Diffusion pipeline
31
- pipeline = StableDiffusion3Pipeline.from_pretrained(
32
- model_repo_id,
33
- torch_dtype=torch_dtype,
34
- use_safetensors=True, # Use safetensors format if supported
35
- ).to(device)
36
 
37
- # Load and fuse LoRA trained weights
38
- if os.path.exists(lora_safetensors_path):
39
- try:
40
- pipeline.load_lora_weights(".", weight_name="lora_trained_model.safetensors") # Corrected loading method
41
- pipeline.fuse_lora() # Merges LoRA into the base model
42
- print("βœ… LoRA weights loaded and fused successfully!")
43
- except Exception as e:
44
- print(f"❌ Error loading LoRA: {e}")
45
- else:
46
- print("⚠️ LoRA file not found! Running base model.")
47
 
48
- # Verify if LoRA is applied
49
- for name, param in pipeline.text_encoder.named_parameters():
50
- if "lora" in name.lower():
51
- print(f"βœ… LoRA applied to: {name}, requires_grad={param.requires_grad}")
52
 
53
- # Ensure GPU allocation in Hugging Face Spaces
54
- @spaces.GPU(duration=65)
55
- def generate_image(prompt: str, seed: int = None):
56
- """Generates an image using Stable Diffusion 3.5 with LoRA fine-tuning."""
57
- if seed is None:
58
- seed = random.randint(0, 100000)
59
-
60
- # Create a generator with the seed
61
- generator = torch.manual_seed(seed)
62
 
63
- # Generate the image using the pipeline
64
- image = pipeline(prompt, generator=generator).images[0]
65
- return image
66
 
67
- # Gradio Interface
68
- with gr.Blocks() as demo:
69
- gr.Markdown("# πŸ–ΌοΈ LoRA Fine-Tuned SD 3.5 Image Generator")
 
 
 
70
 
71
- with gr.Row():
72
- prompt_input = gr.Textbox(label="Enter Prompt", value="A woman in her 20s with expressive black eyes, graceful face, elegant body, standing on the beach at sunset. Photorealistic, highly detailed.")
73
- seed_input = gr.Number(label="Seed (optional)", value=None)
74
-
75
- generate_btn = gr.Button("Generate Image")
76
- output_image = gr.Image(label="Generated Image")
77
 
78
- generate_btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_image)
79
 
80
- # Launch the Gradio app
81
- demo.launch()
 
 
 
 
7
  from diffusers.loaders import SD3LoraLoaderMixin
8
  from safetensors.torch import load_file, save_file
9
 
10
+ # Ensure GPU allocation for image generation (moved here)
11
+ def main():
12
+ # Device selection
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
15
 
16
+ # Load Hugging Face token securely
17
+ token = os.getenv("HF_TOKEN")
18
 
19
+ # Model ID for SD 3.5 Large
20
+ model_repo_id = "stabilityai/stable-diffusion-3.5-large"
21
 
22
+ # Convert .pt to .safetensors if needed
23
+ lora_pt_path = "lora_trained_model.pt"
24
+ lora_safetensors_path = "lora_trained_model.safetensors"
25
 
26
+ if os.path.exists(lora_pt_path) and not os.path.exists(lora_safetensors_path):
27
+ print("πŸ”„ Converting LoRA .pt to .safetensors...")
28
+ lora_weights = torch.load(lora_pt_path, map_location="cpu")
29
+ save_file(lora_weights, lora_safetensors_path)
30
+ print(f"βœ… LoRA saved as {lora_safetensors_path}")
31
 
32
+ # Load Stable Diffusion 3.5 pipeline with optimized settings
33
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
34
+ model_repo_id,
35
+ torch_dtype=torch_dtype,
36
+ use_safetensors=True,
37
+ ).to(device)
38
 
39
+ # Load and fuse LoRA weights (optimized method)
40
+ if os.path.exists(lora_safetensors_path):
41
+ try:
42
+ SD3LoraLoaderMixin.load_lora_weights(pipeline, lora_safetensors_path)
43
+ pipeline.fuse_lora()
44
+ print("βœ… LoRA weights loaded and fused successfully!")
45
+ except Exception as e:
46
+ print(f"❌ Error loading LoRA: {e}")
47
+ else:
48
+ print("⚠️ LoRA file not found! Running base model.")
49
 
50
+ # Ensure LoRA is applied correctly
51
+ applied_lora = any("lora" in name.lower() for name, _ in pipeline.text_encoder.named_parameters())
52
+ print(f"βœ… LoRA Applied: {applied_lora}")
 
53
 
54
+ # Image generation function with GPU decorator
55
+ @spaces.GPU(duration=65)
56
+ def generate_image(prompt: str, seed: int = None):
57
+ """Generates an image using Stable Diffusion 3.5 with LoRA fine-tuning."""
58
+ seed = seed or random.randint(0, 100000)
59
+ generator = torch.Generator(device).manual_seed(seed)
60
+ return pipeline(prompt, generator=generator).images[0]
 
 
61
 
62
+ # Gradio Interface
63
+ with gr.Blocks() as demo:
64
+ gr.Markdown("# πŸ–ΌοΈ LoRA Fine-Tuned SD 3.5 Image Generator")
65
 
66
+ with gr.Row():
67
+ prompt_input = gr.Textbox(
68
+ label="Enter Prompt",
69
+ value="A woman in her 20s with expressive black eyes, graceful face, elegant body, standing on the beach at sunset. Photorealistic, highly detailed."
70
+ )
71
+ seed_input = gr.Number(label="Seed (optional)", value=None)
72
 
73
+ generate_btn = gr.Button("Generate Image")
74
+ output_image = gr.Image(label="Generated Image")
 
 
 
 
75
 
76
+ generate_btn.click(generate_image, inputs=[prompt_input, seed_input], outputs=output_image)
77
 
78
+ # Launch Gradio app
79
+ demo.launch()
80
+
81
+ if __name__ == "__main__":
82
+ main()