AsmaAILab commited on
Commit
94bc625
·
verified ·
1 Parent(s): 88908fa

add app file to run model

Browse files
Files changed (1) hide show
  1. app.py +214 -0
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ import cv2
6
+ import random
7
+ import gradio as gr
8
+ from gradio.themes import Soft
9
+
10
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
11
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
12
+ from transformers import AutoTokenizer, CLIPTextModel, CLIPFeatureExtractor
13
+ from transformers import DPTForDepthEstimation, DPTImageProcessor
14
+
15
+
16
+ stable_diffusion_base = "runwayml/stable-diffusion-v1-5"
17
+
18
+ finetune_controlnet_path = "controlnet"
19
+
20
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
22
+
23
+ pipeline = None
24
+ depth_estimator_model = None
25
+ depth_estimator_processor = None
26
+
27
+
28
+ def load_depth_estimator():
29
+ global depth_estimator_model, depth_estimator_processor
30
+ if depth_estimator_model is None:
31
+ model_name = "Intel/dpt-hybrid-midas"
32
+ depth_estimator_model = DPTForDepthEstimation.from_pretrained(model_name)
33
+ depth_estimator_processor = DPTImageProcessor.from_pretrained(model_name)
34
+ depth_estimator_model.to(DEVICE)
35
+ depth_estimator_model.eval()
36
+
37
+ return depth_estimator_model, depth_estimator_processor
38
+
39
+
40
+
41
+ def load_diffusion_pipeline():
42
+ global pipeline
43
+ if pipeline is None:
44
+ try:
45
+ if not os.path.exists(finetune_controlnet_path):
46
+ raise FileNotFoundError(f"ControlNet model not found: {finetune_controlnet_path}")
47
+
48
+ # 1. Load individual components of the base Stable Diffusion pipeline from Hugging Face Hub
49
+ vae = AutoencoderKL.from_pretrained(stable_diffusion_base, subfolder="vae", torch_dtype=DTYPE)
50
+ tokenizer = AutoTokenizer.from_pretrained(stable_diffusion_base, subfolder="tokenizer")
51
+ text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_base, subfolder="text_encoder", torch_dtype=DTYPE)
52
+ unet = UNet2DConditionModel.from_pretrained(stable_diffusion_base, subfolder="unet", torch_dtype=DTYPE)
53
+ scheduler = DDPMScheduler.from_pretrained(stable_diffusion_base, subfolder="scheduler")
54
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(stable_diffusion_base, subfolder="feature_extractor")
55
+
56
+ controlnet = ControlNetModel.from_pretrained(finetune_controlnet_path, torch_dtype=DTYPE)
57
+ pipeline = StableDiffusionControlNetPipeline(
58
+ vae=vae,
59
+ text_encoder=text_encoder,
60
+ tokenizer=tokenizer,
61
+ unet=unet,
62
+ controlnet=controlnet, # Your fine-tuned ControlNet
63
+ scheduler=scheduler,
64
+ safety_checker=None,
65
+ feature_extractor=feature_extractor,
66
+ image_encoder=None, # Explicitly set to None as it's not part of this setup
67
+ requires_safety_checker=False,
68
+ )
69
+
70
+ pipeline.to(DEVICE)
71
+ if torch.cuda.is_available() and hasattr(pipeline, "enable_xformers_memeory_efficient_attention"):
72
+ try:
73
+ pipeline.enable_xformers_memory_efficient_attention()
74
+ print("xformers memory efficient attention enabled.")
75
+ except Exception as e:
76
+ print(f"Could not enable xformers: {e}")
77
+
78
+
79
+ load_depth_estimator()
80
+
81
+ except Exception as e:
82
+ print(f"Error loading pipeline: {e}")
83
+ pipeline = None
84
+ raise RuntimeError(f"Failed to load diffusion pipeline: {e}")
85
+ return pipeline
86
+
87
+
88
+
89
+ def estimate_depth(pil_image: Image.Image) ->Image.Image:
90
+ global depth_estimator_model, depth_estimator_processor
91
+ if depth_estimator_model is None or depth_estimator_processor is None:
92
+ try:
93
+ load_depth_estimator()
94
+ except RuntimeError as e:
95
+ raise RuntimeError(f"Depth estimator not loaded: {e}")
96
+
97
+ input = depth_estimator_processor(pil_image, return_tensors = "pt")
98
+ input = {k: v.to(DEVICE) for k, v in input.items()}
99
+
100
+
101
+ with torch.no_grad():
102
+ output = depth_estimator_model(**input)
103
+ predicted_depth = output.predicted_depth
104
+
105
+ depth_numpy = predicted_depth.squeeze().cpu().numpy()
106
+
107
+ min_depth = depth_numpy.min()
108
+ max_depth = depth_numpy.max()
109
+ normalized_depth = (depth_numpy - min_depth) / (max_depth - min_depth)
110
+
111
+ inverted_normalized_depth = 1 - normalized_depth
112
+
113
+ depth_image_array = (inverted_normalized_depth * 255).astype(np.uint8)
114
+ depth_pil_image = Image.fromarray(depth_image_array).convert("RGB")
115
+
116
+ print("Depth estimation complete.")
117
+ return depth_pil_image
118
+
119
+
120
+ def generate_image_for_gradio(
121
+ prompt: str,
122
+ input_image_for_depth: Image.Image,
123
+ num_inference_steps: int = 25,
124
+ guidance_scale: float = 8.0,
125
+ seed: int = None,
126
+ resolution: int = 512
127
+ ) -> Image.Image:
128
+
129
+ global pipeline
130
+ if pipeline is None:
131
+ try:
132
+ load_diffusion_pipeline()
133
+ except RuntimeError as e:
134
+ return gr.Error(f"Model not loaded: {e}")
135
+
136
+ try:
137
+ depth_map_pil = estimate_depth(input_image_for_depth)
138
+ except Exception as e:
139
+ return gr.Error(f"Error during depth estimation: {e}")
140
+
141
+ print(f"Generating image for prompt: '{prompt}'")
142
+
143
+
144
+ control_image = depth_map_pil.convert("RGB")
145
+ control_image = control_image.resize((resolution, resolution), Image.LANCZOS)
146
+
147
+ input_image_for_pipeline = [control_image]
148
+
149
+ generator = None
150
+ if seed is None:
151
+ seed = random.randint(0, 100000)
152
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
153
+
154
+
155
+ with torch.no_grad():
156
+ generated_images = pipeline(
157
+ prompt,
158
+ image=input_image_for_pipeline,
159
+ num_inference_steps=num_inference_steps,
160
+ guidance_scale=guidance_scale,
161
+ generator=generator,
162
+ ).images
163
+
164
+ print(f"Image generation complete (seed: {seed}).")
165
+ return generated_images[0]
166
+
167
+
168
+
169
+ iface = gr.Interface(
170
+ fn=generate_image_for_gradio,
171
+ inputs=[
172
+ gr.Textbox(label="Prompt", value="a high-quality photo of a modern interior design"),
173
+ gr.Image(type="pil", label="Input Image (for Depth Estimation)"),
174
+ gr.Slider(minimum=10, maximum=100, value=25, step=1, label="Inference Steps"),
175
+ gr.Slider(minimum=1.0, maximum=20.0, value=8.0, step=0.5, label="Guidance Scale"),
176
+ gr.Number(label="Seed (optional, leave blank for random)", value=None),
177
+ gr.Number(label="Resolution", value=512, interactive=False)
178
+ ],
179
+ outputs=gr.Image(type="pil", label="Generated Image"),
180
+ title="Stable Diffusion ControlNet Depth Demo (with Depth Estimation)",
181
+ description="Upload an input image, and the app will estimate its depth map, then use it with your prompt to generate a new image. This allows for structural guidance from your input photo.",
182
+ allow_flagging="never",
183
+ live=False,
184
+ theme=Soft(),
185
+ css="""
186
+ /* Target the upload icon within the Image component */
187
+ .gr-image .icon-lg {
188
+ font-size: 2em !important; /* Adjust size as needed, e.g., 2em, 3em */
189
+ max-width: 50px; /* Max width to prevent it from filling the container */
190
+ max-height: 50px; /* Max height */
191
+ }
192
+ /* Target the image placeholder icon (if it's different) */
193
+ .gr-image .gr-image-placeholder {
194
+ max-width: 100px; /* Adjust size as needed */
195
+ max-height: 100px;
196
+ object-fit: contain; /* Ensures the icon scales down without distortion */
197
+ }
198
+ /* General styling for the image input area to ensure it has space */
199
+ .gr-image-container {
200
+ min-height: 200px; /* Give the image input area a minimum height */
201
+ display: flex;
202
+ align-items: center;
203
+ justify-content: center;
204
+ }
205
+ """
206
+ )
207
+
208
+
209
+
210
+ load_diffusion_pipeline()
211
+
212
+
213
+ if __name__ == "__main__":
214
+ iface.launch()