zaydzuhri's picture
Training in progress, step 2500
0094a2a verified
# -*- coding: utf-8 -*-
# Copyright (c) 2024, Songlin Yang, Yu Zhang
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from fla.utils import contiguous
@triton.autotune(
configs=[
triton.Config({'BD': 32}, num_warps=1),
triton.Config({'BD': 32}, num_warps=2),
triton.Config({'BD': 32}, num_warps=4),
triton.Config({'BD': 32}, num_warps=8),
triton.Config({'BD': 64}, num_warps=1),
triton.Config({'BD': 64}, num_warps=2),
triton.Config({'BD': 64}, num_warps=4),
triton.Config({'BD': 64}, num_warps=8),
triton.Config({'BD': 128}, num_warps=1),
triton.Config({'BD': 128}, num_warps=2),
triton.Config({'BD': 128}, num_warps=4),
triton.Config({'BD': 128}, num_warps=8),
],
key=['D']
)
@triton.heuristics({
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.jit
def fused_recurrent_hgrn_fwd_kernel(
x,
g,
o,
h0,
ht,
offsets,
T: tl.constexpr,
D: tl.constexpr,
BD: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
USE_OFFSETS: tl.constexpr
):
i_d, i_n = tl.program_id(0), tl.program_id(1)
if USE_OFFSETS:
bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
p_x = x + bos * D + o_d
p_g = g + bos * D + o_d
p_o = o + bos * D + o_d
b_h = tl.zeros([BD], dtype=tl.float32)
if USE_INITIAL_STATE:
p_h0 = h0 + i_n * D + o_d
b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)
for _ in range(0, T):
b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
b_h = tl.exp(b_g) * b_h + b_x
tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)
p_x += D
p_g += D
p_o += D
if STORE_FINAL_STATE:
p_ht = ht + i_n * D + o_d
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)
@triton.autotune(
configs=[
triton.Config({'BD': 32}, num_warps=1),
triton.Config({'BD': 32}, num_warps=2),
triton.Config({'BD': 32}, num_warps=4),
triton.Config({'BD': 32}, num_warps=8),
triton.Config({'BD': 64}, num_warps=1),
triton.Config({'BD': 64}, num_warps=2),
triton.Config({'BD': 64}, num_warps=4),
triton.Config({'BD': 64}, num_warps=8),
triton.Config({'BD': 128}, num_warps=1),
triton.Config({'BD': 128}, num_warps=2),
triton.Config({'BD': 128}, num_warps=4),
triton.Config({'BD': 128}, num_warps=8),
],
key=['D']
)
@triton.heuristics({
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None,
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.jit
def fused_recurrent_hgrn_bwd_kernel(
g,
o,
h0,
dx,
dg,
do,
dht,
dh0,
offsets,
T: tl.constexpr,
D: tl.constexpr,
BD: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
USE_FINAL_STATE_GRADIENT: tl.constexpr,
USE_OFFSETS: tl.constexpr
):
i_d, i_n = tl.program_id(0), tl.program_id(1)
if USE_OFFSETS:
bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64)
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
o_d = i_d * BD + tl.arange(0, BD)
mask = o_d < D
p_g = g + (bos + T - 1) * D + o_d
p_o = o + (bos + T - 2) * D + o_d
p_dx = dx + (bos + T - 1) * D + o_d
p_dg = dg + (bos + T - 1) * D + o_d
p_do = do + (bos + T - 1) * D + o_d
b_dh = tl.zeros([BD], dtype=tl.float32)
if USE_FINAL_STATE_GRADIENT:
p_dht = dht + i_n * D + o_d
b_dh += tl.load(p_dht, mask=mask, other=0).to(tl.float32)
for i in range(T - 1, -1, -1):
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)
b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)
if i > 0:
b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)
elif USE_INITIAL_STATE:
b_o = tl.load(h0 + i_n * D + o_d, mask=mask, other=0).to(tl.float32)
else:
b_o = tl.zeros([BD], dtype=tl.float32)
b_dh = b_dh + b_do
b_dx = b_dh
b_dh = b_dh * tl.exp(b_g)
b_dg = b_dh * b_o
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)
p_g -= D
p_o -= D
p_dx -= D
p_dg -= D
p_do -= D
if USE_INITIAL_STATE:
p_dh0 = dh0 + i_n * D + o_d
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask)
def fused_recurrent_hgrn_fwd(
x: torch.Tensor,
g: torch.Tensor,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
offsets: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
B, T, D = x.shape
N = B if offsets is None else len(offsets) - 1
o = torch.empty_like(x)
final_state = x.new_empty(N, D) if output_final_state else None
def grid(meta): return (triton.cdiv(D, meta['BD']), N)
fused_recurrent_hgrn_fwd_kernel[grid](
x=x,
g=g,
o=o,
h0=initial_state,
ht=final_state,
offsets=offsets,
T=T,
D=D
)
return o, final_state
def fused_recurrent_hgrn_bwd(
g: torch.Tensor,
o: torch.Tensor,
do: torch.Tensor,
dht: torch.Tensor = None,
initial_state: torch.Tensor = None,
offsets: Optional[torch.LongTensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
B, T, D = do.shape
N = B if offsets is None else len(offsets) - 1
dx = torch.empty_like(o, dtype=torch.float)
dg = torch.empty_like(g, dtype=torch.float)
dh0 = torch.empty_like(initial_state, dtype=torch.float) if initial_state is not None else None
def grid(meta): return (triton.cdiv(D, meta['BD']), N)
fused_recurrent_hgrn_bwd_kernel[grid](
g=g,
o=o,
h0=initial_state,
dx=dx,
dg=dg,
do=do,
dht=dht,
dh0=dh0,
offsets=offsets,
T=T,
D=D
)
return dx, dg, dh0
class FusedRecurrentHGRNFunction(torch.autograd.Function):
@staticmethod
@contiguous
def forward(
ctx,
x: torch.Tensor,
g: torch.Tensor,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
offsets: Optional[torch.LongTensor] = None
):
o, ht = fused_recurrent_hgrn_fwd(
x=x,
g=g,
initial_state=initial_state,
output_final_state=output_final_state,
offsets=offsets
)
ctx.save_for_backward(g, o, initial_state)
ctx.offsets = offsets
return o, ht
@staticmethod
@contiguous
def backward(ctx, do, dht=None):
g, o, initial_state = ctx.saved_tensors
offsets = ctx.offsets
dx, dg, dh0 = fused_recurrent_hgrn_bwd(
g=g,
o=o,
do=do,
dht=dht,
initial_state=initial_state,
offsets=offsets
)
return dx, dg, dh0, None, None
def fused_recurrent_hgrn(
x: torch.Tensor,
g: torch.Tensor,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
offsets: Optional[torch.LongTensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
x (torch.Tensor):
inputs of shape `[B, T, D].
g (torch.Tensor):
Forget gates of shape `[B, T, D]`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, D]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, D]`. Default: `False`.
offsets (Optional[torch.LongTensor]):
Offsets of shape `[N+1]` defining the bos/eos positions of `N` variable-length sequences in the batch.
For example,
if `offsets` is `[0, 1, 3, 6, 10, 15]`, there are `N=5` sequences with lengths 1, 2, 3, 4 and 5 respectively.
If provided, the inputs are concatenated and the batch size `B` is expected to be 1.
Default: `None`.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, D]`.
final_state (torch.Tensor):
Final state of shape `[N, D]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.hgrn import fused_recurrent_hgrn
# inputs with equal lengths
>>> B, T, D = 4, 2048, 512
>>> x = torch.randn(B, T, D, device='cuda')
>>> g = F.logsigmoid(torch.randn(B, T, D, device='cuda'))
>>> h0 = torch.randn(B, D, device='cuda')
>>> o, ht = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True)
# for variable-length inputs, the batch size `B` is expected to be 1 and `offsets` is required
>>> x, g = map(lambda x: rearrange(x, 'b t d -> 1 (b t) d'), (x, g))
# for a batch with 4 sequences, offsets with 5 start/end positions are expected
>>> offsets = x.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True, offsets=offsets)
>>> assert o.allclose(o_var.view(o.shape))
>>> assert ht.allclose(ht_var)
"""
return FusedRecurrentHGRNFunction.apply(
x,
g,
initial_state,
output_final_state,
offsets
)