File size: 24,733 Bytes
be2df75
 
2039f5a
be2df75
 
132e8c4
 
ea52235
be2df75
b5d7f4c
a5265d3
6789b6e
1a6f91c
be2df75
 
4585b1e
be2df75
327a738
1b3852a
2fa2e84
92aeb25
e349e43
 
 
 
132e8c4
fb91acd
 
 
 
be2df75
f68983c
a36a3bb
017b989
be2df75
92aeb25
 
458a627
9d84818
 
 
2039f5a
1fd04e8
 
5e8a3d7
0cf5bce
1fd04e8
a36a3bb
 
 
1fd04e8
0f751a9
 
 
 
25b26b8
1fd04e8
 
f68983c
 
60c8ea5
 
 
1fd04e8
60c8ea5
1fd04e8
 
be2df75
 
1fd04e8
f68983c
017b989
 
1fd04e8
017b989
be2df75
d585ae1
 
 
 
 
d979b5a
 
 
 
 
 
 
 
 
16f1cf0
 
 
d198b72
3cf299c
 
 
1dee97c
0b910bc
 
 
 
 
d979b5a
be2df75
 
f68983c
 
 
 
996f8c3
f68983c
 
 
 
 
a36a3bb
 
f68983c
a36a3bb
 
 
be2df75
 
 
f68983c
 
be2df75
 
 
f68983c
be2df75
85f39ae
132e8c4
be2df75
d35cde0
be2df75
 
 
132e8c4
be2df75
132e8c4
be2df75
 
92aeb25
 
 
 
 
 
4585b1e
92aeb25
 
 
 
 
 
132e8c4
0b910bc
 
 
92aeb25
 
 
 
 
 
 
 
 
 
 
3cf299c
be2df75
f6dd4f3
91a9d80
58774ec
5008035
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91a9d80
f6dd4f3
 
a1227fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be2df75
 
 
 
 
 
ef15707
 
be2df75
 
ef15707
 
be2df75
ef15707
99df0e2
 
 
0e79ca6
 
1fd04e8
fc5df44
26effe4
d585ae1
 
 
99df0e2
 
 
d979b5a
99df0e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132e8c4
1fd04e8
132e8c4
be2df75
132e8c4
 
be2df75
1fd04e8
 
 
 
 
 
25b26b8
1fd04e8
60c8ea5
1fd04e8
 
60c8ea5
1fd04e8
 
 
d585ae1
 
 
d979b5a
16f1cf0
d198b72
6f20d2d
9a3f2e6
0b910bc
 
 
132e8c4
be2df75
 
 
 
132e8c4
1fd04e8
 
 
 
 
 
61690e3
201ac66
61690e3
ccbe184
 
 
132e8c4
be2df75
 
1fd04e8
c98ab1a
1fd04e8
 
 
 
 
25b26b8
b71870b
1fd04e8
 
 
 
1976186
1fd04e8
 
 
 
 
 
d585ae1
 
 
d979b5a
16f1cf0
 
 
d198b72
 
 
3cf299c
 
 
 
9a3f2e6
0b910bc
 
 
 
 
be2df75
1fd04e8
ccbe184
 
e349e43
132e8c4
a2c4d6f
be2df75
 
 
73b7f0a
c6a43f8
 
3cf299c
 
306dec1
 
 
b5d7f4c
51d9ba1
ef15707
1fd04e8
6c74560
 
1fd04e8
 
b212177
6c74560
 
 
 
1fd04e8
 
b5d7f4c
d979b5a
 
2bc5cf0
 
 
 
 
ef15707
827505c
 
0b910bc
 
 
 
 
 
 
92aeb25
 
0b910bc
 
 
 
 
 
 
fb91acd
 
0b910bc
92aeb25
0b910bc
 
fb91acd
 
0b910bc
 
 
 
 
 
92aeb25
ce3b7e3
 
 
 
 
 
 
 
 
 
be2df75
92aeb25
a1227fd
458a627
 
 
25b26b8
61ff213
458a627
 
e996891
 
f6dd4f3
1a6f91c
a1227fd
e996891
 
f6dd4f3
28cbc54
be2df75
28cbc54
 
 
 
d585ae1
 
 
1a6f91c
be2df75
f6dd4f3
 
1a6f91c
132e8c4
 
a5265d3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
from dataclasses import dataclass
from pathlib import Path
import pathlib
from typing import Dict, Any, Optional, Tuple
import asyncio
import base64
import io
import pprint
import logging
import random
import traceback
import os
import numpy as np
import torch
from diffusers import LTXPipeline, LTXImageToVideoPipeline
from diffusers.hooks import apply_enhance_a_video, EnhanceAVideoConfig
from PIL import Image

from teacache import TeaCacheConfig, enable_teacache, disable_teacache
from varnish import Varnish
from varnish.utils import is_truthy, process_input_image

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# Get token from environment
hf_token = os.getenv("HF_API_TOKEN")

# Constraints
MAX_LARGE_SIDE = 1280
MAX_SMALL_SIDE = 768 # should be 720 but it must be divisible by 32
MAX_FRAMES = (8 * 21) + 1 # visual glitches appear after about 169 frames, so we cap it

# Check environment variable for pipeline support
support_image_prompt = is_truthy(os.getenv("SUPPORT_INPUT_IMAGE_PROMPT"))

@dataclass
class GenerationConfig:
    """Configuration for video generation"""

    # general content settings
    prompt: str = ""
    negative_prompt: str = "saturated, highlight, overexposed, highlighted, overlit, shaking, too bright, worst quality, inconsistent motion, blurry, jittery, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles, lowres"

    # video model settings (will be used during generation of the initial raw video clip)
    # we use small values to make things a bit faster
    width: int = 768
    height: int = 416


    # this is a hack to fool LTX-Video into believing our input image is an actual video frame with poor encoding quality
    # after a quick benchmark using the value 70 seems like a sweet spot
    input_image_quality: int = 70

    # users may tend to always set this to the max, to get as much useable content as possible (which is MAX_FRAMES ie. 257).
    # The value must be a multiple of 8, plus 1 frame.
    # visual glitches appear after about 169 frames, so we don't need more actually
    num_frames: int = (8 * 14) + 1

    # values between 3.0 and 4.0 are nice
    guidance_scale: float = 3.5
    
    num_inference_steps: int = 50

    # reproducible generation settings
    seed: int = -1  # -1 means random seed

    # varnish settings (will be used for post-processing after the raw video clip has been generated
    fps: int = 30 # FPS of the final video (only applied at the the very end, when converting to mp4)
    double_num_frames: bool = False # if True, the number of frames will be multiplied by 2 using RIFE
    super_resolution: bool = False # if True, the resolution will be multiplied by 2 using Real_ESRGAN
    
    grain_amount: float = 0.0 # be careful, adding film grian can negatively impact video compression

    # audio settings
    enable_audio: bool = False  # Whether to generate audio
    audio_prompt: str = ""  # Text prompt for audio generation
    audio_negative_prompt: str = "voices, voice, talking, speaking, speech" # Negative prompt for audio generation

    # The range of the CRF scale is 0–51, where:
    # 0 is lossless (for 8 bit only, for 10 bit use -qp 0)
    # 23 is the default
    # 51 is worst quality possible
    # A lower value generally leads to higher quality, and a subjectively sane range is 17–28.
    # Consider 17 or 18 to be visually lossless or nearly so;
    # it should look the same or nearly the same as the input but it isn't technically lossless.
    # The range is exponential, so increasing the CRF value +6 results in roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
    quality: int = 18

    # TeaCache settings
    enable_teacache: bool = True
    teacache_threshold: float = 0.05 # values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).

    # Enhance-A-Video settings
    enable_enhance_a_video: bool = True
    enhance_a_video_weight: float = 5.0

    # LoRA settings
    lora_model_name: str = ""  # HuggingFace repo ID or path to LoRA model
    lora_model_weight_file: str = ""  # Specific weight file to load from the LoRA model
    lora_model_trigger: str = ""  # Optional trigger word to prepend to the prompt
    
    def validate_and_adjust(self) -> 'GenerationConfig':
        """Validate and adjust parameters to meet constraints"""
        # First check if it's one of our explicitly allowed resolutions
        if not ((self.width == MAX_LARGE_SIDE and self.height == MAX_SMALL_SIDE) or 
                (self.width == MAX_SMALL_SIDE and self.height == MAX_LARGE_SIDE)):
            # For other resolutions, ensure total pixels don't exceed max
            MAX_TOTAL_PIXELS = MAX_SMALL_SIDE * MAX_LARGE_SIDE # or 921600 = 1280 * 720
            
            # If total pixels exceed maximum, scale down proportionally
            total_pixels = self.width * self.height
            if total_pixels > MAX_TOTAL_PIXELS:
                scale = (MAX_TOTAL_PIXELS / total_pixels) ** 0.5
                self.width = max(128, min(MAX_LARGE_SIDE, round(self.width * scale / 32) * 32))
                self.height = max(128, min(MAX_LARGE_SIDE, round(self.height * scale / 32) * 32))
            else:
                # Round dimensions to nearest multiple of 32
                self.width = max(128, min(MAX_LARGE_SIDE, round(self.width / 32) * 32))
                self.height = max(128, min(MAX_LARGE_SIDE, round(self.height / 32) * 32))
        
        # Adjust number of frames to be in format 8k + 1
        k = (self.num_frames - 1) // 8
        self.num_frames = min((k * 8) + 1, MAX_FRAMES)
    
        # Set random seed if not specified
        if self.seed == -1:
            self.seed = random.randint(0, 2**32 - 1)
    
        return self

class EndpointHandler:
    """Handles video generation requests using LTX models and Varnish post-processing"""
    
    def __init__(self, model_path: str = ""):
        """Initialize the handler with LTX models and Varnish

        Args:
            model_path: Path to LTX model weights
        """
        # Enable TF32 for potential speedup on Ampere GPUs
        #torch.backends.cuda.matmul.allow_tf32 = True

        if support_image_prompt:
            self.image_to_video = LTXImageToVideoPipeline.from_pretrained(
                model_path,
                torch_dtype=torch.bfloat16
            ).to("cuda")

        else:
            # Initialize models with bfloat16 precision
            self.text_to_video = LTXPipeline.from_pretrained(
                model_path,
                torch_dtype=torch.bfloat16
            ).to("cuda")

        # Initialize LoRA tracking
        self._current_lora_model = None

        #if support_image_prompt:
        #    # Enable CPU offload for memory efficiency
        #    self.image_to_video.enable_model_cpu_offload()
        #    # Inject enhance-a-video functionality
        #    inject_enhance_for_ltx(self.image_to_video.transformer)
        #else:
        #    # Enable CPU offload for memory efficiency
        #    self.text_to_video.enable_model_cpu_offload()
        #    # Inject enhance-a-video functionality
        #    inject_enhance_for_ltx(self.text_to_video.transformer)
         

        # Initialize Varnish for post-processing
        self.varnish = Varnish(
            device="cuda",
            model_base_dir="/repository/varnish",

            # there is currently a bug with MMAudio and/or torch and/or the weight format and/or version..
            # not sure how to fix that.. :/
            #
            # it says:
            #   File "dist-packages/varnish.py", line 152, in __init__
            #     self._setup_mmaudio()
            #   File "dist-packages/varnish/varnish.py", line 165, in _setup_mmaudio
            #     net.load_weights(torch.load(model.model_path, map_location=self.device, weights_only=False))
            #                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            #   File "dist-packages/torch/serialization.py", line 1384, in load
            #     return _legacy_load(
            #            ^^^^^^^^^^^^^
            #   File "dist-packages/torch/serialization.py", line 1628, in _legacy_load
            #     magic_number = pickle_module.load(f, **pickle_load_args)
            #                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            # _pickle.UnpicklingError: invalid load key, '<'.
            enable_mmaudio=True,
        )

        # Store TeaCache config for each model
        self.text_to_video_teacache = None
        self.image_to_video_teacache = None

    def _configure_teacache(self, model, config: GenerationConfig):
        """Configure TeaCache for a model based on generation config
        
        Args:
            model: The model to configure TeaCache for
            config: Generation configuration
        """
        if config.enable_teacache:
            # Create and enable TeaCache if it should be enabled
            teacache_config = TeaCacheConfig(
                enabled=True,
                rel_l1_thresh=config.teacache_threshold,
                num_inference_steps=config.num_inference_steps
            )
            enable_teacache(model.transformer.__class__, teacache_config)
            logger.info(f"TeaCache enabled with threshold {config.teacache_threshold}")
        else:
            # Disable TeaCache if it was previously enabled
            if hasattr(model.transformer.__class__, 'teacache_config'):
                disable_teacache(model.transformer.__class__)
                logger.info("TeaCache disabled")
    
    async def process_frames(
        self,
        frames: torch.Tensor,
        config: GenerationConfig
    ) -> tuple[str, dict]:
        """Post-process generated frames using Varnish
        
        Args:
            frames: Generated video frames tensor
            config: Generation configuration
            
        Returns:
            Tuple of (video data URI, metadata dictionary)
        """
        try:
            # Process video with Varnish
            result = await self.varnish(
                input_data=frames, # note: this might contain a certain number of frames eg. 97, which will get doubled if double_num_frames is True
                fps=config.fps, # this is the FPS of the final output video. This number can be used by Varnish to calculate the duration of a clip ((using frames * factor) / fps etc)
                double_num_frames=config.double_num_frames, # if True, the number of frames will be multiplied by 2 using RIFE
                super_resolution=config.super_resolution, # if True, the resolution will be multiplied by 2 using Real_ESRGAN
                grain_amount=config.grain_amount,
                enable_audio=config.enable_audio,
                audio_prompt=config.audio_prompt,
                audio_negative_prompt=config.audio_negative_prompt, 
            )
            
            # Convert to data URI
            video_uri = await result.write(type="data-uri", quality=config.quality)
            
            # Collect metadata
            metadata = {
                "width": result.metadata.width,
                "height": result.metadata.height,
                "num_frames": result.metadata.frame_count,
                "fps": result.metadata.fps,
                "duration": result.metadata.duration,
                "seed": config.seed,
            }
            
            return video_uri, metadata
    
        except Exception as e:
            logger.error(f"Error in process_frames: {str(e)}")
            raise RuntimeError(f"Failed to process frames: {str(e)}")


    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Process incoming requests for video generation
        
        Args:
            data: Request data containing:
                - inputs (dict): Dictionary containing input, which can be either "prompt" (text field) or "image" (input image)
                - parameters (dict):
                    - prompt (required, string): list of concepts to keep in the video.
                    - negative_prompt (optional, string): list of concepts to ignore in the video.
                    - width (optional, int, default to 768): width, or horizontal size in pixels.
                    - height (optional, int, default to 512): height, or vertical size in pixels.
                    - input_image_quality (optional, int, default to 100): this is a trick we use to convert a "pristine" image into a "dirty" video frame. This helps fooling LTX-Video into turning the image into an animated one.
                    - num_frames (optional, int, default to 129): the numer of frames must be a multiple of 8, plus 1 frame.
                    - guidance_scale (optional, float, default to 3.5): Guidance scale (values between 3.0 and 4.0 are nice)
                    - num_inference_steps (optional, int, default to 50): number of inference steps
                    - seed (optional, int, default to -1): set a random number generator seed, -1 means random seed.
                    - fps (optional, int, default to 24): FPS of the final video (eg. 24, 25, 30, 60)
                    - double_num_frames (optional, bool): if enabled, the number of frames will be multiplied by 2 using RIFE
                    - super_resolution (optional, bool): if enabled, the resolution will be multiplied by 2 using Real_ESRGAN
                    - grain_amount (optional, float): amount of film grain to add to the output video
                    - enable_audio (optional, bool): automatically generate an audio track
                    - audio_prompt (optional, str): prompt to use for the audio generation (concepts to add)
                    - audio_negative_prompt (optional, str): nehative prompt to use for the audio generation (concepts to ignore)
                    - quality (optional, str, default to 18): The range of the CRF scale is 0–51, where 0 is lossless (for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
                    - enable_teacache (optional, bool, default to True): Generate faster at the cost of a slight quality loss
                    - teacache_threshold (optional, float, default to 0.05): Amount of cache, 0 (original), 0.03 (1.6x speedup), 0.05 (Default, 2.1x speedup).
                    - enable_enhance_a_video (optional, bool, default to True): enable the enhance_a_video optimization
                    - enhance_a_video_weight(optional, float, default to 5.0): amount of video enhancement to apply
                    - lora_model_name(optional, str, default to ""): HuggingFace repo ID or path to LoRA model
                    - lora_model_weight_file(optional, str, default to ""): Specific weight file to load from the LoRA model
                    - lora_model_trigger(optional, str, default to ""): Optional trigger word to prepend to the prompt
        Returns:
            Dictionary containing:
                - video: Base64 encoded MP4 data URI
                - content-type: MIME type
                - metadata: Generation metadata
        """
        inputs = data.get("inputs", dict())
        
        input_prompt = inputs.get("prompt", "")
        input_image = inputs.get("image")
        
        params = data.get("parameters", dict())

        if not input_image and not input_prompt:
            raise ValueError("Either prompt or image must be provided")

        #logger.debug(f"Raw parameters:")
        # pprint.pprint(params)

        # Create and validate configuration
        config = GenerationConfig(
            # general content settings
            prompt=input_prompt,
            negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),

            # video model settings (will be used during generation of the initial raw video clip)
            width=params.get("width", GenerationConfig.width),
            height=params.get("height", GenerationConfig.height),
            input_image_quality=params.get("input_image_quality", GenerationConfig.input_image_quality),
            num_frames=params.get("num_frames", GenerationConfig.num_frames),
            guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
            num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),

            # reproducible generation settings
            seed=params.get("seed", GenerationConfig.seed),
            
            # varnish settings (will be used for post-processing after the raw video clip has been generated)
            fps=params.get("fps", GenerationConfig.fps), # FPS of the final video (only applied at the the very end, when converting to mp4)
            double_num_frames=params.get("double_num_frames", GenerationConfig.double_num_frames), # if True, the number of frames will be multiplied by 2 using RIFE
            super_resolution=params.get("super_resolution", GenerationConfig.super_resolution), # if True, the resolution will be multiplied by 2 using Real_ESRGAN
            grain_amount=params.get("grain_amount", GenerationConfig.grain_amount),
            enable_audio=params.get("enable_audio", GenerationConfig.enable_audio),
            audio_prompt=params.get("audio_prompt", GenerationConfig.audio_prompt),
            audio_negative_prompt=params.get("audio_negative_prompt", GenerationConfig.audio_negative_prompt),
            quality=params.get("quality", GenerationConfig.quality),
            
            # TeaCache settings
            enable_teacache=params.get("enable_teacache", True),

            # values: 0 (original), 0.03 (1.6x speedup), 0.05 (2.1x speedup).
            teacache_threshold=params.get("teacache_threshold", 0.05),

            
            # Add enhance-a-video settings
            enable_enhance_a_video=params.get("enable_enhance_a_video", True),
            enhance_a_video_weight=params.get("enhance_a_video_weight", 5.0),

            # LoRA settings
            lora_model_name=params.get("lora_model_name", ""),
            lora_model_weight_file=params.get("lora_model_weight_file", ""),
            lora_model_trigger=params.get("lora_model_trigger", ""),
        ).validate_and_adjust()
        
        #logger.debug(f"Global request settings:")
        #pprint.pprint(config)

        try:
            with torch.inference_mode():
                # Set random seeds
                random.seed(config.seed)
                np.random.seed(config.seed)
                torch.manual_seed(config.seed)
                generator = torch.Generator(device='cuda')
                generator = generator.manual_seed(config.seed)

                # Configure enhance-a-video
                #if config.enable_enhance_a_video:
                #    enable_enhance()
                #    set_enhance_weight(config.enhance_a_video_weight)
                
                # Prepare generation parameters for the video model (we omit params that are destined to Varnish, or things like the seed which is set externally)
                generation_kwargs = {
                   # general content settings
                    "prompt": config.prompt,
                    "negative_prompt": config.negative_prompt,
        
                    # video model settings (will be used during generation of the initial raw video clip)
                    "width": config.width,
                    "height": config.height,
                    "num_frames": config.num_frames,
                    "guidance_scale": config.guidance_scale,
                    "num_inference_steps": config.num_inference_steps,
 
                    # constants
                    "output_type": "pt",
                    "generator": generator,

                    # Timestep for decoding VAE noise: the timestep at which generated video is decoded
                    "decode_timestep": 0.05,
                    
                    # Noise level for decoding VAE noise: the interpolation factor between random noise and denoised latents at the decode timestep
                    "decode_noise_scale": 0.025,
                }
                #logger.info(f"Video model generation settings:")
                #pprint.pprint(generation_kwargs)

                # Handle LoRA loading/unloading
                if hasattr(self, '_current_lora_model'):
                    if self._current_lora_model != (config.lora_model_name, config.lora_model_weight_file):
                        # Unload previous LoRA if it exists and is different
                        if hasattr(self.text_to_video, 'unload_lora_weights'):
                            self.text_to_video.unload_lora_weights()
                        
                        if support_image_prompt and hasattr(self.image_to_video, 'unload_lora_weights'):
                            self.image_to_video.unload_lora_weights()
    
                if config.lora_model_name:
                    # Load new LoRA
                    if hasattr(self.text_to_video, 'load_lora_weights'):
                        self.text_to_video.load_lora_weights(
                            config.lora_model_name,
                            weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
                            token=hf_token,
                        )
                    if support_image_prompt and hasattr(self.image_to_video, 'load_lora_weights'):
                        self.image_to_video.load_lora_weights(
                            config.lora_model_name,
                            weight_name=config.lora_model_weight_file if config.lora_model_weight_file else None,
                            token=hf_token,
                        )
                    self._current_lora_model = (config.lora_model_name, config.lora_model_weight_file)
    
                # Modify prompt if trigger word is provided
                if config.lora_model_trigger:
                    generation_kwargs["prompt"] = f"{config.lora_model_trigger} {generation_kwargs['prompt']}"

                enhance_a_video_config = EnhanceAVideoConfig(
                    weight=config.enhance_a_video_weight if config.enable_enhance_a_video else 0.0,
                    # doing some testing
                    num_frames_callback=lambda: (8 + 1),
                    # num_frames_callback=lambda: config.num_frames,
                    # num_frames_callback=lambda: (config.num_frames - 1),
                    
                    _attention_type=1
                )
                
                # Check if image-to-video generation is requested
                if support_image_prompt and input_image:
                    self._configure_teacache(self.image_to_video, config)
                    processed_image = process_input_image(
                        input_image,
                        config.width,
                        config.height,
                        config.input_image_quality,
                    )
                    generation_kwargs["image"] = processed_image
                    # disabled (we cannot install the hook multiple times, we would have to uninstall it first or find another way to dynamically enable it, eg. using the weight only)
                    # apply_enhance_a_video(self.image_to_video.transformer, enhance_a_video_config)
                    frames = self.image_to_video(**generation_kwargs).frames
                else:
                    self._configure_teacache(self.text_to_video, config)
                    # disabled (we cannot install the hook multiple times, we would have to uninstall it first or find another way to dynamically enable it, eg. using the weight only)
                    # apply_enhance_a_video(self.text_to_video.transformer, enhance_a_video_config)
                    frames = self.text_to_video(**generation_kwargs).frames

                try:
                    loop = asyncio.get_event_loop()
                except RuntimeError:
                    loop = asyncio.new_event_loop()
                    asyncio.set_event_loop(loop)
                
                video_uri, metadata = loop.run_until_complete(self.process_frames(frames, config))
                
                return {
                    "video": video_uri,
                    "content-type": "video/mp4",
                    "metadata": metadata
                }

        except Exception as e:
            message = f"Error generating video ({str(e)})\n{traceback.format_exc()}"
            print(message)
            raise RuntimeError(message)