NightRaven109 commited on
Commit
303d638
·
verified ·
1 Parent(s): 77f9404

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -61
app.py CHANGED
@@ -9,10 +9,15 @@ from test_ccsr_tile import load_pipeline
9
  import argparse
10
  from accelerate import Accelerator
11
 
12
- # Initialize global variables
13
- pipeline = None
14
- generator = None
15
- accelerator = None
 
 
 
 
 
16
 
17
  class Args:
18
  def __init__(self, **kwargs):
@@ -20,10 +25,12 @@ class Args:
20
 
21
  @spaces.GPU
22
  def initialize_models():
23
- global pipeline, generator, accelerator
 
 
24
 
25
  try:
26
- # Download model repository
27
  model_path = snapshot_download(
28
  repo_id="NightRaven109/CCSRModels",
29
  token=os.environ['Read2']
@@ -42,24 +49,28 @@ def initialize_models():
42
  )
43
 
44
  # Initialize accelerator
45
- accelerator = Accelerator(
46
  mixed_precision=args.mixed_precision,
47
  )
48
 
49
  # Load pipeline
50
- pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention=False)
 
51
 
52
- # Ensure all models are in eval mode
53
- pipeline.unet.eval()
54
- pipeline.controlnet.eval()
55
- pipeline.vae.eval()
56
- pipeline.text_encoder.eval()
57
 
58
- # Move pipeline to CUDA
59
- pipeline = pipeline.to("cuda")
60
 
61
  # Initialize generator
62
- generator = torch.Generator("cuda")
 
 
 
63
 
64
  return True
65
 
@@ -67,6 +78,7 @@ def initialize_models():
67
  print(f"Error initializing models: {str(e)}")
68
  return False
69
 
 
70
  @spaces.GPU(processing_timeout=180)
71
  def process_image(
72
  input_image,
@@ -79,15 +91,13 @@ def process_image(
79
  upscale_factor=4,
80
  color_fix_method="adain"
81
  ):
82
- global pipeline, generator, accelerator
83
-
84
- try:
85
- # Initialize models if not already done
86
- if pipeline is None:
87
- if not initialize_models():
88
- return None
89
 
90
- # Create args object with all necessary parameters
 
91
  args = Args(
92
  added_prompt=prompt,
93
  negative_prompt=negative_prompt,
@@ -105,13 +115,13 @@ def process_image(
105
  tile_diffusion_stride=None,
106
  start_steps=999,
107
  start_point='lr',
108
- use_vae_encode_condition=True, # Changed to True
109
  sample_times=1
110
  )
111
 
112
  # Set seed if provided
113
  if seed is not None:
114
- generator.manual_seed(seed)
115
 
116
  # Process input image
117
  validation_image = Image.fromarray(input_image)
@@ -128,42 +138,26 @@ def process_image(
128
  validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
129
  width, height = validation_image.size
130
 
131
- # Ensure pipeline is on CUDA and in eval mode
132
- pipeline = pipeline.to("cuda")
133
- pipeline.unet.eval()
134
- pipeline.controlnet.eval()
135
- pipeline.vae.eval()
136
- pipeline.text_encoder.eval()
137
-
138
  # Generate image
139
- with torch.no_grad():
140
- try:
141
- # First encode the image with VAE
142
- image_tensor = pipeline.image_processor.preprocess(validation_image)
143
- image_tensor = image_tensor.unsqueeze(0).to(device="cuda", dtype=torch.float32)
144
-
145
- inference_time, output = pipeline(
146
- args.t_max,
147
- args.t_min,
148
- args.tile_diffusion,
149
- args.tile_diffusion_size,
150
- args.tile_diffusion_stride,
151
- args.added_prompt,
152
- validation_image,
153
- num_inference_steps=args.num_inference_steps,
154
- generator=generator,
155
- height=height,
156
- width=width,
157
- guidance_scale=args.guidance_scale,
158
- negative_prompt=args.negative_prompt,
159
- conditioning_scale=args.conditioning_scale,
160
- start_steps=args.start_steps,
161
- start_point=args.start_point,
162
- use_vae_encode_condition=True, # Set to True
163
- )
164
- except Exception as e:
165
- print(f"Pipeline execution error: {str(e)}")
166
- raise
167
 
168
  image = output.images[0]
169
 
 
9
  import argparse
10
  from accelerate import Accelerator
11
 
12
+ # Global variables
13
+ class ModelContainer:
14
+ def __init__(self):
15
+ self.pipeline = None
16
+ self.generator = None
17
+ self.accelerator = None
18
+ self.is_initialized = False
19
+
20
+ model_container = ModelContainer()
21
 
22
  class Args:
23
  def __init__(self, **kwargs):
 
25
 
26
  @spaces.GPU
27
  def initialize_models():
28
+ """Initialize models only if they haven't been initialized yet"""
29
+ if model_container.is_initialized:
30
+ return True
31
 
32
  try:
33
+ # Download model repository (only once)
34
  model_path = snapshot_download(
35
  repo_id="NightRaven109/CCSRModels",
36
  token=os.environ['Read2']
 
49
  )
50
 
51
  # Initialize accelerator
52
+ model_container.accelerator = Accelerator(
53
  mixed_precision=args.mixed_precision,
54
  )
55
 
56
  # Load pipeline
57
+ model_container.pipeline = load_pipeline(args, model_container.accelerator,
58
+ enable_xformers_memory_efficient_attention=False)
59
 
60
+ # Set models to eval mode
61
+ model_container.pipeline.unet.eval()
62
+ model_container.pipeline.controlnet.eval()
63
+ model_container.pipeline.vae.eval()
64
+ model_container.pipeline.text_encoder.eval()
65
 
66
+ # Move pipeline to CUDA and set to eval mode once
67
+ model_container.pipeline = model_container.pipeline.to("cuda")
68
 
69
  # Initialize generator
70
+ model_container.generator = torch.Generator("cuda")
71
+
72
+ # Set initialization flag
73
+ model_container.is_initialized = True
74
 
75
  return True
76
 
 
78
  print(f"Error initializing models: {str(e)}")
79
  return False
80
 
81
+ @torch.no_grad() # Add no_grad decorator for inference
82
  @spaces.GPU(processing_timeout=180)
83
  def process_image(
84
  input_image,
 
91
  upscale_factor=4,
92
  color_fix_method="adain"
93
  ):
94
+ # Initialize models if not already done
95
+ if not model_container.is_initialized:
96
+ if not initialize_models():
97
+ return None
 
 
 
98
 
99
+ try:
100
+ # Create args object
101
  args = Args(
102
  added_prompt=prompt,
103
  negative_prompt=negative_prompt,
 
115
  tile_diffusion_stride=None,
116
  start_steps=999,
117
  start_point='lr',
118
+ use_vae_encode_condition=True,
119
  sample_times=1
120
  )
121
 
122
  # Set seed if provided
123
  if seed is not None:
124
+ model_container.generator.manual_seed(seed)
125
 
126
  # Process input image
127
  validation_image = Image.fromarray(input_image)
 
138
  validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
139
  width, height = validation_image.size
140
 
 
 
 
 
 
 
 
141
  # Generate image
142
+ inference_time, output = model_container.pipeline(
143
+ args.t_max,
144
+ args.t_min,
145
+ args.tile_diffusion,
146
+ args.tile_diffusion_size,
147
+ args.tile_diffusion_stride,
148
+ args.added_prompt,
149
+ validation_image,
150
+ num_inference_steps=args.num_inference_steps,
151
+ generator=model_container.generator,
152
+ height=height,
153
+ width=width,
154
+ guidance_scale=args.guidance_scale,
155
+ negative_prompt=args.negative_prompt,
156
+ conditioning_scale=args.conditioning_scale,
157
+ start_steps=args.start_steps,
158
+ start_point=args.start_point,
159
+ use_vae_encode_condition=True,
160
+ )
 
 
 
 
 
 
 
 
 
161
 
162
  image = output.images[0]
163