import torch 
from torch import Tensor, nn
from transformers import PreTrainedModel
from .config import AdapterConfig 


class Model(nn.Module):

    def __init__(
        self,
        num_channels: int,
        num_filters: int,
        window_length: int,
        stride: int, 
    ):
        super().__init__()
        self.stride = stride
        padding = window_length // 2 - stride // 2
        self.conv = nn.Conv1d(
            in_channels=num_channels,
            out_channels=num_filters,
            kernel_size=window_length,
            stride=stride,
            padding=padding,
            padding_mode="reflect",
            bias=False,
        )
        self.decode = nn.ConvTranspose1d(
            in_channels=num_filters,
            out_channels=num_channels,  
            kernel_size=window_length,
            stride=stride,
            padding=padding,
            bias=False,
        )
    
    def encode(self, x: Tensor) -> Tensor:
        return torch.tanh(self.conv(x))


class Adapter(PreTrainedModel):

    config_class = AdapterConfig

    def __init__(self, config: AdapterConfig):
        super().__init__(config)

        self.model = Model(
            num_channels=2,
            num_filters=128,
            window_length=128,
            stride=64
        )
    
    def encode(self, x):
        return self.model.encode(x)

    def decode(self, x):
        return self.model.decode(x)