erwold commited on
Commit
17d135f
·
1 Parent(s): 9217369

Initial Commit

Browse files
Files changed (1) hide show
  1. app.py +26 -17
app.py CHANGED
@@ -52,38 +52,47 @@ class FluxInterface:
52
  if self.models is not None:
53
  return
54
 
 
 
55
  # Load FLUX components
56
  tokenizer = CLIPTokenizer.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer")
57
- text_encoder = CLIPTextModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder")
58
- text_encoder_two = T5EncoderModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder_2")
59
  tokenizer_two = T5TokenizerFast.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer_2")
60
 
61
- # Load VAE and transformer from flux folder
62
- vae = AutoencoderKL.from_pretrained(self.MODEL_ID, subfolder="flux/vae")
63
- transformer = FluxTransformer2DModel.from_pretrained(self.MODEL_ID, subfolder="flux/transformer")
64
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(self.MODEL_ID, subfolder="flux/scheduler", shift=1)
65
 
66
- # Load Qwen2VL components from qwen2-vl folder
67
- qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(self.MODEL_ID, subfolder="qwen2-vl")
68
 
69
- # Load connector and t5 embedder from qwen2-vl folder
70
- connector = Qwen2Connector()
71
  connector_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/connector.pt"
72
- connector_state = torch.hub.load_state_dict_from_url(connector_path, map_location=self.device)
 
 
73
  connector.load_state_dict(connector_state)
 
74
 
75
  # Load T5 embedder
76
- self.t5_context_embedder = nn.Linear(4096, 3072)
77
  t5_embedder_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/t5_embedder.pt"
78
- t5_embedder_state = torch.hub.load_state_dict_from_url(t5_embedder_path, map_location=self.device)
 
 
79
  self.t5_context_embedder.load_state_dict(t5_embedder_state)
 
80
 
81
- # Move models to device and set dtype
82
- models = [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, self.t5_context_embedder]
83
- for model in models:
84
- model.to(self.device).to(self.dtype)
85
  model.eval()
86
 
 
 
87
  self.models = {
88
  'tokenizer': tokenizer,
89
  'text_encoder': text_encoder,
@@ -110,7 +119,7 @@ class FluxInterface:
110
  vae=vae,
111
  text_encoder=text_encoder,
112
  tokenizer=tokenizer,
113
- )
114
 
115
  def resize_image(self, img, max_pixels=1050000):
116
  if not isinstance(img, Image.Image):
 
52
  if self.models is not None:
53
  return
54
 
55
+ logger.info("Starting model loading...")
56
+
57
  # Load FLUX components
58
  tokenizer = CLIPTokenizer.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer")
59
+ text_encoder = CLIPTextModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder").to(self.dtype).to(self.device)
60
+ text_encoder_two = T5EncoderModel.from_pretrained(self.MODEL_ID, subfolder="flux/text_encoder_2").to(self.dtype).to(self.device)
61
  tokenizer_two = T5TokenizerFast.from_pretrained(self.MODEL_ID, subfolder="flux/tokenizer_2")
62
 
63
+ # Load VAE and transformer
64
+ vae = AutoencoderKL.from_pretrained(self.MODEL_ID, subfolder="flux/vae").to(self.dtype).to(self.device)
65
+ transformer = FluxTransformer2DModel.from_pretrained(self.MODEL_ID, subfolder="flux/transformer").to(self.dtype).to(self.device)
66
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(self.MODEL_ID, subfolder="flux/scheduler", shift=1)
67
 
68
+ # Load Qwen2VL components
69
+ qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(self.MODEL_ID, subfolder="qwen2-vl").to(self.dtype).to(self.device)
70
 
71
+ # Load connector
72
+ connector = Qwen2Connector().to(self.dtype).to(self.device)
73
  connector_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/connector.pt"
74
+ connector_state = torch.hub.load_state_dict_from_url(connector_path, map_location='cpu')
75
+ # Move state dict to dtype before loading
76
+ connector_state = {k: v.to(self.dtype) for k, v in connector_state.items()}
77
  connector.load_state_dict(connector_state)
78
+ connector = connector.to(self.device)
79
 
80
  # Load T5 embedder
81
+ self.t5_context_embedder = nn.Linear(4096, 3072).to(self.dtype).to(self.device)
82
  t5_embedder_path = f"https://huggingface.co/{self.MODEL_ID}/resolve/main/qwen2-vl/t5_embedder.pt"
83
+ t5_embedder_state = torch.hub.load_state_dict_from_url(t5_embedder_path, map_location='cpu')
84
+ # Move state dict to dtype before loading
85
+ t5_embedder_state = {k: v.to(self.dtype) for k, v in t5_embedder_state.items()}
86
  self.t5_context_embedder.load_state_dict(t5_embedder_state)
87
+ self.t5_context_embedder = self.t5_context_embedder.to(self.device)
88
 
89
+ # Set models to eval mode
90
+ for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, self.t5_context_embedder]:
91
+ model.requires_grad_(False)
 
92
  model.eval()
93
 
94
+ logger.info("All models loaded successfully")
95
+
96
  self.models = {
97
  'tokenizer': tokenizer,
98
  'text_encoder': text_encoder,
 
119
  vae=vae,
120
  text_encoder=text_encoder,
121
  tokenizer=tokenizer,
122
+ )
123
 
124
  def resize_image(self, img, max_pixels=1050000):
125
  if not isinstance(img, Image.Image):