|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
from fla.utils import contiguous |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BD': 32}, num_warps=1), |
|
triton.Config({'BD': 32}, num_warps=2), |
|
triton.Config({'BD': 32}, num_warps=4), |
|
triton.Config({'BD': 32}, num_warps=8), |
|
triton.Config({'BD': 64}, num_warps=1), |
|
triton.Config({'BD': 64}, num_warps=2), |
|
triton.Config({'BD': 64}, num_warps=4), |
|
triton.Config({'BD': 64}, num_warps=8), |
|
triton.Config({'BD': 128}, num_warps=1), |
|
triton.Config({'BD': 128}, num_warps=2), |
|
triton.Config({'BD': 128}, num_warps=4), |
|
triton.Config({'BD': 128}, num_warps=8), |
|
], |
|
key=['D'] |
|
) |
|
@triton.heuristics({ |
|
'USE_INITIAL_STATE': lambda args: args['h0'] is not None, |
|
'STORE_FINAL_STATE': lambda args: args['ht'] is not None, |
|
'USE_OFFSETS': lambda args: args['offsets'] is not None |
|
}) |
|
@triton.jit |
|
def fused_recurrent_hgrn_fwd_kernel( |
|
x, |
|
g, |
|
o, |
|
h0, |
|
ht, |
|
offsets, |
|
T: tl.constexpr, |
|
D: tl.constexpr, |
|
BD: tl.constexpr, |
|
USE_INITIAL_STATE: tl.constexpr, |
|
STORE_FINAL_STATE: tl.constexpr, |
|
USE_OFFSETS: tl.constexpr |
|
): |
|
i_d, i_n = tl.program_id(0), tl.program_id(1) |
|
if USE_OFFSETS: |
|
bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) |
|
T = eos - bos |
|
else: |
|
bos, eos = i_n * T, i_n * T + T |
|
|
|
o_d = i_d * BD + tl.arange(0, BD) |
|
mask = o_d < D |
|
|
|
p_x = x + bos * D + o_d |
|
p_g = g + bos * D + o_d |
|
p_o = o + bos * D + o_d |
|
|
|
b_h = tl.zeros([BD], dtype=tl.float32) |
|
if USE_INITIAL_STATE: |
|
p_h0 = h0 + i_n * D + o_d |
|
b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32) |
|
for _ in range(0, T): |
|
b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32) |
|
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) |
|
b_h = tl.exp(b_g) * b_h + b_x |
|
tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask) |
|
|
|
p_x += D |
|
p_g += D |
|
p_o += D |
|
|
|
if STORE_FINAL_STATE: |
|
p_ht = ht + i_n * D + o_d |
|
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask) |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BD': 32}, num_warps=1), |
|
triton.Config({'BD': 32}, num_warps=2), |
|
triton.Config({'BD': 32}, num_warps=4), |
|
triton.Config({'BD': 32}, num_warps=8), |
|
triton.Config({'BD': 64}, num_warps=1), |
|
triton.Config({'BD': 64}, num_warps=2), |
|
triton.Config({'BD': 64}, num_warps=4), |
|
triton.Config({'BD': 64}, num_warps=8), |
|
triton.Config({'BD': 128}, num_warps=1), |
|
triton.Config({'BD': 128}, num_warps=2), |
|
triton.Config({'BD': 128}, num_warps=4), |
|
triton.Config({'BD': 128}, num_warps=8), |
|
], |
|
key=['D'] |
|
) |
|
@triton.heuristics({ |
|
'USE_INITIAL_STATE': lambda args: args['h0'] is not None, |
|
'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, |
|
'USE_OFFSETS': lambda args: args['offsets'] is not None |
|
}) |
|
@triton.jit |
|
def fused_recurrent_hgrn_bwd_kernel( |
|
g, |
|
o, |
|
h0, |
|
dx, |
|
dg, |
|
do, |
|
dht, |
|
dh0, |
|
offsets, |
|
T: tl.constexpr, |
|
D: tl.constexpr, |
|
BD: tl.constexpr, |
|
USE_INITIAL_STATE: tl.constexpr, |
|
USE_FINAL_STATE_GRADIENT: tl.constexpr, |
|
USE_OFFSETS: tl.constexpr |
|
): |
|
i_d, i_n = tl.program_id(0), tl.program_id(1) |
|
if USE_OFFSETS: |
|
bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) |
|
T = eos - bos |
|
else: |
|
bos, eos = i_n * T, i_n * T + T |
|
|
|
o_d = i_d * BD + tl.arange(0, BD) |
|
mask = o_d < D |
|
|
|
p_g = g + (bos + T - 1) * D + o_d |
|
p_o = o + (bos + T - 2) * D + o_d |
|
p_dx = dx + (bos + T - 1) * D + o_d |
|
p_dg = dg + (bos + T - 1) * D + o_d |
|
p_do = do + (bos + T - 1) * D + o_d |
|
|
|
b_dh = tl.zeros([BD], dtype=tl.float32) |
|
if USE_FINAL_STATE_GRADIENT: |
|
p_dht = dht + i_n * D + o_d |
|
b_dh += tl.load(p_dht, mask=mask, other=0).to(tl.float32) |
|
|
|
for i in range(T - 1, -1, -1): |
|
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) |
|
b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) |
|
if i > 0: |
|
b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32) |
|
elif USE_INITIAL_STATE: |
|
b_o = tl.load(h0 + i_n * D + o_d, mask=mask, other=0).to(tl.float32) |
|
else: |
|
b_o = tl.zeros([BD], dtype=tl.float32) |
|
|
|
b_dh = b_dh + b_do |
|
b_dx = b_dh |
|
b_dh = b_dh * tl.exp(b_g) |
|
b_dg = b_dh * b_o |
|
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) |
|
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask) |
|
|
|
p_g -= D |
|
p_o -= D |
|
p_dx -= D |
|
p_dg -= D |
|
p_do -= D |
|
|
|
if USE_INITIAL_STATE: |
|
p_dh0 = dh0 + i_n * D + o_d |
|
tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask) |
|
|
|
|
|
def fused_recurrent_hgrn_fwd( |
|
x: torch.Tensor, |
|
g: torch.Tensor, |
|
initial_state: torch.Tensor = None, |
|
output_final_state: bool = False, |
|
offsets: Optional[torch.LongTensor] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
B, T, D = x.shape |
|
N = B if offsets is None else len(offsets) - 1 |
|
|
|
o = torch.empty_like(x) |
|
final_state = x.new_empty(N, D) if output_final_state else None |
|
|
|
def grid(meta): return (triton.cdiv(D, meta['BD']), N) |
|
fused_recurrent_hgrn_fwd_kernel[grid]( |
|
x=x, |
|
g=g, |
|
o=o, |
|
h0=initial_state, |
|
ht=final_state, |
|
offsets=offsets, |
|
T=T, |
|
D=D |
|
) |
|
return o, final_state |
|
|
|
|
|
def fused_recurrent_hgrn_bwd( |
|
g: torch.Tensor, |
|
o: torch.Tensor, |
|
do: torch.Tensor, |
|
dht: torch.Tensor = None, |
|
initial_state: torch.Tensor = None, |
|
offsets: Optional[torch.LongTensor] = None |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
B, T, D = do.shape |
|
N = B if offsets is None else len(offsets) - 1 |
|
|
|
dx = torch.empty_like(o, dtype=torch.float) |
|
dg = torch.empty_like(g, dtype=torch.float) |
|
dh0 = torch.empty_like(initial_state, dtype=torch.float) if initial_state is not None else None |
|
def grid(meta): return (triton.cdiv(D, meta['BD']), N) |
|
fused_recurrent_hgrn_bwd_kernel[grid]( |
|
g=g, |
|
o=o, |
|
h0=initial_state, |
|
dx=dx, |
|
dg=dg, |
|
do=do, |
|
dht=dht, |
|
dh0=dh0, |
|
offsets=offsets, |
|
T=T, |
|
D=D |
|
) |
|
return dx, dg, dh0 |
|
|
|
|
|
class FusedRecurrentHGRNFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
@contiguous |
|
def forward( |
|
ctx, |
|
x: torch.Tensor, |
|
g: torch.Tensor, |
|
initial_state: torch.Tensor = None, |
|
output_final_state: bool = False, |
|
offsets: Optional[torch.LongTensor] = None |
|
): |
|
o, ht = fused_recurrent_hgrn_fwd( |
|
x=x, |
|
g=g, |
|
initial_state=initial_state, |
|
output_final_state=output_final_state, |
|
offsets=offsets |
|
) |
|
ctx.save_for_backward(g, o, initial_state) |
|
ctx.offsets = offsets |
|
return o, ht |
|
|
|
@staticmethod |
|
@contiguous |
|
def backward(ctx, do, dht=None): |
|
g, o, initial_state = ctx.saved_tensors |
|
offsets = ctx.offsets |
|
|
|
dx, dg, dh0 = fused_recurrent_hgrn_bwd( |
|
g=g, |
|
o=o, |
|
do=do, |
|
dht=dht, |
|
initial_state=initial_state, |
|
offsets=offsets |
|
) |
|
return dx, dg, dh0, None, None |
|
|
|
|
|
def fused_recurrent_hgrn( |
|
x: torch.Tensor, |
|
g: torch.Tensor, |
|
initial_state: torch.Tensor = None, |
|
output_final_state: bool = False, |
|
offsets: Optional[torch.LongTensor] = None |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
r""" |
|
Args: |
|
x (torch.Tensor): |
|
inputs of shape `[B, T, D]. |
|
g (torch.Tensor): |
|
Forget gates of shape `[B, T, D]`. |
|
initial_state (Optional[torch.Tensor]): |
|
Initial state of shape `[N, D]` for `N` input sequences. |
|
For equal-length input sequences, `N` equals the batch size `B`. |
|
Default: `None`. |
|
output_final_state (Optional[bool]): |
|
Whether to output the final state of shape `[N, D]`. Default: `False`. |
|
offsets (Optional[torch.LongTensor]): |
|
Offsets of shape `[N+1]` defining the bos/eos positions of `N` variable-length sequences in the batch. |
|
For example, |
|
if `offsets` is `[0, 1, 3, 6, 10, 15]`, there are `N=5` sequences with lengths 1, 2, 3, 4 and 5 respectively. |
|
If provided, the inputs are concatenated and the batch size `B` is expected to be 1. |
|
Default: `None`. |
|
|
|
Returns: |
|
o (torch.Tensor): |
|
Outputs of shape `[B, T, D]`. |
|
final_state (torch.Tensor): |
|
Final state of shape `[N, D]` if `output_final_state=True` else `None`. |
|
|
|
Examples:: |
|
>>> import torch |
|
>>> import torch.nn.functional as F |
|
>>> from einops import rearrange |
|
>>> from fla.ops.hgrn import fused_recurrent_hgrn |
|
# inputs with equal lengths |
|
>>> B, T, D = 4, 2048, 512 |
|
>>> x = torch.randn(B, T, D, device='cuda') |
|
>>> g = F.logsigmoid(torch.randn(B, T, D, device='cuda')) |
|
>>> h0 = torch.randn(B, D, device='cuda') |
|
>>> o, ht = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True) |
|
# for variable-length inputs, the batch size `B` is expected to be 1 and `offsets` is required |
|
>>> x, g = map(lambda x: rearrange(x, 'b t d -> 1 (b t) d'), (x, g)) |
|
# for a batch with 4 sequences, offsets with 5 start/end positions are expected |
|
>>> offsets = x.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) |
|
>>> o_var, ht_var = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True, offsets=offsets) |
|
>>> assert o.allclose(o_var.view(o.shape)) |
|
>>> assert ht.allclose(ht_var) |
|
""" |
|
return FusedRecurrentHGRNFunction.apply( |
|
x, |
|
g, |
|
initial_state, |
|
output_final_state, |
|
offsets |
|
) |
|
|