import torch
import torch.nn.functional as F
import numpy as np
from typing import Optional, List, Dict, Union

from ..builder import ARCHITECTURES, build_loss, build_submodule
from ..utils.gaussian_diffusion import create_named_schedule_sampler, build_diffusion
from ..utils.mask_helper import expand_mask_to_all
from .base_architecture import BaseArchitecture


def set_requires_grad(nets: Union[torch.nn.Module, List[torch.nn.Module]], requires_grad: bool = False):
    """Set requires_grad for all the networks.

    Args:
        nets (nn.Module | list[nn.Module]): A list of networks or a single network.
        requires_grad (bool): Whether the networks require gradients or not.
    """
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad


@ARCHITECTURES.register_module()
class MotionDiffusion(BaseArchitecture):
    """
    Motion Diffusion architecture for modeling and generating motion sequences using diffusion models.

    Args:
        dataset_name (Optional[str]): Name of the dataset being used (e.g., 'kit_ml', 'human_ml3d').
        model (dict): Configuration for the submodule (e.g., the motion generation model).
        loss_recon (dict): Configuration for the reconstruction loss.
        loss_reduction (str): Specifies the reduction method for the loss. Defaults to 'frame'.
        use_loss_score (bool): Whether to use a scoring mechanism for loss calculation. Defaults to False.
        diffusion_train (dict): Configuration for the diffusion model during training.
        diffusion_test (dict): Configuration for the diffusion model during testing.
        sampler_type (str): The type of sampler to use. Defaults to 'uniform'.
        init_cfg (dict): Initialization config for the module.
        inference_type (str): Type of inference to use ('ddpm' or 'ddim'). Defaults to 'ddpm'.
    """

    def __init__(self,
                 dataset_name: Optional[str] = None,
                 model: dict = None,
                 loss_recon: dict = None,
                 loss_reduction: str = "frame",
                 use_loss_score: bool = False,
                 diffusion_train: dict = None,
                 diffusion_test: dict = None,
                 sampler_type: str = 'uniform',
                 init_cfg: dict = None,
                 inference_type: str = 'ddpm',
                 **kwargs):
        super().__init__(init_cfg=init_cfg, **kwargs)
        self.model = build_submodule(model)
        self.loss_recon = build_loss(loss_recon)
        self.diffusion_train = build_diffusion(diffusion_train)
        self.diffusion_test = build_diffusion(diffusion_test)
        self.sampler = create_named_schedule_sampler(sampler_type, self.diffusion_train)
        self.inference_type = inference_type
        self.loss_reduction = loss_reduction
        self.use_loss_score = use_loss_score
        self.dataset_name = dataset_name

        if self.dataset_name == "kit_ml":
            self.mean = np.load("data/datasets/kit_ml/mean.npy")
            self.std = np.load("data/datasets/kit_ml/std.npy")
        elif self.dataset_name == "human_ml3d":
            self.mean = np.load("data/datasets/human_ml3d/mean.npy")
            self.std = np.load("data/datasets/human_ml3d/std.npy")
        elif self.dataset_name is not None:
            raise NotImplementedError()


    def forward(self, **kwargs) -> Union[Dict, List]:
        """Forward pass of the model.

        Depending on whether the model is in training mode, this method performs the forward pass
        during training or inference, and calculates the relevant losses.

        Args:
            **kwargs: Keyword arguments containing the input data for the model.

        Returns:
            dict or list: The calculated losses during training or the generated motion during inference.
        """
        motion = kwargs['motion'].float()
        motion_mask = kwargs['motion_mask'].float()
        motion_length = kwargs['motion_length']
        num_intervals = kwargs.get('num_intervals', 1)
        sample_idx = kwargs.get('sample_idx', None)
        clip_feat = kwargs.get('clip_feat', None)
        B, T = motion.shape[:2]
        text = [kwargs['motion_metas'][i]['text'] for i in range(B)]

        if self.training:
            t, _ = self.sampler.sample(B, motion.device)
            output = self.diffusion_train.training_losses(
                model=self.model,
                x_start=motion,
                t=t,
                model_kwargs={
                    'motion_mask': motion_mask,
                    'motion_length': motion_length,
                    'text': text,
                    'clip_feat': clip_feat,
                    'sample_idx': sample_idx,
                    'num_intervals': num_intervals
                }
            )
            pred, target = output['pred'], output['target']
            recon_loss = self.loss_recon(pred, target, reduction_override='none')

            if self.use_loss_score:
                loss_score = kwargs['score']
                recon_loss = recon_loss * loss_score.view(B, 1, -1)

            recon_loss = recon_loss.mean(dim=-1) * motion_mask
            recon_loss_batch = recon_loss.sum(dim=1) / motion_mask.sum(dim=1)
            recon_loss_frame = recon_loss.sum() / motion_mask.sum()

            if self.loss_reduction == "frame":
                recon_loss = recon_loss_frame
            else:
                recon_loss = recon_loss_batch

            if hasattr(self.sampler, "update_with_local_losses"):
                self.sampler.update_with_local_losses(t, recon_loss_batch)

            loss = {'recon_loss': recon_loss.mean()}
            if hasattr(self.model, 'aux_loss'):
                loss.update(self.model.aux_loss())
            return loss

        else:
            dim_pose = kwargs['motion'].shape[-1]
            model_kwargs = self.model.get_precompute_condition(
                device=motion.device, text=text, **kwargs
            )
            model_kwargs.update({
                'motion_mask': motion_mask,
                'sample_idx': sample_idx,
                'motion_length': motion_length,
                'num_intervals': num_intervals
            })

            inference_kwargs = kwargs.get('inference_kwargs', {})
            if self.inference_type == 'ddpm':
                output = self.diffusion_test.p_sample_loop(
                    self.model, (B, T, dim_pose), clip_denoised=False, progress=False,
                    model_kwargs=model_kwargs, **inference_kwargs
                )
            else:
                output = self.diffusion_test.ddim_sample_loop(
                    self.model, (B, T, dim_pose), clip_denoised=False, progress=False,
                    model_kwargs=model_kwargs, eta=0, **inference_kwargs
                )

            results = kwargs
            if getattr(self.model, "post_process") is not None:
                output = self.model.post_process(output)

            results['pred_motion'] = output
            results = self.split_results(results)
            return results


@ARCHITECTURES.register_module()
class UnifiedMotionDiffusion(BaseArchitecture):
    """
    Unified Motion Diffusion architecture for generating motion sequences using diffusion models.
    
    Args:
        model (dict): Configuration for the motion generation model.
        loss_recon (dict): Configuration for the reconstruction loss.
        loss_reduction (str): Specifies the reduction method for the loss. Defaults to 'frame'.
        random_mask (float): Probability or scaling factor for applying random masking. Defaults to 0.
        diffusion_train (dict): Configuration for the diffusion model during training.
        diffusion_test (dict): Configuration for the diffusion model during testing.
        sampler_type (str): The type of sampler to use. Defaults to 'uniform'.
        init_cfg (dict): Initialization config for the module.
        inference_type (str): Type of inference to use ('ddpm' or 'ddim'). Defaults to 'ddpm'.
        body_scale (float): Scaling factor for the body motion mask. Defaults to 1.0.
        hand_scale (float): Scaling factor for the hand motion mask. Defaults to 1.0.
        face_scale (float): Scaling factor for the face motion mask. Defaults to 1.0.
    """

    def __init__(self,
                 model: dict = None,
                 loss_recon: dict = None,
                 loss_reduction: str = "frame",
                 random_mask: float = 0,
                 diffusion_train: dict = None,
                 diffusion_test_dict: dict = None,
                 sampler_type: str = 'uniform',
                 init_cfg: dict = None,
                 inference_type: str = 'ddpm',
                 body_scale: float = 1.0,
                 hand_scale: float = 1.0,
                 face_scale: float = 1.0,
                 train_repeat: int = 1,
                 loss_weight: str = None,
                 **kwargs):
        super().__init__(init_cfg=init_cfg, **kwargs)
        self.model = build_submodule(model)
        self.loss_recon = build_loss(loss_recon)
        self.diffusion_train = build_diffusion(diffusion_train)
        self.diffusion_test_dict = diffusion_test_dict
        self.sampler = create_named_schedule_sampler(sampler_type, self.diffusion_train)
        self.inference_type = inference_type
        self.loss_reduction = loss_reduction
        self.random_mask = random_mask
        self.body_scale = body_scale
        self.hand_scale = hand_scale
        self.face_scale = face_scale
        self.train_repeat = train_repeat
        self.loss_weight = None
        if init_cfg is not None:
            self.init_weights()

    def repeat_data(self, **kwargs):
        if self.train_repeat == 1:
            return kwargs
        N = self.train_repeat
        motion = kwargs['motion'].float().repeat(N, 1, 1)
        B = motion.shape[0]
        kwargs['motion'] = motion
        
        motion_mask = kwargs['motion_mask'].float().repeat(N, 1, 1)
        kwargs['motion_mask'] = motion_mask
        
        motion_length = kwargs['motion_length'].repeat(N, 1)
        kwargs['motion_length'] = motion_length

        motion_metas = kwargs['motion_metas'] * N
        kwargs['motion_metas'] = motion_metas

        if 'text_seq_feat' in kwargs:
            kwargs['text_seq_feat'] = kwargs['text_seq_feat'].repeat(N, 1)
        if 'text_word_feat' in kwargs:
            kwargs['text_word_feat'] = kwargs['text_word_feat'].repeat(N, 1, 1)
        if 'text_cond' in kwargs:
            kwargs['text_cond'] = kwargs['text_cond'].repeat(N, 1)
            
        if 'music_seq_feat' in kwargs:
            kwargs['music_seq_feat'] = kwargs['music_seq_feat'].repeat(N, 1)
        if 'music_word_feat' in kwargs:
            kwargs['music_word_feat'] = kwargs['music_word_feat'].repeat(N, 1, 1)
        if 'music_cond' in kwargs:
            kwargs['music_cond'] = kwargs['music_cond'].repeat(N, 1)

        if 'speech_seq_feat' in kwargs:
            kwargs['speech_seq_feat'] = kwargs['speech_seq_feat'].repeat(N, 1)
        if 'speech_word_feat' in kwargs:
            kwargs['speech_word_feat'] = kwargs['speech_word_feat'].repeat(N, 1, 1)
        if 'speech_cond' in kwargs:
            kwargs['speech_cond'] = kwargs['speech_cond'].repeat(N, 1)

        if 'video_seq_feat' in kwargs:
            kwargs['video_seq_feat'] = kwargs['video_seq_feat'].repeat(N, 1)
        if 'video_word_feat' in kwargs:
            kwargs['video_word_feat'] = kwargs['video_word_feat'].repeat(N, 1, 1)
        if 'video_cond' in kwargs:
            kwargs['video_cond'] = kwargs['video_cond'].repeat(N, 1)
        return kwargs


    def forward(self, **kwargs) -> Dict:
        """Forward pass for training or inference in the unified motion diffusion model.
        
        Args:
            **kwargs: Keyword arguments containing the input data for the model.
        
        Returns:
            dict: The calculated losses during training or the generated motion during inference.
        """
        if self.training:
            kwargs = self.repeat_data(**kwargs)
            
        motion = kwargs['motion'].float()
        B, T = motion.shape[:2]
        motion_mask = kwargs['motion_mask'].float()
        motion_length = kwargs['motion_length']
        num_intervals = kwargs.get('num_intervals', 1)
        sample_idx = kwargs.get('sample_idx', None)
        motion_metas = kwargs['motion_metas']

        # Conditioning features (text, music, speech, video)
        text_word_feat = kwargs.get('text_word_feat', None)
        text_seq_feat = kwargs.get('text_seq_feat', None)
        text_cond = kwargs.get('text_cond', torch.zeros(B).type_as(motion))

        music_word_feat = kwargs.get('music_word_feat', None)
        music_seq_feat = kwargs.get('music_seq_feat', None)
        music_cond = kwargs.get('music_cond', torch.zeros(B).type_as(motion))

        speech_word_feat = kwargs.get('speech_word_feat', None)
        speech_seq_feat = kwargs.get('speech_seq_feat', None)
        speech_cond = kwargs.get('speech_cond', torch.zeros(B).type_as(motion))

        video_word_feat = kwargs.get('video_word_feat', None)
        video_seq_feat = kwargs.get('video_seq_feat', None)
        video_cond = kwargs.get('video_cond', torch.zeros(B).type_as(motion))

        if self.training:
            # Random masking during training
            t, _ = self.sampler.sample(B, motion.device)
            
            # rand_mask = torch.rand_like(motion_mask)
            # new_motion_mask = motion_mask.clone()
            # threshold = torch.rand(B).type_as(rand_mask)
            # threshold = threshold.view(B, 1, 1).repeat(1, T, 10)
            # new_motion_mask[rand_mask < threshold] = 0
            # motion_mask = new_motion_mask
            
            output = self.diffusion_train.training_losses(
                model=self.model,
                x_start=motion,
                t=t,
                model_kwargs={
                    'motion_mask': motion_mask,
                    'motion_length': motion_length,
                    'num_intervals': num_intervals,
                    'motion_metas': motion_metas,
                    'text_word_feat': text_word_feat,
                    'text_seq_feat': text_seq_feat,
                    'text_cond': text_cond,
                    'music_word_feat': music_word_feat,
                    'music_seq_feat': music_seq_feat,
                    'music_cond': music_cond,
                    'speech_word_feat': speech_word_feat,
                    'speech_seq_feat': speech_seq_feat,
                    'speech_cond': speech_cond,
                    'video_word_feat': video_word_feat,
                    'video_seq_feat': video_seq_feat,
                    'video_cond': video_cond,
                })
            pred, target = output['pred'], output['target']
            recon_loss = self.loss_recon(pred, target, reduction_override='none')
            # Apply expanded motion mask
            motion_mask = expand_mask_to_all(
                motion_mask, self.body_scale, self.hand_scale, self.face_scale
            )
            if self.loss_weight is not None:
                loss_weight = torch.from_numpy(self.loss_weight).type_as(motion_mask)
                dataset_idx = self.model.dataset_idx
                loss_weight = loss_weight.index_select(0, dataset_idx).unsqueeze(1)
                motion_mask = motion_mask * loss_weight
                recon_loss = (recon_loss * motion_mask).sum(dim=-1)
                motion_mask = motion_mask.sum(dim=-1)
            else:
                recon_loss = (recon_loss * motion_mask).mean(dim=-1)
                motion_mask = motion_mask.mean(dim=-1)

            recon_loss_batch = recon_loss.sum(dim=1) / motion_mask.sum(dim=1)
            recon_loss_frame = recon_loss.sum() / motion_mask.sum()

            # Determine final reconstruction loss
            if self.loss_reduction == "frame":
                recon_loss = recon_loss_frame
            else:
                recon_loss = recon_loss_batch

            if hasattr(self.sampler, "update_with_local_losses"):
                self.sampler.update_with_local_losses(t, recon_loss_batch)

            loss = {'recon_loss': recon_loss.mean()}

            # Add auxiliary loss if applicable
            if hasattr(self.model, 'aux_loss'):
                loss.update(self.model.aux_loss())

            return loss
        else:
            # Inference (DDPM or DDIM sampling)
            dim_pose = 669  # Fixed dimension for the motion output
            model_kwargs = self.model.get_precompute_condition(
                device=motion.device, **kwargs
            )
            model_kwargs.update({
                'motion_mask': motion_mask,
                'sample_idx': sample_idx,
                'motion_length': motion_length,
                'num_intervals': num_intervals,
                'motion_metas': motion_metas,
                'text_word_feat': text_word_feat,
                'text_seq_feat': text_seq_feat,
                'text_cond': text_cond,
                'music_word_feat': music_word_feat,
                'music_seq_feat': music_seq_feat,
                'music_cond': music_cond,
                'speech_word_feat': speech_word_feat,
                'speech_seq_feat': speech_seq_feat,
                'speech_cond': speech_cond,
                'video_word_feat': video_word_feat,
                'video_seq_feat': video_seq_feat,
                'video_cond': video_cond,
            })
            inference_kwargs = kwargs.get('inference_kwargs', {})
            inference_kwargs['gt_motion'] = motion
            inference_kwargs['context_mask'] = kwargs.get('context_mask', None)
            dataset_name = motion_metas[0]['meta_data']['dataset_name']
            diffusion_test_cfg = self.diffusion_test_dict['base']
            diffusion_test_cfg.update(dict(respace=self.diffusion_test_dict[dataset_name]))
            diffusion_test = build_diffusion(diffusion_test_cfg)
            if self.inference_type == 'ddpm':
                output = diffusion_test.p_sample_loop(
                    self.model, (B, T, dim_pose), clip_denoised=False,
                    progress=False, model_kwargs=model_kwargs, **inference_kwargs
                )
            else:
                output = diffusion_test.ddim_sample_loop(
                    self.model, (B, T, dim_pose), clip_denoised=False,
                    progress=False, model_kwargs=model_kwargs, eta=0,
                    **inference_kwargs
                )

            results = kwargs
            if getattr(self.model, "post_process") is not None:
                output = self.model.post_process(output)

            results['pred_motion'] = output
            results = self.split_results(results)

            return results