Update app.py
Browse files
app.py
CHANGED
@@ -1,37 +1,44 @@
|
|
1 |
-
|
2 |
-
import
|
|
|
3 |
import os
|
4 |
-
|
5 |
-
|
6 |
-
#
|
7 |
-
|
8 |
-
|
9 |
-
#
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
-
@app.post("/modify-prompt")
|
27 |
-
async def modify_prompt(prompt: str):
|
28 |
-
global lora_weights
|
29 |
-
if lora_weights is None:
|
30 |
-
raise HTTPException(status_code=500, detail="LoRA weights not loaded.")
|
31 |
|
32 |
-
#
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import StableDiffusion3Pipeline
|
3 |
+
from huggingface_hub import login
|
4 |
import os
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
# Retrieve the token from the environment variable
|
8 |
+
token = os.getenv("HF_TOKEN") # Hugging Face token from the secret
|
9 |
+
if token:
|
10 |
+
login(token=token) # Log in with the retrieved token
|
11 |
+
else:
|
12 |
+
raise ValueError("Hugging Face token not found. Please set it as a repository secret in the Space settings.")
|
13 |
+
|
14 |
+
# Load the Stable Diffusion 3.5 model
|
15 |
+
model_id = "stabilityai/stable-diffusion-3.5-large"
|
16 |
+
pipe = StableDiffusion3Pipeline.from_pretrained(model_id) # Removed torch_dtype argument
|
17 |
+
pipe.to("cpu") # Ensuring it runs on CPU
|
18 |
+
|
19 |
+
# Define the path to the LoRA model
|
20 |
+
lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
|
21 |
+
|
22 |
+
# Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
|
23 |
+
def load_lora_model(pipe, lora_model_path):
|
24 |
+
# Load the LoRA weights
|
25 |
+
lora_weights = torch.load(lora_model_path, map_location="cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
+
# Apply weights to the UNet submodule
|
28 |
+
for name, param in pipe.unet.named_parameters(): # Accessing unet parameters
|
29 |
+
if name in lora_weights:
|
30 |
+
param.data += lora_weights[name]
|
31 |
+
|
32 |
+
return pipe
|
33 |
+
|
34 |
+
# Load and apply the LoRA model weights
|
35 |
+
pipe = load_lora_model(pipe, lora_model_path)
|
36 |
+
|
37 |
+
# Function to generate an image from a text prompt
|
38 |
+
def generate_image(prompt):
|
39 |
+
image = pipe(prompt).images[0]
|
40 |
+
return image
|
41 |
+
|
42 |
+
# Gradio interface
|
43 |
+
iface = gr.Interface(fn=generate_image, inputs="text", outputs="image")
|
44 |
+
iface.launch()
|