# -*- coding: utf-8 -*- # Copyright (c) 2024, Songlin Yang, Yu Zhang from typing import Optional, Tuple import torch from fla.ops.common.fused_recurrent import fused_recurrent def fused_recurrent_gla( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, gk: Optional[torch.Tensor] = None, gv: Optional[torch.Tensor] = None, scale: Optional[int] = None, initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, reverse: bool = False, offsets: Optional[torch.LongTensor] = None, head_first: bool = True ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. k (torch.Tensor): keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. v (torch.Tensor): values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. gk (torch.Tensor): Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys. gv (torch.Tensor): Forget gates of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` applied to values. scale (Optional[int]): Scale factor for the attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. initial_state (Optional[torch.Tensor]): Initial state of shape `[N, H, K, V]` for `N` input sequences. For equal-length input sequences, `N` equals the batch size `B`. Default: `None`. output_final_state (Optional[bool]): Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. reverse (Optional[bool]): If `True`, process the state passing in reverse order. Default: `False`. offsets (Optional[torch.LongTensor]): Offsets of shape `[N+1]` defining the bos/eos positions of `N` variable-length sequences in the batch. For example, if `offsets` is `[0, 1, 3, 6, 10, 15]`, there are `N=5` sequences with lengths 1, 2, 3, 4 and 5 respectively. If provided, the inputs are concatenated and the batch size `B` is expected to be 1. Default: `None`. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. Default: `True`. Returns: o (torch.Tensor): Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. final_state (torch.Tensor): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. Examples:: >>> import torch >>> import torch.nn.functional as F >>> from einops import rearrange >>> from fla.ops.gla import fused_recurrent_gla # inputs with equal lengths >>> B, T, H, K, V = 4, 2048, 4, 512, 512 >>> q = torch.randn(B, T, H, K, device='cuda') >>> k = torch.randn(B, T, H, K, device='cuda') >>> v = torch.randn(B, T, H, V, device='cuda') >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) >>> h0 = torch.randn(B, H, K, V, device='cuda') >>> o, ht = fused_recurrent_gla(q, k, v, g, initial_state=h0, output_final_state=True, head_first=False) # for variable-length inputs, the batch size `B` is expected to be 1 and `offsets` is required >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) # for a batch with 4 sequences, offsets with 5 start/end positions are expected >>> offsets = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) >>> o_var, ht_var = fused_recurrent_gla(q, k, v, g, initial_state=h0, output_final_state=True, offsets=offsets, head_first=False) >>> assert o.allclose(o_var.view(o.shape)) >>> assert ht.allclose(ht_var) """ if offsets is not None: if q.shape[0] != 1: raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `offsets`." f"Please flatten variable-length inputs before processing.") if head_first: raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") if initial_state is not None and initial_state.shape[0] != len(offsets) - 1: raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " f"i.e., {len(offsets) - 1} rather than {initial_state.shape[0]}.") if scale is None: scale = k.shape[-1] ** -0.5 o, final_state = fused_recurrent( q=q, k=k, v=v, g=None, gk=gk, gv=gv, scale=scale, initial_state=initial_state, output_final_state=output_final_state, reverse=reverse, offsets=offsets, head_first=head_first ) return o, final_state