File size: 15,741 Bytes
2df6ae7
 
 
4e8d40c
2df6ae7
 
3fb4272
2df6ae7
b97d514
cdf6b7a
8c9921c
 
b7f5d29
609f5cd
2df6ae7
4e8d40c
2df6ae7
 
 
 
 
4e8d40c
 
 
2df6ae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b4a69c
b7f5d29
8b8ed09
8b4a69c
8d2995d
 
b7f5d29
4e8d40c
 
 
 
 
 
8d2995d
2df6ae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baee908
 
 
 
 
 
 
 
b7f5d29
4e8d40c
8c9921c
 
 
 
 
 
 
 
 
 
 
 
 
bab295a
 
 
 
 
 
4e8d40c
 
 
 
 
 
 
 
 
 
 
 
 
 
bab295a
 
 
 
 
 
 
 
b7f5d29
 
 
 
 
 
060887f
b7f5d29
 
 
 
 
060887f
 
 
 
b7f5d29
060887f
03f0c45
b7f5d29
4e8d40c
 
 
187fe3b
2df6ae7
 
 
 
 
 
4e8d40c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2df6ae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28747d9
8b8ed09
 
8d2995d
8b8ed09
28747d9
4e8d40c
 
 
 
 
2df6ae7
 
 
 
 
 
 
 
 
 
 
8b4a69c
80253b4
 
 
 
 
 
 
 
8b4a69c
cdf6b7a
97a4a3d
4e8d40c
 
 
 
 
97a4a3d
4e8d40c
97a4a3d
4e8d40c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2df6ae7
 
 
 
 
4e8d40c
 
cdf6b7a
 
 
 
 
2df6ae7
 
 
4e8d40c
2df6ae7
4e8d40c
2df6ae7
f510634
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
from dataclasses import dataclass
from typing import Dict, Any, Optional
import base64
import asyncio
import logging
import random
import traceback
import torch
import os
import gc

# note: there is no HunyuanImageToVideoPipeline yet in Diffusers
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, FasterCacheConfig
from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
from varnish import Varnish
from varnish.utils import is_truthy, process_input_image

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Check environment variable for pipeline support
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))

@dataclass
class GenerationConfig:
    """Configuration for video generation"""
    # Content settings
    prompt: str
    negative_prompt: str = ""
    
    # Model settings
    num_frames: int = 49  # Should be 4k + 1 format
    height: int = 320
    width: int = 576
    num_inference_steps: int = 50
    guidance_scale: float = 7.0
    
    # Reproducibility
    seed: int = -1
    
    # Varnish post-processing settings
    fps: int = 30
    double_num_frames: bool = False
    super_resolution: bool = False
    grain_amount: float = 0.0
    quality: int = 18  # CRF scale (0-51, lower is better)
    
    # Audio settings
    enable_audio: bool = False
    audio_prompt: str = ""
    audio_negative_prompt: str = "voices, voice, talking, speaking, speech"

    # TeaCache settings
    enable_teacache: bool = False
    teacache_threshold: float = 0.15 # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup)

    
    # Enhance-A-Video settings
    enable_enhance_a_video: bool = False
    enhance_a_video_weight: float = 5.0

    # LoRA settings
    lora_model_name: str = ""  # HuggingFace repo ID or path to LoRA model
    lora_model_weight_file: str = ""  # Specific weight file to load from the LoRA model
    lora_model_trigger: str = ""  # Optional trigger word to prepend to the prompt

    def validate_and_adjust(self) -> 'GenerationConfig':
        """Validate and adjust parameters"""
        # Ensure num_frames follows 4k + 1 format
        k = (self.num_frames - 1) // 4
        self.num_frames = (k * 4) + 1
        
        # Set random seed if not specified
        if self.seed == -1:
            self.seed = random.randint(0, 2**32 - 1)
            
        return self

class EndpointHandler:
    """Handles video generation requests using HunyuanVideo and Varnish"""
    
    def __init__(self, path: str = ""):
        """Initialize handler with models
        
        Args:
            path: Path to model weights
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

          
        # Initialize transformer with Enhance-A-Video injection first
        transformer = HunyuanVideoTransformer3DModel.from_pretrained(
            path,
            subfolder="transformer",
            torch_dtype=torch.bfloat16
        )
        
        if support_image_prompt:
            raise Exception("Please use a version of Diffusers that supports HunyuanImageToVideoPipeline")
            # # Initialize image-to-video pipeline
            # self.image_to_video = HunyuanImageToVideoPipeline.from_pretrained(
            #     path,
            #     transformer=transformer,
            #     torch_dtype=torch.float16,
            # ).to(self.device)
            #
            # # Initialize components in appropriate precision
            # self.image_to_video.text_encoder = self.image_to_video.text_encoder.half()
            # self.image_to_video.text_encoder_2 = self.image_to_video.text_encoder_2.half()
            # self.image_to_video.transformer = self.image_to_video.transformer.to(torch.bfloat16)
            # self.image_to_video.vae = self.image_to_video.vae.half()
                        
            # apply_enhance_a_video(self.image_to_video.transformer, EnhanceAVideoConfig(
            #     weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
            #     num_frames_callback=lambda: (config.num_frames - 1),
            #     _attention_type=1
            # ))
        else:
            # Initialize text-to-video pipeline
            self.text_to_video = HunyuanVideoPipeline.from_pretrained(
                path,
                transformer=transformer,
                torch_dtype=torch.float16,
            ).to(self.device)

            # Initialize components in appropriate precision
            self.text_to_video.text_encoder = self.text_to_video.text_encoder.half()
            self.text_to_video.text_encoder_2 = self.text_to_video.text_encoder_2.half()
            self.text_to_video.transformer = self.text_to_video.transformer.to(torch.bfloat16)
            self.text_to_video.vae = self.text_to_video.vae.half()


            # apply_enhance_a_video(self.text_to_video.transformer, EnhanceAVideoConfig(
            #     # weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
            #     weight=config.enhance_a_video_weight,
            #     num_frames_callback=lambda: (config.num_frames - 1),
            #     _attention_type=1
            # ))
            
            # enable FasterCache

            # those values are coming from here:
            # https://github.com/huggingface/diffusers/pull/10163/files#diff-777f4ee62cb325371233a450e0f6cc0ba357a3fade2ec2dea912260b4f8d08ceR67-R74
            
            faster_cache_config = FasterCacheConfig(
                current_timestep_callback=lambda: self.text_to_video.current_timestep,
                spatial_attention_block_skip_range=2,
                spatial_attention_timestep_skip_range=(-1, 901),
                unconditional_batch_skip_range=2,
                attention_weight_callback=lambda _: 0.5,
                is_guidance_distilled=True,

                # do we need to uncomment those?
                #unconditional_batch_timestep_skip_range=(-1, 901),
                #tensor_format="BFCHW",
            )
            
            #self.text_to_video.transformer.enable_cache(faster_cache_config)
 

        # Initialize LoRA tracking
        self._current_lora_model = None

        # Initialize Varnish for post-processing
        self.varnish = Varnish(
            device=self.device,
            model_base_dir="/repository/varnish"
        )

    async def process_frames(
        self,
        frames: torch.Tensor,
        config: GenerationConfig
    ) -> tuple[str, dict]:
        """Post-process generated frames using Varnish
        
        Args:
            frames: Generated video frames tensor
            config: Generation configuration
            
        Returns:
            Tuple of (video data URI, metadata dictionary)
        """
        try:
            # Process video with Varnish
            result = await self.varnish(
                input_data=frames,
                fps=config.fps,
                double_num_frames=config.double_num_frames,
                super_resolution=config.super_resolution,
                grain_amount=config.grain_amount,
                enable_audio=config.enable_audio,
                audio_prompt=config.audio_prompt,
                audio_negative_prompt=config.audio_negative_prompt
            )
            
            # Convert to data URI
            video_uri = await result.write(type="data-uri", quality=config.quality)
            
            # Collect metadata
            metadata = {
                "width": result.metadata.width,
                "height": result.metadata.height,
                "num_frames": result.metadata.frame_count,
                "fps": result.metadata.fps,
                "duration": result.metadata.duration,
                "seed": config.seed,
                "enable_teacache": config.enable_teacache,
                "teacache_threshold": config.teacache_threshold if config.enable_teacache else 0,
                "enable_enhance_a_video": config.enable_enhance_a_video,
                "enhance_a_video_weight": config.enhance_a_video_weight if config.enable_enhance_a_video else 0,
            }
            
            return video_uri, metadata
            
        except Exception as e:
            logger.error(f"Error in process_frames: {str(e)}")
            raise RuntimeError(f"Failed to process frames: {str(e)}")

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Process video generation requests
        
        Args:
            data: Request data containing:
                - inputs (str): Prompt for video generation
                - parameters (dict): Generation parameters
                
        Returns:
            Dictionary containing:
                - video: Base64 encoded MP4 data URI
                - content-type: MIME type
                - metadata: Generation metadata
        """
        # Extract inputs
        inputs = data.pop("inputs", data)
        if isinstance(inputs, dict):
            prompt = inputs.get("prompt", "")
        else:
            prompt = inputs
            
        params = data.get("parameters", {})
        
        # Create and validate config
        config = GenerationConfig(
            prompt=prompt,
            negative_prompt=params.get("negative_prompt", ""),
            num_frames=params.get("num_frames", 49),
            height=params.get("height", 320),
            width=params.get("width", 576),
            num_inference_steps=params.get("num_inference_steps", 50),
            guidance_scale=params.get("guidance_scale", 7.0),
            seed=params.get("seed", -1),
            fps=params.get("fps", 30),
            double_num_frames=params.get("double_num_frames", False),
            super_resolution=params.get("super_resolution", False),
            grain_amount=params.get("grain_amount", 0.0),
            quality=params.get("quality", 18),
            enable_audio=params.get("enable_audio", False),
            audio_prompt=params.get("audio_prompt", ""),
            audio_negative_prompt=params.get("audio_negative_prompt", "voices, voice, talking, speaking, speech"),
            enable_teacache=params.get("enable_teacache", False),

            # values: 0 (original), 0.1 (1.6x speedup), 0.15 (2.1x speedup).
            teacache_threshold=params.get("teacache_threshold", 0.15),
            
            enable_enhance_a_video=params.get("enable_enhance_a_video", False),
            enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0),

            lora_model_name=params.get("lora_model_name", ""),
            lora_model_weight_file=params.get("lora_model_weight_file", ""),
            lora_model_trigger=params.get("lora_model_trigger", ""),
        ).validate_and_adjust()

        try:
            # Set random seeds
            if config.seed != -1:
                torch.manual_seed(config.seed)
                random.seed(config.seed)
                generator = torch.Generator(device=self.device).manual_seed(config.seed)
            else:
                generator = None

            # Configure TeaCache
            #if config.enable_teacache:
            #    enable_teacache(
            #        self.pipeline.transformer,
            #        num_inference_steps=config.num_inference_steps,
            #        rel_l1_thresh=config.teacache_threshold
            #    )
            #else:
            #    disable_teacache(self.pipeline.transformer)

            with torch.amp.autocast_mode.autocast('cuda', torch.bfloat16), torch.no_grad(), torch.inference_mode():

                # Prepare generation parameters
                generation_kwargs = {
                    "prompt": config.prompt,

                    
                    # Failed to generate video: HunyuanVideoPipeline.__call__() got an unexpected keyword argument 'negative_prompt'
                    #"negative_prompt": config.negative_prompt,
                    
                    "num_frames": config.num_frames,
                    "height": config.height,
                    "width": config.width,
                    "num_inference_steps": config.num_inference_steps,
                    "guidance_scale": config.guidance_scale,
                    "generator": generator,
                    "output_type": "pt",
                }

                # Handle LoRA loading/unloading
                if hasattr(self, '_current_lora_model'):
                    if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file):
                        # Unload previous LoRA if it exists and is different
                        if support_image_prompt and hasattr(self.image_to_video, 'unload_lora_weights'):
                            self.image_to_video.unload_lora_weights()
                        else:
                            if hasattr(self.text_to_video, 'unload_lora_weights'):
                                self.text_to_video.unload_lora_weights()

                if config.lora_model_name:
                    # Load new LoRA
                    if support_image_prompt and hasattr(self.image_to_video, 'load_lora_weights'):
                        self.image_to_video.load_lora_weights(
                            config.lora_model_name,
                            weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
                            token=hf_token,
                        )
                    else:
                        if hasattr(self.text_to_video, 'load_lora_weights'):
                            self.text_to_video.load_lora_weights(
                                config.lora_model_name,
                                weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
                                token=hf_token,
                            )
                    self._current_lora_model = (config.lora_model_name, config.lora_model_weight_file)

                # Modify prompt if trigger word is provided
                if config.lora_model_trigger:
                    generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"

                # Check if image-to-video generation is requested
                if support_image_prompt and input_image:
                    processed_image = process_input_image(
                        input_image,
                        config.width,
                        config.height,
                        config.input_image_quality,
                    )
                    generation_kwargs["image"] = processed_image
                    frames = self.image_to_video(**generation_kwargs).frames
                else:
                    frames = self.text_to_video(**generation_kwargs).frames

                
                try:
                    loop = asyncio.get_event_loop()
                except RuntimeError:
                    loop = asyncio.new_event_loop()
                    asyncio.set_event_loop(loop)
                
                video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config))
                
                torch.cuda.empty_cache()
                torch.cuda.reset_peak_memory_stats()
                gc.collect()
                
                return {
                    "video": video_uri,
                    "content-type": "video/mp4",
                    "metadata": metadata
                }
                
        except Exception as e:
            message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
            logger.error(message)
            raise RuntimeError(message)