erwold commited on
Commit
2d66916
·
1 Parent(s): 29fa1d0

Initial Commit

Browse files
Files changed (1) hide show
  1. app.py +80 -28
app.py CHANGED
@@ -11,6 +11,9 @@ import math
11
  import logging
12
  import sys
13
 
 
 
 
14
  from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
15
 
16
  # 设置日志
@@ -51,49 +54,92 @@ class FluxInterface:
51
  self.MODEL_ID = "Djrango/Qwen2vl-Flux"
52
 
53
  def load_models(self):
54
- if self.models is not None:
55
- return
56
 
57
- logger.info("Starting model loading...")
58
-
59
- # Load FLUX components
 
 
 
 
 
60
  tokenizer = CLIPTokenizer.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer")
61
- text_encoder = CLIPTextModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder").to(self.dtype).to(self.device)
62
- text_encoder_two = T5EncoderModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder_2").to(self.dtype).to(self.device)
63
  tokenizer_two = T5TokenizerFast.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer_2")
64
-
65
- # Load VAE and transformer
66
- vae = AutoencoderKL.from_pretrained(self.MODEL_ID, subfolder="flux/vae").to(self.dtype).to(self.device)
67
- transformer = FluxTransformer2DModel.from_pretrained(self.MODEL_ID, subfolder="flux/transformer").to(self.dtype).to(self.device)
68
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(self.MODEL_ID, subfolder="flux/scheduler", shift=1)
69
 
70
- # Load Qwen2VL components
71
- qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(self.MODEL_ID, subfolder="qwen2-vl").to(self.dtype).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # Load connector
74
- connector = Qwen2Connector().to(self.dtype).to(self.device)
75
  connector_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/connector.pt"
76
  connector_state = torch.hub.load_state_dict_from_url(connector_path, map_location='cpu')
77
- # Move state dict to dtype before loading
78
  connector_state = {k: v.to(self.dtype) for k, v in connector_state.items()}
79
  connector.load_state_dict(connector_state)
80
  connector = connector.to(self.device)
81
-
82
- # Load T5 embedder
83
- self.t5_context_embedder = nn.Linear(4096, 3072).to(self.dtype).to(self.device)
84
  t5_embedder_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/t5_embedder.pt"
85
  t5_embedder_state = torch.hub.load_state_dict_from_url(t5_embedder_path, map_location='cpu')
86
- # Move state dict to dtype before loading
87
  t5_embedder_state = {k: v.to(self.dtype) for k, v in t5_embedder_state.items()}
88
  self.t5_context_embedder.load_state_dict(t5_embedder_state)
89
  self.t5_context_embedder = self.t5_context_embedder.to(self.device)
90
 
91
- # Set models to eval mode
92
  for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, self.t5_context_embedder]:
93
- model.requires_grad_(False)
94
- model.eval()
 
 
95
 
96
- logger.info("All models loaded successfully")
97
 
98
  self.models = {
99
  'tokenizer': tokenizer,
@@ -107,11 +153,11 @@ class FluxInterface:
107
  'connector': connector
108
  }
109
 
110
- # Initialize processor and pipeline
111
  self.qwen2vl_processor = AutoProcessor.from_pretrained(
112
- self.MODEL_ID,
113
  subfolder="qwen2-vl",
114
- min_pixels=256*28*28,
115
  max_pixels=256*28*28
116
  )
117
 
@@ -121,7 +167,13 @@ class FluxInterface:
121
  vae=vae,
122
  text_encoder=text_encoder,
123
  tokenizer=tokenizer,
124
- )
 
 
 
 
 
 
125
 
126
  def resize_image(self, img, max_pixels=1050000):
127
  if not isinstance(img, Image.Image):
 
11
  import logging
12
  import sys
13
 
14
+ import os
15
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512,expandable_segments:True'
16
+
17
  from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
18
 
19
  # 设置日志
 
54
  self.MODEL_ID = "Djrango/Qwen2vl-Flux"
55
 
56
  def load_models(self):
57
+ if self.models is not None:
58
+ return
59
 
60
+ import gc
61
+ torch.cuda.empty_cache()
62
+ gc.collect()
63
+
64
+ logger.info("Starting model loading...")
65
+
66
+ try:
67
+ # 1. 首先加载小型模型和tokenizer
68
  tokenizer = CLIPTokenizer.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer")
 
 
69
  tokenizer_two = T5TokenizerFast.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer_2")
 
 
 
 
70
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(self.MODEL_ID, subfolder="flux/scheduler", shift=1)
71
 
72
+ # 2. 加载并优化CLIP text encoder
73
+ text_encoder = CLIPTextModel.from_pretrained(
74
+ self.MODEL_ID,
75
+ subfolder="flux/text_encoder",
76
+ torch_dtype=self.dtype,
77
+ device_map="auto" # 让模型自动管理显存
78
+ )
79
+
80
+ # 3. 加载T5 encoder
81
+ text_encoder_two = T5EncoderModel.from_pretrained(
82
+ self.MODEL_ID,
83
+ subfolder="flux/text_encoder_2",
84
+ torch_dtype=self.dtype,
85
+ device_map="auto"
86
+ )
87
+
88
+ # 清理一次显存
89
+ torch.cuda.empty_cache()
90
+ gc.collect()
91
+
92
+ # 4. 加载VAE
93
+ vae = AutoencoderKL.from_pretrained(
94
+ self.MODEL_ID,
95
+ subfolder="flux/vae",
96
+ torch_dtype=self.dtype,
97
+ device_map="auto"
98
+ )
99
+
100
+ # 5. 加载Transformer
101
+ transformer = FluxTransformer2DModel.from_pretrained(
102
+ self.MODEL_ID,
103
+ subfolder="flux/transformer",
104
+ torch_dtype=self.dtype,
105
+ device_map="auto"
106
+ )
107
+
108
+ # 再次清理显存
109
+ torch.cuda.empty_cache()
110
+ gc.collect()
111
+
112
+ # 6. 加载Qwen2VL
113
+ qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
114
+ self.MODEL_ID,
115
+ subfolder="qwen2-vl",
116
+ torch_dtype=self.dtype,
117
+ device_map="auto"
118
+ )
119
 
120
+ # 7. 加载其他小组件
121
+ connector = Qwen2Connector().to(self.dtype)
122
  connector_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/connector.pt"
123
  connector_state = torch.hub.load_state_dict_from_url(connector_path, map_location='cpu')
 
124
  connector_state = {k: v.to(self.dtype) for k, v in connector_state.items()}
125
  connector.load_state_dict(connector_state)
126
  connector = connector.to(self.device)
127
+
128
+ self.t5_context_embedder = nn.Linear(4096, 3072).to(self.dtype)
 
129
  t5_embedder_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/t5_embedder.pt"
130
  t5_embedder_state = torch.hub.load_state_dict_from_url(t5_embedder_path, map_location='cpu')
 
131
  t5_embedder_state = {k: v.to(self.dtype) for k, v in t5_embedder_state.items()}
132
  self.t5_context_embedder.load_state_dict(t5_embedder_state)
133
  self.t5_context_embedder = self.t5_context_embedder.to(self.device)
134
 
135
+ # 设置eval模式和关闭梯度
136
  for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, self.t5_context_embedder]:
137
+ if hasattr(model, 'eval'):
138
+ model.eval()
139
+ if hasattr(model, 'requires_grad_'):
140
+ model.requires_grad_(False)
141
 
142
+ logger.info("Models loaded successfully")
143
 
144
  self.models = {
145
  'tokenizer': tokenizer,
 
153
  'connector': connector
154
  }
155
 
156
+ # 初始化processorpipeline
157
  self.qwen2vl_processor = AutoProcessor.from_pretrained(
158
+ self.MODEL_ID,
159
  subfolder="qwen2-vl",
160
+ min_pixels=256*28*28,
161
  max_pixels=256*28*28
162
  )
163
 
 
167
  vae=vae,
168
  text_encoder=text_encoder,
169
  tokenizer=tokenizer,
170
+ )
171
+
172
+ except Exception as e:
173
+ logger.error(f"Error loading models: {str(e)}")
174
+ torch.cuda.empty_cache()
175
+ gc.collect()
176
+ raise
177
 
178
  def resize_image(self, img, max_pixels=1050000):
179
  if not isinstance(img, Image.Image):