import os 
import yaml
import json

import torch
import torch.nn as nn
import torch.nn.functional as F

# from . import diffusion_utils as utils
from .molecule_utils import graph_to_smiles, check_valid
from .transformer import Transformer
from .visualize_utils import MolecularVisualization

class GraphDiT(nn.Module):
    def __init__(
        self,
        model_config_path,
        data_info_path,
        model_dtype,
    ):
        super().__init__()

    def init_model(self, model_dir):
        pass
    
    def disable_grads(self):
        pass

    def generate(self, properties, guide_scale, num_nodes, number_chain_steps):
        return 0, 0
    

# class GraphDiT(nn.Module):
#     def __init__(
#         self,
#         model_config_path,
#         data_info_path,
#         model_dtype,
#     ):
#         super().__init__()
#         dm_cfg, data_info = utils.load_config(model_config_path, data_info_path)

#         input_dims = data_info.input_dims
#         output_dims = data_info.output_dims
#         nodes_dist = data_info.nodes_dist
#         active_index = data_info.active_index

#         self.model_config = dm_cfg
#         self.data_info = data_info
#         self.T = dm_cfg.diffusion_steps
#         self.Xdim = input_dims["X"]
#         self.Edim = input_dims["E"]
#         self.ydim = input_dims["y"]
#         self.Xdim_output = output_dims["X"]
#         self.Edim_output = output_dims["E"]
#         self.ydim_output = output_dims["y"]
#         self.node_dist = nodes_dist
#         self.active_index = active_index
#         self.max_n_nodes = data_info.max_n_nodes
#         self.atom_decoder = data_info.atom_decoder
#         self.hidden_size = dm_cfg.hidden_size
#         self.mol_visualizer = MolecularVisualization(self.atom_decoder)

#         self.denoiser = Transformer(
#             max_n_nodes=self.max_n_nodes,
#             hidden_size=dm_cfg.hidden_size,
#             depth=dm_cfg.depth,
#             num_heads=dm_cfg.num_heads,
#             mlp_ratio=dm_cfg.mlp_ratio,
#             drop_condition=dm_cfg.drop_condition,
#             Xdim=self.Xdim,
#             Edim=self.Edim,
#             ydim=self.ydim,
#         )

#         self.model_dtype = model_dtype
#         self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete(
#             dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps
#         )
#         x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum(
#             data_info.node_types.to(self.model_dtype)
#         )
#         e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum(
#             data_info.edge_types.to(self.model_dtype)
#         )
#         x_marginals = x_marginals / x_marginals.sum()
#         e_marginals = e_marginals / e_marginals.sum()

#         xe_conditions = data_info.transition_E.to(self.model_dtype)
#         xe_conditions = xe_conditions[self.active_index][:, self.active_index]

#         xe_conditions = xe_conditions.sum(dim=1)
#         ex_conditions = xe_conditions.t()
#         xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True)
#         ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True)

#         self.transition_model = utils.MarginalTransition(
#             x_marginals=x_marginals,
#             e_marginals=e_marginals,
#             xe_conditions=xe_conditions,
#             ex_conditions=ex_conditions,
#             y_classes=self.ydim_output,
#             n_nodes=self.max_n_nodes,
#         )
#         self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None)

#     def init_model(self, model_dir):
#         model_file = os.path.join(model_dir, 'model.pt')
#         if os.path.exists(model_file):
#             self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True))
#         else:
#             raise FileNotFoundError(f"Model file not found: {model_file}")

#     def disable_grads(self):
#         self.denoiser.disable_grads()
    
#     def forward(
#         self, x, edge_index, edge_attr, graph_batch, properties, no_label_index
#     ):
#         raise ValueError('Not Implement')

#     def _forward(self, noisy_data, unconditioned=False):
#         noisy_x, noisy_e, properties = (
#             noisy_data["X_t"].to(self.model_dtype),
#             noisy_data["E_t"].to(self.model_dtype),
#             noisy_data["y_t"].to(self.model_dtype).clone(),
#         )
#         node_mask, timestep = (
#             noisy_data["node_mask"],
#             noisy_data["t"],
#         )
        
#         pred = self.denoiser(
#             noisy_x,
#             noisy_e,
#             node_mask,
#             properties,
#             timestep,
#             unconditioned=unconditioned,
#         )
#         return pred

#     def apply_noise(self, X, E, y, node_mask):
#         """Sample noise and apply it to the data."""

#         # Sample a timestep t.
#         # When evaluating, the loss for t=0 is computed separately
#         lowest_t = 0 if self.training else 1
#         t_int = torch.randint(
#             lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device
#         ).to(
#             self.model_dtype
#         )  # (bs, 1)
#         s_int = t_int - 1

#         t_float = t_int / self.T
#         s_float = s_int / self.T

#         # beta_t and alpha_s_bar are used for denoising/loss computation
#         beta_t = self.noise_schedule(t_normalized=t_float)  # (bs, 1)
#         alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float)  # (bs, 1)
#         alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float)  # (bs, 1)

#         Qtb = self.transition_model.get_Qt_bar(
#             alpha_t_bar, X.device
#         )  # (bs, dx_in, dx_out), (bs, de_in, de_out)

#         bs, n, d = X.shape
#         X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1)
#         prob_all = X_all @ Qtb.X
#         probX = prob_all[:, :, : self.Xdim_output]
#         probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1)

#         sampled_t = utils.sample_discrete_features(
#             probX=probX, probE=probE, node_mask=node_mask
#         )

#         X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
#         E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
#         assert (X.shape == X_t.shape) and (E.shape == E_t.shape)

#         y_t = y
#         z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask)

#         noisy_data = {
#             "t_int": t_int,
#             "t": t_float,
#             "beta_t": beta_t,
#             "alpha_s_bar": alpha_s_bar,
#             "alpha_t_bar": alpha_t_bar,
#             "X_t": z_t.X,
#             "E_t": z_t.E,
#             "y_t": z_t.y,
#             "node_mask": node_mask,
#         }
#         return noisy_data

#     @torch.no_grad()
#     def generate(
#         self,
#         properties,
#         guide_scale=1.,
#         num_nodes=None,
#         number_chain_steps=50,
#     ):
#         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#         properties = [float('nan') if x is None else x for x in properties]
#         properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
#         batch_size = properties.size(0)
#         assert batch_size == 1
#         if num_nodes is None:
#             num_nodes = self.node_dist.sample_n(batch_size, device)
#         else:
#             num_nodes = torch.LongTensor([num_nodes]).to(device)

#         arange = (
#             torch.arange(self.max_n_nodes, device=device)
#             .unsqueeze(0)
#             .expand(batch_size, -1)
#         )
#         node_mask = arange < num_nodes.unsqueeze(1)

#         z_T = utils.sample_discrete_feature_noise(
#             limit_dist=self.limit_dist, node_mask=node_mask
#         )
#         X, E = z_T.X, z_T.E

#         assert (E == torch.transpose(E, 1, 2)).all()

#         if number_chain_steps > 0:
#             chain_X_size = torch.Size((number_chain_steps, X.size(1)))
#             chain_E_size = torch.Size((number_chain_steps, E.size(1), E.size(2)))
#             chain_X = torch.zeros(chain_X_size)
#             chain_E = torch.zeros(chain_E_size)

#         # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
#         y = properties
#         for s_int in reversed(range(0, self.T)):
#             s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
#             t_array = s_array + 1
#             s_norm = s_array / self.T
#             t_norm = t_array / self.T

#             # Sample z_s
#             sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt(
#                 s_norm, t_norm, X, E, y, node_mask, guide_scale, device
#             )
#             X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
            
#             if number_chain_steps > 0:
#                 # Save the first keep_chain graphs
#                 write_index = (s_int * number_chain_steps) // self.T
#                 chain_X[write_index] = discrete_sampled_s.X[:1]
#                 chain_E[write_index] = discrete_sampled_s.E[:1]

#         # Sample
#         sampled_s = sampled_s.mask(node_mask, collapse=True)
#         X, E, y = sampled_s.X, sampled_s.E, sampled_s.y

#         molecule_list = []
#         n = num_nodes[0]
#         atom_types = X[0, :n].cpu()
#         edge_types = E[0, :n, :n].cpu()
#         molecule_list.append([atom_types, edge_types])
#         smiles = graph_to_smiles(molecule_list, self.atom_decoder)[0]

#         # Visualize Chains
#         if number_chain_steps > 0:
#             final_X_chain = X[:1]
#             final_E_chain = E[:1]

#             chain_X[0] = final_X_chain                  # Overwrite last frame with the resulting X, E
#             chain_E[0] = final_E_chain

#             chain_X = utils.reverse_tensor(chain_X)
#             chain_E = utils.reverse_tensor(chain_E)

#             # Repeat last frame to see final sample better
#             chain_X = torch.cat([chain_X, chain_X[-1:].repeat(10, 1)], dim=0)
#             chain_E = torch.cat([chain_E, chain_E[-1:].repeat(10, 1, 1)], dim=0)
#             mol_img_list = self.mol_visualizer.visualize_chain(chain_X.numpy(), chain_E.numpy())
#         else:
#             mol_img_list = []

#         return smiles, mol_img_list

#     def check_valid(self, smiles):
#         return check_valid(smiles)
    
#     def sample_p_zs_given_zt(
#         self, s, t, X_t, E_t, properties, node_mask, guide_scale, device
#     ):
#         """Samples from zs ~ p(zs | zt). Only used during sampling.
#         if last_step, return the graph prediction as well"""
#         bs, n, _ = X_t.shape
#         beta_t = self.noise_schedule(t_normalized=t)  # (bs, 1)
#         alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
#         alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)

#         # Neural net predictions
#         noisy_data = {
#             "X_t": X_t,
#             "E_t": E_t,
#             "y_t": properties,
#             "t": t,
#             "node_mask": node_mask,
#         }

#         def get_prob(noisy_data, unconditioned=False):
#             pred = self._forward(noisy_data, unconditioned=unconditioned)

#             # Normalize predictions
#             pred_X = F.softmax(pred.X, dim=-1)  # bs, n, d0
#             pred_E = F.softmax(pred.E, dim=-1)  # bs, n, n, d0

#             # Retrieve transitions matrix
#             Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device)
#             Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device)
#             Qt = self.transition_model.get_Qt(beta_t, device)

#             Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1)
#             predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1)

#             unnormalized_probX_all = utils.reverse_diffusion(
#                 predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X
#             )

#             unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output]
#             unnormalized_prob_E = unnormalized_probX_all[
#                 :, :, self.Xdim_output :
#             ].reshape(bs, n * n, -1)

#             unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
#             unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5

#             prob_X = unnormalized_prob_X / torch.sum(
#                 unnormalized_prob_X, dim=-1, keepdim=True
#             )  # bs, n, d_t-1
#             prob_E = unnormalized_prob_E / torch.sum(
#                 unnormalized_prob_E, dim=-1, keepdim=True
#             )  # bs, n, d_t-1
#             prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])

#             return prob_X, prob_E

#         prob_X, prob_E = get_prob(noisy_data)

#         ### Guidance
#         if guide_scale != 1:
#             uncon_prob_X, uncon_prob_E = get_prob(
#                 noisy_data, unconditioned=True
#             )
#             prob_X = (
#                 uncon_prob_X
#                 * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** guide_scale
#             )
#             prob_E = (
#                 uncon_prob_E
#                 * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** guide_scale
#             )
#             prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5)
#             prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5)

#         # assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-3).all()
#         # assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-3).all()

#         sampled_s = utils.sample_discrete_features(
#             prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item()
#         )

#         X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype)
#         E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype)

#         assert (E_s == torch.transpose(E_s, 1, 2)).all()
#         assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)

#         out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties)
#         out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties)

#         return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask(
#             node_mask, collapse=True
#         ).type_as(properties)