# Copyright (c) 2023, Dan Fu and Simran Arora.
# Adapted from https://github.com/HazyResearch/safari/blob/main/src/models/sequence/hyena.py

import torch.nn as nn
from einops import rearrange
import opt_einsum as oe

contract = oe.contract
from .hyena_utils import HyenaFilter


class MonarchMixerSequenceMixing(nn.Module):
    def __init__(
        self,
        d_model,
        l_max=128,
        dropout=0.0,
        hyena_kernel_lr=None,
        bidirectional=False,
        hyena_lr_pos_emb=1e-5,
        hyena_w=10,
        hyena_w_mod=1,
        hyena_wd=0.1,
        hyena_emb_dim=3,
        hyena_filter_dropout=0.0,
        hyena_filter_order=16,
        residual_long_conv=False,
        hyena_training_additions=False,
    ):
        super().__init__()

        self.d_model = d_model
        self.l_max = l_max
        self.kernel_lr = hyena_kernel_lr
        self.channels = 1
        self.bidirectional = bidirectional
        self.residual_long_conv = residual_long_conv
        self.NUM_PROJECTIONS = 3

        print('-- Bidirectional:', self.bidirectional)
        print("-- Using Long Conv Residual:", self.residual_long_conv)
        print('-- Hyena w:', hyena_w)
        print('-- Hyena w mod:', hyena_w_mod)
        print(f"-- Hyena filter order: {hyena_filter_order}")
        print(f"-- Hyena filter dropout: {hyena_filter_dropout}")
        print(f"-- Hyena filter wd: {hyena_wd}")
        print(f"-- Hyena filter emb dim: {hyena_emb_dim}")
        print(f"-- Hyena filter lr: {hyena_kernel_lr}")
        print(f"-- Hyena filter lr pos emb: {hyena_lr_pos_emb}")

        self.filter_fn = HyenaFilter(
            self.d_model,
            order=hyena_filter_order,
            seq_len=self.l_max,
            dropout=hyena_filter_dropout,
            bidirectional=self.bidirectional,
            lr=hyena_kernel_lr,
            lr_pos_emb=hyena_lr_pos_emb,
            w=hyena_w,  # frequency of periodic activations
            w_mod=hyena_w_mod,
            wd=hyena_wd,  # weight decay of kernel parameters
            emb_dim=hyena_emb_dim,
        )
        
        if self.residual_long_conv:
            self.filter_fn2 = HyenaFilter(
                self.d_model,
                order=hyena_filter_order,
                seq_len=self.l_max,
                dropout=hyena_filter_dropout,
                bidirectional=self.bidirectional,
                lr=hyena_kernel_lr,
                lr_pos_emb=hyena_lr_pos_emb,
                w=hyena_w,  # frequency of periodic activations
                w_mod=hyena_w_mod,
                wd=hyena_wd,  # weight decay of kernel parameters
                emb_dim=hyena_emb_dim,
            )
        
        # setup projections
        self.in_linear = nn.Linear(d_model, 3 * d_model)
        self.out_linear = nn.Linear(d_model, d_model)
        self.hyena_training_additions = hyena_training_additions
        if self.hyena_training_additions:
            self.act = nn.Identity()
            self.drop = nn.Dropout(dropout)
            self.layernorm = nn.LayerNorm(d_model)
        
        # setup short conv
        total_width = self.d_model * self.NUM_PROJECTIONS
        self.short_filter = nn.Conv1d(
            in_channels=total_width,
            out_channels=total_width,
            kernel_size=3,
            groups=total_width,
            padding=2,
        )


    def forward(self, u, **kwargs):
        # u is B L H
        if self.hyena_training_additions:
            u = self.layernorm(u)
        L = u.size(-2)

        # in projection
        u_orig = u
        u = self.in_linear(u)
        u = rearrange(u, "b l d -> b d l")
        
        # short filter
        uc = self.short_filter(u)[..., :L]

        x1, x2, v = uc.split(self.d_model, dim=1)
        
        v = v * x1
        if self.hyena_training_additions:
            v = self.drop(v)

        k = self.filter_fn.filter(L, device=u.device)
        k = rearrange(k, "c l d -> c d l")[0] # `c` is always 1 by default

        if self.bidirectional:
            k_rev = self.filter_fn.filter_rev(L, device=u.device)
            k_rev = rearrange(k_rev, "c l d -> c d l")[0] # `c` is always 1 by default
        else:
            k_rev = None

        y = self.filter_fn(v, L, k_fwd=k, k_rev=k_rev, bias= self.filter_fn.bias[None, :, None])

        if self.residual_long_conv:
            k2 = self.filter_fn2.filter(L, device=u.device)
            k2 = rearrange(k2, "c l d -> c d l")[0]

            if self.bidirectional:
                k2_rev = self.filter_fn2.filter_rev(L, device=u.device)
                k2_rev = rearrange(k2_rev, "c l d -> c d l")[0] # `c` is always 1 by default
            else:
                k2_rev = None                

            yu = self.filter_fn2(u_orig.transpose(-1, -2), L, k_fwd=k2, k_rev=k2_rev, bias= self.filter_fn2.bias[None, :, None])
        
        # post gating
        y = y * x2

        if self.residual_long_conv:
            y = y + yu

        y = y.transpose(-1, -2)
        if self.hyena_training_additions:
            y = self.drop(self.act(y))
        y = self.out_linear(y)

        return y, None