NightRaven109 commited on
Commit
bfecb5b
·
verified ·
1 Parent(s): b22f2c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -40
app.py CHANGED
@@ -117,26 +117,34 @@ def process_image(
117
  validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
118
  width, height = validation_image.size
119
 
 
 
 
 
 
 
 
120
  # Generate image
121
- inference_time, output = pipeline(
122
- args.t_max,
123
- args.t_min,
124
- args.tile_diffusion,
125
- args.tile_diffusion_size,
126
- args.tile_diffusion_stride,
127
- args.added_prompt,
128
- validation_image,
129
- num_inference_steps=args.num_inference_steps,
130
- generator=generator,
131
- height=height,
132
- width=width,
133
- guidance_scale=args.guidance_scale,
134
- negative_prompt=args.negative_prompt,
135
- conditioning_scale=args.conditioning_scale,
136
- start_steps=args.start_steps,
137
- start_point=args.start_point,
138
- use_vae_encode_condition=args.use_vae_encode_condition,
139
- )
 
140
 
141
  image = output.images[0]
142
 
@@ -149,30 +157,62 @@ def process_image(
149
  if resize_flag:
150
  image = image.resize((ori_width*args.upscale, ori_height*args.upscale))
151
 
 
 
 
 
152
  return image
153
 
154
  except Exception as e:
155
  print(f"Error processing image: {str(e)}")
156
  return None
157
 
158
- # Create Gradio interface
159
- iface = gr.Interface(
160
- fn=process_image,
161
- inputs=[
162
- gr.Image(label="Input Image"),
163
- gr.Textbox(label="Prompt", value="clean, high-resolution, 8k"),
164
- gr.Textbox(label="Negative Prompt", value="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed"),
165
- gr.Slider(minimum=1.0, maximum=20.0, value=1.0, label="Guidance Scale"),
166
- gr.Slider(minimum=0.1, maximum=2.0, value=1.0, label="Conditioning Scale"),
167
- gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Number of Steps"),
168
- gr.Number(label="Seed", value=42),
169
- gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Upscale Factor"),
170
- gr.Radio(["none", "wavelet", "adain"], label="Color Fix Method", value="adain"),
171
- ],
172
- outputs=gr.Image(label="Generated Image"),
173
- title="Controllable Conditional Super-Resolution",
174
- description="Upload an image to enhance its resolution using CCSR."
175
- )
176
-
177
- if __name__ == "__main__":
178
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
118
  width, height = validation_image.size
119
 
120
+ # Move pipeline to GPU and set to eval mode
121
+ pipeline.to(accelerator.device)
122
+ pipeline.unet.eval()
123
+ pipeline.controlnet.eval()
124
+ pipeline.vae.eval()
125
+ pipeline.text_encoder.eval()
126
+
127
  # Generate image
128
+ with torch.no_grad():
129
+ inference_time, output = pipeline(
130
+ args.t_max,
131
+ args.t_min,
132
+ args.tile_diffusion,
133
+ args.tile_diffusion_size,
134
+ args.tile_diffusion_stride,
135
+ args.added_prompt,
136
+ validation_image,
137
+ num_inference_steps=args.num_inference_steps,
138
+ generator=generator,
139
+ height=height,
140
+ width=width,
141
+ guidance_scale=args.guidance_scale,
142
+ negative_prompt=args.negative_prompt,
143
+ conditioning_scale=args.conditioning_scale,
144
+ start_steps=args.start_steps,
145
+ start_point=args.start_point,
146
+ use_vae_encode_condition=args.use_vae_encode_condition,
147
+ )
148
 
149
  image = output.images[0]
150
 
 
157
  if resize_flag:
158
  image = image.resize((ori_width*args.upscale, ori_height*args.upscale))
159
 
160
+ # Move pipeline back to CPU to free up GPU memory
161
+ pipeline.to("cpu")
162
+ torch.cuda.empty_cache()
163
+
164
  return image
165
 
166
  except Exception as e:
167
  print(f"Error processing image: {str(e)}")
168
  return None
169
 
170
+ # Also update the initialize_models function:
171
+ @spaces.GPU
172
+ def initialize_models():
173
+ global pipeline, generator, accelerator
174
+
175
+ try:
176
+ # Download model repository
177
+ model_path = snapshot_download(
178
+ repo_id="NightRaven109/CCSRModels",
179
+ token=os.environ['Read2']
180
+ )
181
+
182
+ # Set up default arguments
183
+ args = Args(
184
+ pretrained_model_path=os.path.join(model_path, "stable-diffusion-2-1-base"),
185
+ controlnet_model_path=os.path.join(model_path, "Controlnet"),
186
+ vae_model_path=os.path.join(model_path, "vae"),
187
+ mixed_precision="fp16",
188
+ tile_vae=False,
189
+ sample_method="ddpm",
190
+ vae_encoder_tile_size=1024,
191
+ vae_decoder_tile_size=224
192
+ )
193
+
194
+ # Initialize accelerator
195
+ accelerator = Accelerator(
196
+ mixed_precision=args.mixed_precision,
197
+ )
198
+
199
+ # Load pipeline
200
+ pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
201
+
202
+ # Set pipeline to eval mode
203
+ pipeline.unet.eval()
204
+ pipeline.controlnet.eval()
205
+ pipeline.vae.eval()
206
+ pipeline.text_encoder.eval()
207
+
208
+ # Move to CPU initially to save memory
209
+ pipeline.to("cpu")
210
+
211
+ # Initialize generator
212
+ generator = torch.Generator(device=accelerator.device)
213
+
214
+ return True
215
+
216
+ except Exception as e:
217
+ print(f"Error initializing models: {str(e)}")
218
+ return False