jbilcke-hf HF staff commited on
Commit
baee908
·
verified ·
1 Parent(s): f510634

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -6
handler.py CHANGED
@@ -75,13 +75,29 @@ class EndpointHandler:
75
  path: Path to model weights
76
  """
77
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
78
 
79
- # Initialize HunyuanVideo pipeline
80
  self.pipeline = HunyuanVideoPipeline.from_pretrained(
81
  path,
 
82
  torch_dtype=torch.float16,
83
  ).to(self.device)
84
 
 
 
 
 
 
 
85
  # Initialize text encoders in float16
86
  self.pipeline.text_encoder = self.pipeline.text_encoder.half()
87
  self.pipeline.text_encoder_2 = self.pipeline.text_encoder_2.half()
@@ -92,11 +108,6 @@ class EndpointHandler:
92
  # Initialize VAE in float16
93
  self.pipeline.vae = self.pipeline.vae.half()
94
 
95
- # Initialize Enhance-A-Video
96
- inject_enhance_for_hunyuanvideo(self.pipeline.transformer)
97
- set_enhance_weight(4.0) # default weight
98
- enable_enhance()
99
-
100
  # Initialize Varnish for post-processing
101
  self.varnish = Varnish(
102
  device=self.device,
 
75
  path: Path to model weights
76
  """
77
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
78
+
79
+
80
+ # Initialize transformer with Enhance-A-Video injection first
81
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
82
+ path,
83
+ subfolder="transformer",
84
+ torch_dtype=torch.bfloat16
85
+ )
86
+ inject_enhance_for_hunyuanvideo(transformer)
87
 
88
+ # Initialize HunyuanVideo pipeline with the enhanced transformer
89
  self.pipeline = HunyuanVideoPipeline.from_pretrained(
90
  path,
91
+ transformer=transformer,
92
  torch_dtype=torch.float16,
93
  ).to(self.device)
94
 
95
+ # Initialize HunyuanVideo pipeline
96
+ self.pipeline = HunyuanVideoPipeline.from_pretrained(
97
+ path,
98
+ torch_dtype=torch.float16,
99
+ ).to(self.device)
100
+
101
  # Initialize text encoders in float16
102
  self.pipeline.text_encoder = self.pipeline.text_encoder.half()
103
  self.pipeline.text_encoder_2 = self.pipeline.text_encoder_2.half()
 
108
  # Initialize VAE in float16
109
  self.pipeline.vae = self.pipeline.vae.half()
110
 
 
 
 
 
 
111
  # Initialize Varnish for post-processing
112
  self.varnish = Varnish(
113
  device=self.device,