import torch


def ddpm_sampler(
    net,
    batch,
    conditioning_keys=None,
    scheduler=None,
    uncond_tokens=None,
    num_steps=1000,
    cfg_rate=0,
    generator=None,
    use_confidence_sampling=False,
    use_uncond_token=True,
    confidence_value=1.0,
    unconfidence_value=0.0,
):
    if scheduler is None:
        raise ValueError("Scheduler must be provided")

    x_cur = batch["y"].to(torch.float32)
    latents = batch["previous_latents"]
    if use_confidence_sampling:
        batch["confidence"] = (
            torch.ones(x_cur.shape[0], device=x_cur.device) * confidence_value
        )
    step_indices = torch.arange(num_steps + 1, dtype=torch.float32, device=x_cur.device)
    steps = 1 - step_indices / num_steps
    gammas = scheduler(steps)
    latents_cond = latents_uncond = latents
    # dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    dtype = torch.float32
    if cfg_rate > 0 and conditioning_keys is not None:
        stacked_batch = {}
        for key in conditioning_keys:
            if f"{key}_mask" in batch:
                if use_confidence_sampling and not use_uncond_token:
                    stacked_batch[f"{key}_mask"] = torch.cat(
                        [batch[f"{key}_mask"], batch[f"{key}_mask"]], dim=0
                    )
                else:
                    if (
                        batch[f"{key}_mask"].shape[1]
                        > uncond_tokens[f"{key}_mask"].shape[1]
                    ):
                        uncond_mask = (
                            torch.zeros_like(batch[f"{key}_mask"])
                            if batch[f"{key}_mask"].dtype == torch.bool
                            else torch.ones_like(batch[f"{key}_mask"]) * -torch.inf
                        )
                        uncond_mask[:, : uncond_tokens[f"{key}_mask"].shape[1]] = (
                            uncond_tokens[f"{key}_mask"]
                        )
                    else:
                        uncond_mask = uncond_tokens[f"{key}_mask"]
                        batch[f"{key}_mask"] = torch.cat(
                            [
                                batch[f"{key}_mask"],
                                torch.zeros(
                                    batch[f"{key}_mask"].shape[0],
                                    uncond_tokens[f"{key}_embeddings"].shape[1]
                                    - batch[f"{key}_mask"].shape[1],
                                    device=batch[f"{key}_mask"].device,
                                    dtype=batch[f"{key}_mask"].dtype,
                                ),
                            ],
                            dim=1,
                        )
                    stacked_batch[f"{key}_mask"] = torch.cat(
                        [batch[f"{key}_mask"], uncond_mask], dim=0
                    )
            if f"{key}_embeddings" in batch:
                if use_confidence_sampling and not use_uncond_token:
                    stacked_batch[f"{key}_embeddings"] = torch.cat(
                        [
                            batch[f"{key}_embeddings"],
                            batch[f"{key}_embeddings"],
                        ],
                        dim=0,
                    )
                else:
                    if (
                        batch[f"{key}_embeddings"].shape[1]
                        > uncond_tokens[f"{key}_embeddings"].shape[1]
                    ):
                        uncond_tokens[f"{key}_embeddings"] = torch.cat(
                            [
                                uncond_tokens[f"{key}_embeddings"],
                                torch.zeros(
                                    uncond_tokens[f"{key}_embeddings"].shape[0],
                                    batch[f"{key}_embeddings"].shape[1]
                                    - uncond_tokens[f"{key}_embeddings"].shape[1],
                                    uncond_tokens[f"{key}_embeddings"].shape[2],
                                    device=uncond_tokens[f"{key}_embeddings"].device,
                                ),
                            ],
                            dim=1,
                        )
                    elif (
                        batch[f"{key}_embeddings"].shape[1]
                        < uncond_tokens[f"{key}_embeddings"].shape[1]
                    ):
                        batch[f"{key}_embeddings"] = torch.cat(
                            [
                                batch[f"{key}_embeddings"],
                                torch.zeros(
                                    batch[f"{key}_embeddings"].shape[0],
                                    uncond_tokens[f"{key}_embeddings"].shape[1]
                                    - batch[f"{key}_embeddings"].shape[1],
                                    batch[f"{key}_embeddings"].shape[2],
                                    device=batch[f"{key}_embeddings"].device,
                                ),
                            ],
                            dim=1,
                        )
                    stacked_batch[f"{key}_embeddings"] = torch.cat(
                        [
                            batch[f"{key}_embeddings"],
                            uncond_tokens[f"{key}_embeddings"],
                        ],
                        dim=0,
                    )
            elif key not in batch:
                raise ValueError(f"Key {key} not in batch")
            else:
                if isinstance(batch[key], torch.Tensor):
                    if use_confidence_sampling and not use_uncond_token:
                        stacked_batch[key] = torch.cat([batch[key], batch[key]], dim=0)
                    else:
                        stacked_batch[key] = torch.cat(
                            [batch[key], uncond_tokens], dim=0
                        )
                elif isinstance(batch[key], list):
                    if use_confidence_sampling and not use_uncond_token:
                        stacked_batch[key] = [*batch[key], *batch[key]]
                    else:
                        stacked_batch[key] = [*batch[key], *uncond_tokens]
                else:
                    raise ValueError(
                        "Conditioning must be a tensor or a list of tensors"
                    )
        if use_confidence_sampling:
            stacked_batch["confidence"] = torch.cat(
                [
                    torch.ones(x_cur.shape[0], device=x_cur.device) * confidence_value,
                    torch.ones(x_cur.shape[0], device=x_cur.device)
                    * unconfidence_value,
                ],
                dim=0,
            )
    for step, (gamma_now, gamma_next) in enumerate(zip(gammas[:-1], gammas[1:])):
        with torch.cuda.amp.autocast(dtype=dtype):
            if cfg_rate > 0 and conditioning_keys is not None:
                stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0)
                stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2)
                stacked_batch["previous_latents"] = (
                    torch.cat([latents_cond, latents_uncond], dim=0)
                    if latents is not None
                    else None
                )
                denoised_all, latents_all = net(stacked_batch)
                denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0)
                latents_cond, latents_uncond = latents_all.chunk(2, dim=0)
                denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate
            else:
                batch["y"] = x_cur
                batch["gamma"] = gamma_now.expand(x_cur.shape[0])
                batch["previous_latents"] = latents
                denoised, latents = net(
                    batch,
                )
        x_pred = (x_cur - torch.sqrt(1 - gamma_now) * denoised) / torch.sqrt(gamma_now)
        x_pred = torch.clamp(x_pred, -1, 1)
        noise_pred = (x_cur - torch.sqrt(gamma_now) * x_pred) / torch.sqrt(
            1 - gamma_now
        )

        log_alpha_t = torch.log(gamma_now) - torch.log(gamma_next)
        alpha_t = torch.clip(torch.exp(log_alpha_t), 0, 1)
        x_mean = torch.rsqrt(alpha_t) * (
            x_cur - torch.rsqrt(1 - gamma_now) * (1 - alpha_t) * noise_pred
        )
        var_t = 1 - alpha_t
        eps = torch.randn(x_cur.shape, device=x_cur.device, generator=generator)
        x_next = x_mean + torch.sqrt(var_t) * eps
        x_cur = x_next
    return x_cur.to(torch.float32)