tight-inversion commited on
Commit
e818c21
·
1 Parent(s): 9499d13

Move to models to GPU inside generate_image

Browse files
Files changed (1) hide show
  1. app.py +11 -16
app.py CHANGED
@@ -20,21 +20,11 @@ from pulid.pipeline_flux import PuLIDPipeline
20
  from pulid.utils import resize_numpy_image_long, seed_everything
21
 
22
 
23
- # def get_models(name: str, device: torch.device, offload: bool, fp8: bool):
24
- # t5 = load_t5(device, max_length=128)
25
- # clip = load_clip(device)
26
- # model = load_flow_model(name, device="cpu" if offload else device)
27
- # model.eval()
28
- # ae = load_ae(name, device=device)
29
- # return model, ae, t5, clip
30
- class Tmp:
31
- def __init__(self):
32
- self.max_length = 128
33
-
34
  def get_models(name: str, device: torch.device, offload: bool, fp8: bool):
35
- t5 = Tmp()
36
- clip = None
37
- model = None
 
38
  ae = load_ae(name, device=device)
39
  return model, ae, t5, clip
40
 
@@ -50,8 +40,8 @@ class FluxGenerator:
50
  offload=self.offload,
51
  fp8=args.fp8,
52
  )
53
- # self.pulid_model = PuLIDPipeline(self.model, device='cuda', weight_dtype=torch.bfloat16)
54
- # self.pulid_model.load_pretrain(args.pretrained_model)
55
 
56
  @spaces.GPU(duration=30)
57
  @torch.inference_mode()
@@ -82,6 +72,10 @@ class FluxGenerator:
82
  """
83
  Core function that performs the image generation.
84
  """
 
 
 
 
85
  self.t5.max_length = max_sequence_length
86
 
87
  # If seed == -1, random
@@ -266,6 +260,7 @@ class FluxGenerator:
266
  )
267
 
268
  # Offload flux model, load auto-decoder
 
269
  if self.offload:
270
  self.model.cpu()
271
  torch.cuda.empty_cache()
 
20
  from pulid.utils import resize_numpy_image_long, seed_everything
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
23
  def get_models(name: str, device: torch.device, offload: bool, fp8: bool):
24
+ t5 = load_t5(device, max_length=128)
25
+ clip = load_clip(device)
26
+ model = load_flow_model(name, device="cpu" if offload else device)
27
+ model.eval()
28
  ae = load_ae(name, device=device)
29
  return model, ae, t5, clip
30
 
 
40
  offload=self.offload,
41
  fp8=args.fp8,
42
  )
43
+ self.pulid_model = PuLIDPipeline(self.model, device='cuda', weight_dtype=torch.bfloat16)
44
+ self.pulid_model.load_pretrain(args.pretrained_model)
45
 
46
  @spaces.GPU(duration=30)
47
  @torch.inference_mode()
 
72
  """
73
  Core function that performs the image generation.
74
  """
75
+ self.t5.to(self.device)
76
+ self.clip_model.to(self.device)
77
+ self.ae.to(self.device)
78
+ self.model.to(self.device)
79
  self.t5.max_length = max_sequence_length
80
 
81
  # If seed == -1, random
 
260
  )
261
 
262
  # Offload flux model, load auto-decoder
263
+ self.ae.decoder.to(self.device)
264
  if self.offload:
265
  self.model.cpu()
266
  torch.cuda.empty_cache()