erwold commited on
Commit
bc9137b
·
1 Parent(s): 1d4e763

Initial Commit

Browse files
Files changed (1) hide show
  1. app.py +70 -117
app.py CHANGED
@@ -12,7 +12,9 @@ import logging
12
  import sys
13
 
14
  import os
15
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512,expandable_segments:True'
 
 
16
 
17
  from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
18
 
@@ -57,123 +59,74 @@ class FluxInterface:
57
  if self.models is not None:
58
  return
59
 
60
- import gc
61
- torch.cuda.empty_cache()
62
- gc.collect()
63
-
64
  logger.info("Starting model loading...")
65
-
66
- try:
67
- # 1. 首先加载小型模型和tokenizer
68
- tokenizer = CLIPTokenizer.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer")
69
- tokenizer_two = T5TokenizerFast.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer_2")
70
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(self.MODEL_ID, subfolder="flux/scheduler", shift=1)
71
-
72
- # 2. 加载并优化CLIP text encoder
73
- text_encoder = CLIPTextModel.from_pretrained(
74
- self.MODEL_ID,
75
- subfolder="flux/text_encoder",
76
- torch_dtype=self.dtype,
77
- device_map="auto" # 让模型自动管理显存
78
- )
79
-
80
- # 3. 加载T5 encoder
81
- text_encoder_two = T5EncoderModel.from_pretrained(
82
- self.MODEL_ID,
83
- subfolder="flux/text_encoder_2",
84
- torch_dtype=self.dtype,
85
- device_map="auto"
86
- )
87
-
88
- # 清理一次显存
89
- torch.cuda.empty_cache()
90
- gc.collect()
91
-
92
- # 4. 加载VAE
93
- vae = AutoencoderKL.from_pretrained(
94
- self.MODEL_ID,
95
- subfolder="flux/vae",
96
- torch_dtype=self.dtype,
97
- device_map="auto"
98
- )
99
-
100
- # 5. 加载Transformer
101
- transformer = FluxTransformer2DModel.from_pretrained(
102
- self.MODEL_ID,
103
- subfolder="flux/transformer",
104
- torch_dtype=self.dtype,
105
- device_map="auto"
106
- )
107
-
108
- # 再次清理显存
109
- torch.cuda.empty_cache()
110
- gc.collect()
111
-
112
- # 6. 加载Qwen2VL
113
- qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
114
- self.MODEL_ID,
115
- subfolder="qwen2-vl",
116
- torch_dtype=self.dtype,
117
- device_map="auto"
118
- )
119
-
120
- # 7. 加载其他小组件
121
- connector = Qwen2Connector().to(self.dtype)
122
- connector_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/connector.pt"
123
- connector_state = torch.hub.load_state_dict_from_url(connector_path, map_location='cpu')
124
- connector_state = {k: v.to(self.dtype) for k, v in connector_state.items()}
125
- connector.load_state_dict(connector_state)
126
- connector = connector.to(self.device)
127
-
128
- self.t5_context_embedder = nn.Linear(4096, 3072).to(self.dtype)
129
- t5_embedder_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/t5_embedder.pt"
130
- t5_embedder_state = torch.hub.load_state_dict_from_url(t5_embedder_path, map_location='cpu')
131
- t5_embedder_state = {k: v.to(self.dtype) for k, v in t5_embedder_state.items()}
132
- self.t5_context_embedder.load_state_dict(t5_embedder_state)
133
- self.t5_context_embedder = self.t5_context_embedder.to(self.device)
134
-
135
- # 设置eval模式和关闭梯度
136
- for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, self.t5_context_embedder]:
137
- if hasattr(model, 'eval'):
138
- model.eval()
139
- if hasattr(model, 'requires_grad_'):
140
- model.requires_grad_(False)
141
-
142
- logger.info("Models loaded successfully")
143
-
144
- self.models = {
145
- 'tokenizer': tokenizer,
146
- 'text_encoder': text_encoder,
147
- 'text_encoder_two': text_encoder_two,
148
- 'tokenizer_two': tokenizer_two,
149
- 'vae': vae,
150
- 'transformer': transformer,
151
- 'scheduler': scheduler,
152
- 'qwen2vl': qwen2vl,
153
- 'connector': connector
154
- }
155
-
156
- # 初始化processor和pipeline
157
- self.qwen2vl_processor = AutoProcessor.from_pretrained(
158
- self.MODEL_ID,
159
- subfolder="qwen2-vl",
160
- min_pixels=256*28*28,
161
- max_pixels=256*28*28
162
- )
163
-
164
- self.pipeline = FluxPipeline(
165
- transformer=transformer,
166
- scheduler=scheduler,
167
- vae=vae,
168
- text_encoder=text_encoder,
169
- tokenizer=tokenizer,
170
- )
171
-
172
- except Exception as e:
173
- logger.error(f"Error loading models: {str(e)}")
174
- torch.cuda.empty_cache()
175
- gc.collect()
176
- raise
177
 
178
  def resize_image(self, img, max_pixels=1050000):
179
  if not isinstance(img, Image.Image):
 
12
  import sys
13
 
14
  import os
15
+ # 设置环境变量,强制禁用 accelerate 的显存管理
16
+ os.environ["ACCELERATE_USE_MEMORY_EFFICIENT_ATTENTION"] = "false"
17
+ os.environ["ACCELERATE_DISABLE_MEMORY_EFFICIENT_ATTENTION"] = "1"
18
 
19
  from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
20
 
 
59
  if self.models is not None:
60
  return
61
 
 
 
 
 
62
  logger.info("Starting model loading...")
63
+
64
+ # Load FLUX components
65
+ tokenizer = CLIPTokenizer.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer")
66
+ text_encoder = CLIPTextModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder").to(self.dtype).to(self.device)
67
+ text_encoder_two = T5EncoderModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder_2").to(self.dtype).to(self.device)
68
+ tokenizer_two = T5TokenizerFast.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer_2")
69
+
70
+ # Load VAE and transformer
71
+ vae = AutoencoderKL.from_pretrained(self.MODEL_ID, subfolder="flux/vae").to(self.dtype).to(self.device)
72
+ transformer = FluxTransformer2DModel.from_pretrained(self.MODEL_ID, subfolder="flux/transformer").to(self.dtype).to(self.device)
73
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(self.MODEL_ID, subfolder="flux/scheduler", shift=1)
74
+
75
+ # Load Qwen2VL components
76
+ qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(self.MODEL_ID, subfolder="qwen2-vl").to(self.dtype).to(self.device)
77
+
78
+ # Load connector
79
+ connector = Qwen2Connector().to(self.dtype).to(self.device)
80
+ connector_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/connector.pt"
81
+ connector_state = torch.hub.load_state_dict_from_url(connector_path, map_location='cpu')
82
+ # Move state dict to dtype before loading
83
+ connector_state = {k: v.to(self.dtype) for k, v in connector_state.items()}
84
+ connector.load_state_dict(connector_state)
85
+ connector = connector.to(self.device)
86
+
87
+ # Load T5 embedder
88
+ self.t5_context_embedder = nn.Linear(4096, 3072).to(self.dtype).to(self.device)
89
+ t5_embedder_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/t5_embedder.pt"
90
+ t5_embedder_state = torch.hub.load_state_dict_from_url(t5_embedder_path, map_location='cpu')
91
+ # Move state dict to dtype before loading
92
+ t5_embedder_state = {k: v.to(self.dtype) for k, v in t5_embedder_state.items()}
93
+ self.t5_context_embedder.load_state_dict(t5_embedder_state)
94
+ self.t5_context_embedder = self.t5_context_embedder.to(self.device)
95
+
96
+ # Set models to eval mode
97
+ for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, self.t5_context_embedder]:
98
+ model.requires_grad_(False)
99
+ model.eval()
100
+
101
+ logger.info("All models loaded successfully")
102
+
103
+ self.models = {
104
+ 'tokenizer': tokenizer,
105
+ 'text_encoder': text_encoder,
106
+ 'text_encoder_two': text_encoder_two,
107
+ 'tokenizer_two': tokenizer_two,
108
+ 'vae': vae,
109
+ 'transformer': transformer,
110
+ 'scheduler': scheduler,
111
+ 'qwen2vl': qwen2vl,
112
+ 'connector': connector
113
+ }
114
+
115
+ # Initialize processor and pipeline
116
+ self.qwen2vl_processor = AutoProcessor.from_pretrained(
117
+ self.MODEL_ID,
118
+ subfolder="qwen2-vl",
119
+ min_pixels=256*28*28,
120
+ max_pixels=256*28*28
121
+ )
122
+
123
+ self.pipeline = FluxPipeline(
124
+ transformer=transformer,
125
+ scheduler=scheduler,
126
+ vae=vae,
127
+ text_encoder=text_encoder,
128
+ tokenizer=tokenizer,
129
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  def resize_image(self, img, max_pixels=1050000):
132
  if not isinstance(img, Image.Image):