import spaces
import torch
import os
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
from transformers import T5EncoderModel
from diffusers.utils import export_to_video #, load_image #, PIL_INTERPOLATION

import gradio as gr
import numpy as np
import random
from PIL import Image
# import imageio.v3

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
#torch.backends.cuda.preferred_blas_library="cublas"
#torch.backends.cuda.preferred_linalg_library="cusolver"
torch.set_float32_matmul_precision("highest")
os.putenv("HF_HUB_ENABLE_HF_TRANSFER","1")
HF_TOKEN = os.getenv("HF_TOKEN")
os.environ["SAFETENSORS_FAST_GPU"] = "1"
MAX_SEED = np.iinfo(np.int64).max

single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors"

pipe = LTXImageToVideoPipeline.from_pretrained(
    "Lightricks/LTX-Video",
    token=HF_TOKEN,
    transformer=None,
    text_encoder=None,
).to(torch.device("cuda"),torch.bfloat16)

text_encoder = T5EncoderModel.from_pretrained("Lightricks/LTX-Video",subfolder='text_encoder',token=True).to(torch.device("cuda"),torch.bfloat16)
transformer = LTXVideoTransformer3DModel.from_single_file(single_file_url,token=HF_TOKEN).to(torch.device("cuda"),torch.bfloat16)

@spaces.GPU(duration=80)
def generate_video(
    image_url,
    prompt,
    negative_prompt,
    width,
    height,
    num_frames,
    guidance_scale,
    num_inference_steps,
    fps,
    progress=gr.Progress(track_tqdm=True)
):
    pipe.text_encoder=text_encoder
    pipe.transformer=transformer
    seed=random.randint(0, MAX_SEED)
    generator = torch.Generator(device="cuda").manual_seed(seed)
    image = Image.open(image_url).convert("RGB")
    image.resize((height,width), Image.LANCZOS)
    video = pipe(
        image=image,
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=width,
        height=height,
        num_frames=num_frames,
        frame_rate=fps,
        guidance_scale=guidance_scale,
        generator=generator,
        num_inference_steps=num_inference_steps,
        output_type='pt',
        max_sequence_length=512,
    ).frames
    video = video[0]
    video = video.permute(0, 2, 3, 1).cpu().detach().to(torch.float32).numpy()
    export_to_video(video, "output.mp4", fps=fps) 
    return "output.mp4"

iface = gr.Interface(
    fn=generate_video,
    inputs=[
        gr.Image(type="filepath", label="Image"),
        gr.Textbox(lines=2, label="Prompt"),
        gr.Textbox(lines=2, label="Negative Prompt", value="worst quality, inconsistent motion, blurry, jittery, distorted"),
        gr.Slider(minimum=256, maximum=1024, step=8, value=704, label="Width"),
        gr.Slider(minimum=256, maximum=1024, step=8, value=704, label="Height"),
        gr.Slider(minimum=16, maximum=256, step=16, value=121, label="Number of Frames"),
        gr.Slider(minimum=0.0, maximum=30.0, step=0.05, value=3.35, label="Guidance Scale"),
        gr.Slider(minimum=1, maximum=100, step=1, value=50, label="Number of Inference Steps"),
        gr.Slider(minimum=1, maximum=60, step=1, value=25, label="FPS"),
    ],
    outputs=gr.Video(label="Generated Video"),
    title="LTX-Video Test D",
    description="Generate video from image with LTX-Image-to-Video.",
)

iface.launch()