Update app.py
Browse files
app.py
CHANGED
@@ -4,42 +4,55 @@ from huggingface_hub import login
|
|
4 |
import os
|
5 |
import gradio as gr
|
6 |
|
7 |
-
# Retrieve
|
8 |
-
token = os.getenv("HF_TOKEN")
|
9 |
if token:
|
10 |
-
login(token=token)
|
11 |
else:
|
12 |
-
raise ValueError("Hugging Face token not found. Please set it as a repository secret
|
13 |
|
14 |
# Load the Stable Diffusion 3.5 model
|
15 |
model_id = "stabilityai/stable-diffusion-3.5-large"
|
16 |
-
pipe = StableDiffusion3Pipeline.from_pretrained(
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
# Define the path to the LoRA model
|
20 |
-
lora_model_path = "
|
21 |
|
22 |
-
#
|
23 |
def load_lora_model(pipe, lora_model_path):
|
24 |
-
# Load the LoRA weights (assuming it's a PyTorch .pth file)
|
25 |
lora_weights = torch.load(lora_model_path, map_location="cuda")
|
|
|
|
|
26 |
|
27 |
-
# Modify this section based on how LoRA is intended to interact with your Stable Diffusion model
|
28 |
-
# Here, we just load the weights into the model's parameters (this is a conceptual approach)
|
29 |
-
for name, param in pipe.named_parameters():
|
30 |
-
if name in lora_weights:
|
31 |
-
param.data += lora_weights[name] # Apply LoRA weights to the parameters
|
32 |
-
|
33 |
-
return pipe # Return the updated model
|
34 |
-
|
35 |
-
# Load and apply the LoRA model weights
|
36 |
pipe = load_lora_model(pipe, lora_model_path)
|
37 |
|
38 |
-
#
|
39 |
-
def generate_image(prompt):
|
40 |
-
|
|
|
41 |
return image
|
42 |
|
43 |
# Gradio interface
|
44 |
-
iface = gr.Interface(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
iface.launch()
|
|
|
4 |
import os
|
5 |
import gradio as gr
|
6 |
|
7 |
+
# Retrieve Hugging Face token
|
8 |
+
token = os.getenv("HF_TOKEN")
|
9 |
if token:
|
10 |
+
login(token=token)
|
11 |
else:
|
12 |
+
raise ValueError("Hugging Face token not found. Please set it as a repository secret.")
|
13 |
|
14 |
# Load the Stable Diffusion 3.5 model
|
15 |
model_id = "stabilityai/stable-diffusion-3.5-large"
|
16 |
+
pipe = StableDiffusion3Pipeline.from_pretrained(
|
17 |
+
model_id,
|
18 |
+
torch_dtype=torch.float16,
|
19 |
+
revision="fp16",
|
20 |
+
low_cpu_mem_usage=True,
|
21 |
+
device_map="auto" # Automatically allocate model components to available devices
|
22 |
+
)
|
23 |
+
|
24 |
+
# Enable attention slicing for reduced memory usage
|
25 |
+
pipe.enable_attention_slicing()
|
26 |
|
27 |
# Define the path to the LoRA model
|
28 |
+
lora_model_path = "./lora_model.pth"
|
29 |
|
30 |
+
# Load and apply the LoRA weights
|
31 |
def load_lora_model(pipe, lora_model_path):
|
|
|
32 |
lora_weights = torch.load(lora_model_path, map_location="cuda")
|
33 |
+
pipe.unet.load_attn_procs(lora_weights)
|
34 |
+
return pipe
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
pipe = load_lora_model(pipe, lora_model_path)
|
37 |
|
38 |
+
# Generate image function
|
39 |
+
def generate_image(prompt, steps, scale):
|
40 |
+
with torch.inference_mode(): # Avoid gradient computation for inference
|
41 |
+
image = pipe(prompt, num_inference_steps=steps, guidance_scale=scale).images[0]
|
42 |
return image
|
43 |
|
44 |
# Gradio interface
|
45 |
+
iface = gr.Interface(
|
46 |
+
fn=generate_image,
|
47 |
+
inputs=[
|
48 |
+
gr.Textbox(label="Enter your prompt"),
|
49 |
+
gr.Slider(10, 50, step=1, value=30, label="Number of Inference Steps"),
|
50 |
+
gr.Slider(1.0, 20.0, step=0.5, value=7.5, label="Guidance Scale"),
|
51 |
+
],
|
52 |
+
outputs="image",
|
53 |
+
title="Optimized Stable Diffusion with LoRA",
|
54 |
+
description="Generate images using Stable Diffusion 3.5 with optimized memory usage."
|
55 |
+
)
|
56 |
+
|
57 |
+
# Launch the Gradio interface
|
58 |
iface.launch()
|