import torch
from torch import nn

class ImprovedGRUModel(nn.Module):
    def __init__(self,
                 input_size=1080,
                 hidden_size=240,
                 output_size=24,
                 num_layers=2,
                 bidirectional=True,
                 dropout_rate=0.1):
        super(ImprovedGRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_directions = 2 if bidirectional else 1
        self.gru = nn.GRU(
            input_size=input_size,
            hidden_size=self.hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout_rate if num_layers > 1 else 0, 
            bidirectional=bidirectional
        )
        self.fc1 = nn.Linear(hidden_size * self.num_directions, hidden_size)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        gru_out, _ = self.gru(x)
        
        fc1_out = self.fc1(gru_out)

        fc1_out = torch.relu(fc1_out)
        
        fc1_out = self.dropout(fc1_out)
        
        output = self.fc2(fc1_out)
        
        return output