DonImages commited on
Commit
6085215
·
verified ·
1 Parent(s): e86cc65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -35
app.py CHANGED
@@ -1,37 +1,44 @@
1
- from fastapi import FastAPI, HTTPException
2
- import base64
 
3
  import os
4
- from contextlib import asynccontextmanager
5
-
6
- # Global variable to hold the LoRA weights
7
- lora_weights = None
8
-
9
- # Lifespan context manager to handle the application startup
10
- @asynccontextmanager
11
- async def lifespan(app: FastAPI):
12
- global lora_weights
13
- lora_path = "./lora_file.pth" # Ensure the correct file name
14
- if os.path.exists(lora_path):
15
- with open(lora_path, "rb") as f:
16
- lora_weights = base64.b64encode(f.read()).decode("utf-8")
17
- print("LoRA weights loaded and preprocessed successfully.")
18
- else:
19
- raise HTTPException(status_code=500, detail="LoRA file not found.")
20
- yield
21
- # Cleanup if necessary (but not required in this case)
22
-
23
- # Initialize FastAPI app with lifespan context
24
- app = FastAPI(lifespan=lifespan)
25
-
26
- @app.post("/modify-prompt")
27
- async def modify_prompt(prompt: str):
28
- global lora_weights
29
- if lora_weights is None:
30
- raise HTTPException(status_code=500, detail="LoRA weights not loaded.")
31
 
32
- # Combine prompt with preprocessed LoRA data
33
- extended_prompt = {
34
- "prompt": prompt,
35
- "lora": lora_weights
36
- }
37
- return extended_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusion3Pipeline
3
+ from huggingface_hub import login
4
  import os
5
+ import gradio as gr
6
+
7
+ # Retrieve the token from the environment variable
8
+ token = os.getenv("HF_TOKEN") # Hugging Face token from the secret
9
+ if token:
10
+ login(token=token) # Log in with the retrieved token
11
+ else:
12
+ raise ValueError("Hugging Face token not found. Please set it as a repository secret in the Space settings.")
13
+
14
+ # Load the Stable Diffusion 3.5 model
15
+ model_id = "stabilityai/stable-diffusion-3.5-large"
16
+ pipe = StableDiffusion3Pipeline.from_pretrained(model_id) # Removed torch_dtype argument
17
+ pipe.to("cpu") # Ensuring it runs on CPU
18
+
19
+ # Define the path to the LoRA model
20
+ lora_model_path = "./lora_model.pth" # Assuming the file is saved locally
21
+
22
+ # Custom method to load and apply LoRA weights to the Stable Diffusion pipeline
23
+ def load_lora_model(pipe, lora_model_path):
24
+ # Load the LoRA weights
25
+ lora_weights = torch.load(lora_model_path, map_location="cpu")
 
 
 
 
 
 
26
 
27
+ # Apply weights to the UNet submodule
28
+ for name, param in pipe.unet.named_parameters(): # Accessing unet parameters
29
+ if name in lora_weights:
30
+ param.data += lora_weights[name]
31
+
32
+ return pipe
33
+
34
+ # Load and apply the LoRA model weights
35
+ pipe = load_lora_model(pipe, lora_model_path)
36
+
37
+ # Function to generate an image from a text prompt
38
+ def generate_image(prompt):
39
+ image = pipe(prompt).images[0]
40
+ return image
41
+
42
+ # Gradio interface
43
+ iface = gr.Interface(fn=generate_image, inputs="text", outputs="image")
44
+ iface.launch()