erwold commited on
Commit
0f04459
·
1 Parent(s): 76678b6

Initial Commit

Browse files
Files changed (1) hide show
  1. app.py +42 -27
app.py CHANGED
@@ -91,34 +91,36 @@ class FluxInterface:
91
  tokenizer_two = T5TokenizerFast.from_pretrained(
92
  os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2"))
93
 
94
- # 2. 将大模型初始加载到CPU
95
  vae = AutoencoderKL.from_pretrained(
96
  os.path.join(MODEL_CACHE_DIR, "flux/vae")
97
- ).to(torch.float32).cpu()
98
 
99
  transformer = FluxTransformer2DModel.from_pretrained(
100
  os.path.join(MODEL_CACHE_DIR, "flux/transformer")
101
- ).to(torch.float32).cpu()
102
 
103
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
104
- os.path.join(MODEL_CACHE_DIR, "flux/scheduler"),
105
  shift=1
106
  )
107
 
108
- # 3. Qwen2VL初始加载到CPU
109
  qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
110
  os.path.join(MODEL_CACHE_DIR, "qwen2-vl")
111
- ).to(torch.float32).cpu()
112
 
113
- # 4. 加载connector和embedder到CPU
114
- connector = Qwen2Connector().to(torch.float32).cpu()
115
  connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
116
  connector_state = torch.load(connector_path, map_location='cpu')
 
117
  connector.load_state_dict(connector_state)
118
 
119
- self.t5_context_embedder = nn.Linear(4096, 3072).to(torch.float32).cpu()
120
  t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
121
  t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
 
122
  self.t5_context_embedder.load_state_dict(t5_embedder_state)
123
 
124
  # 5. 设置所有模型为eval模式
@@ -141,7 +143,6 @@ class FluxInterface:
141
  'connector': connector
142
  }
143
 
144
- # Initialize processor and pipeline
145
  self.qwen2vl_processor = AutoProcessor.from_pretrained(
146
  self.MODEL_ID,
147
  subfolder="qwen2-vl",
@@ -160,15 +161,17 @@ class FluxInterface:
160
  def move_to_device(self, model, device):
161
  """Helper function to move model to specified device"""
162
  if hasattr(model, 'to'):
163
- return model.to(device)
164
  return model
165
 
166
  def process_image(self, image):
167
  """Process image with Qwen2VL model"""
168
  try:
169
  # 1. 将Qwen2VL相关模型移到GPU
170
- self.models['qwen2vl'] = self.move_to_device(self.models['qwen2vl'], self.device)
171
- self.models['connector'] = self.move_to_device(self.models['connector'], self.device)
 
 
172
 
173
  message = [
174
  {
@@ -200,10 +203,12 @@ class FluxInterface:
200
  # 保存结果到CPU
201
  result = (image_hidden_state.cpu(), image_grid_thw)
202
 
203
- # 2. 将Qwen2VL相关模型移回CPU以释放显存
204
- self.models['qwen2vl'] = self.move_to_device(self.models['qwen2vl'], 'cpu')
205
- self.models['connector'] = self.move_to_device(self.models['connector'], 'cpu')
 
206
  torch.cuda.empty_cache()
 
207
 
208
  return result
209
 
@@ -242,8 +247,8 @@ class FluxInterface:
242
  ).to(self.device)
243
 
244
  prompt_embeds = self.models['text_encoder_two'](text_inputs.input_ids)[0]
245
- prompt_embeds = prompt_embeds.to(dtype=self.dtype, device=self.device)
246
- prompt_embeds = self.t5_context_embedder(prompt_embeds)
247
 
248
  return prompt_embeds
249
 
@@ -261,9 +266,9 @@ class FluxInterface:
261
  text_inputs.input_ids,
262
  output_hidden_states=False
263
  )
264
- pooled_prompt_embeds = prompt_embeds.pooler_output.to(self.dtype)
 
265
 
266
- return pooled_prompt_embeds
267
 
268
  def generate(self, input_image, prompt="", guidance_scale=3.5,
269
  num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"):
@@ -275,27 +280,36 @@ class FluxInterface:
275
 
276
  if seed is not None:
277
  torch.manual_seed(seed)
 
278
 
279
  # 1. 使用Qwen2VL处理图像
 
280
  qwen2_hidden_state, image_grid_thw = self.process_image(input_image)
 
281
 
282
  # 2. 计算文本嵌入
 
283
  pooled_prompt_embeds = self.compute_text_embeddings("")
284
  t5_prompt_embeds = self.compute_t5_text_embeddings(prompt)
 
285
 
286
  # 3. 将Transformer和VAE移到GPU
287
- self.models['transformer'] = self.move_to_device(self.models['transformer'], self.device)
288
- self.models['vae'] = self.move_to_device(self.models['vae'], self.device)
 
289
 
290
- # 更新pipeline中的模型
291
  self.pipeline.transformer = self.models['transformer']
292
  self.pipeline.vae = self.models['vae']
 
293
 
294
  # 获取维度
295
  width, height = ASPECT_RATIOS[aspect_ratio]
 
296
 
297
  # 4. 生成图像
298
  try:
 
299
  output_images = self.pipeline(
300
  prompt_embeds=qwen2_hidden_state.to(self.device).repeat(num_images, 1, 1),
301
  pooled_prompt_embeds=pooled_prompt_embeds,
@@ -305,11 +319,14 @@ class FluxInterface:
305
  height=height,
306
  width=width,
307
  ).images
 
308
 
309
  # 5. 将Transformer和VAE移回CPU
310
- self.models['transformer'] = self.move_to_device(self.models['transformer'], 'cpu')
311
- self.models['vae'] = self.move_to_device(self.models['vae'], 'cpu')
 
312
  torch.cuda.empty_cache()
 
313
 
314
  return output_images
315
 
@@ -323,8 +340,6 @@ class FluxInterface:
323
  # Initialize the interface
324
  interface = FluxInterface()
325
 
326
- # 直接将 GPU 装饰器应用在最外层的处理函数上
327
- @spaces.GPU(duration=300)
328
  def process_request(input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"):
329
  """主处理函数,直接处理用户请求"""
330
  try:
 
91
  tokenizer_two = T5TokenizerFast.from_pretrained(
92
  os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2"))
93
 
94
+ # 2. 将大模型加载到CPU,但保持bfloat16精度
95
  vae = AutoencoderKL.from_pretrained(
96
  os.path.join(MODEL_CACHE_DIR, "flux/vae")
97
+ ).to(self.dtype).cpu()
98
 
99
  transformer = FluxTransformer2DModel.from_pretrained(
100
  os.path.join(MODEL_CACHE_DIR, "flux/transformer")
101
+ ).to(self.dtype).cpu()
102
 
103
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
104
+ os.path.join(MODEL_CACHE_DIR, "flux/scheduler"),
105
  shift=1
106
  )
107
 
108
+ # 3. Qwen2VL加载到CPU,保持bfloat16
109
  qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
110
  os.path.join(MODEL_CACHE_DIR, "qwen2-vl")
111
+ ).to(self.dtype).cpu()
112
 
113
+ # 4. 加载connector和embedder,保持bfloat16
114
+ connector = Qwen2Connector().to(self.dtype).cpu()
115
  connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
116
  connector_state = torch.load(connector_path, map_location='cpu')
117
+ connector_state = {k: v.to(self.dtype) for k, v in connector_state.items()}
118
  connector.load_state_dict(connector_state)
119
 
120
+ self.t5_context_embedder = nn.Linear(4096, 3072).to(self.dtype).cpu()
121
  t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
122
  t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
123
+ t5_embedder_state = {k: v.to(self.dtype) for k, v in t5_embedder_state.items()}
124
  self.t5_context_embedder.load_state_dict(t5_embedder_state)
125
 
126
  # 5. 设置所有模型为eval模式
 
143
  'connector': connector
144
  }
145
 
 
146
  self.qwen2vl_processor = AutoProcessor.from_pretrained(
147
  self.MODEL_ID,
148
  subfolder="qwen2-vl",
 
161
  def move_to_device(self, model, device):
162
  """Helper function to move model to specified device"""
163
  if hasattr(model, 'to'):
164
+ return model.to(self.dtype).to(device)
165
  return model
166
 
167
  def process_image(self, image):
168
  """Process image with Qwen2VL model"""
169
  try:
170
  # 1. 将Qwen2VL相关模型移到GPU
171
+ logger.info("Moving Qwen2VL models to GPU...")
172
+ self.models['qwen2vl'] = self.models['qwen2vl'].to(self.device)
173
+ self.models['connector'] = self.models['connector'].to(self.device)
174
+ logger.info("Qwen2VL models moved to GPU")
175
 
176
  message = [
177
  {
 
203
  # 保存结果到CPU
204
  result = (image_hidden_state.cpu(), image_grid_thw)
205
 
206
+ # 2. 将Qwen2VL相关模型移回CPU
207
+ logger.info("Moving Qwen2VL models back to CPU...")
208
+ self.models['qwen2vl'] = self.models['qwen2vl'].cpu()
209
+ self.models['connector'] = self.models['connector'].cpu()
210
  torch.cuda.empty_cache()
211
+ logger.info("Qwen2VL models moved to CPU and GPU cache cleared")
212
 
213
  return result
214
 
 
247
  ).to(self.device)
248
 
249
  prompt_embeds = self.models['text_encoder_two'](text_inputs.input_ids)[0]
250
+ prompt_embeds = self.t5_context_embedder.to(self.device)(prompt_embeds)
251
+ self.t5_context_embedder = self.t5_context_embedder.cpu()
252
 
253
  return prompt_embeds
254
 
 
266
  text_inputs.input_ids,
267
  output_hidden_states=False
268
  )
269
+ pooled_prompt_embeds = prompt_embeds.pooler_output
270
+ return pooled_prompt_embeds
271
 
 
272
 
273
  def generate(self, input_image, prompt="", guidance_scale=3.5,
274
  num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"):
 
280
 
281
  if seed is not None:
282
  torch.manual_seed(seed)
283
+ logger.info(f"Set random seed to: {seed}")
284
 
285
  # 1. 使用Qwen2VL处理图像
286
+ logger.info("Processing input image with Qwen2VL...")
287
  qwen2_hidden_state, image_grid_thw = self.process_image(input_image)
288
+ logger.info("Image processing completed")
289
 
290
  # 2. 计算文本嵌入
291
+ logger.info("Computing text embeddings...")
292
  pooled_prompt_embeds = self.compute_text_embeddings("")
293
  t5_prompt_embeds = self.compute_t5_text_embeddings(prompt)
294
+ logger.info("Text embeddings computed")
295
 
296
  # 3. 将Transformer和VAE移到GPU
297
+ logger.info("Moving Transformer and VAE to GPU...")
298
+ self.models['transformer'] = self.models['transformer'].to(self.device)
299
+ self.models['vae'] = self.models['vae'].to(self.device)
300
 
301
+ # 更新pipeline中的模型引用
302
  self.pipeline.transformer = self.models['transformer']
303
  self.pipeline.vae = self.models['vae']
304
+ logger.info("Models moved to GPU")
305
 
306
  # 获取维度
307
  width, height = ASPECT_RATIOS[aspect_ratio]
308
+ logger.info(f"Using dimensions: {width}x{height}")
309
 
310
  # 4. 生成图像
311
  try:
312
+ logger.info("Starting image generation...")
313
  output_images = self.pipeline(
314
  prompt_embeds=qwen2_hidden_state.to(self.device).repeat(num_images, 1, 1),
315
  pooled_prompt_embeds=pooled_prompt_embeds,
 
319
  height=height,
320
  width=width,
321
  ).images
322
+ logger.info("Image generation completed")
323
 
324
  # 5. 将Transformer和VAE移回CPU
325
+ logger.info("Moving models back to CPU...")
326
+ self.models['transformer'] = self.models['transformer'].cpu()
327
+ self.models['vae'] = self.models['vae'].cpu()
328
  torch.cuda.empty_cache()
329
+ logger.info("Models moved to CPU and GPU cache cleared")
330
 
331
  return output_images
332
 
 
340
  # Initialize the interface
341
  interface = FluxInterface()
342
 
 
 
343
  def process_request(input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1"):
344
  """主处理函数,直接处理用户请求"""
345
  try: