Spaces:
Running
on
L40S
Running
on
L40S
from typing import Optional | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class Modulation(nn.Module): | |
def __init__( | |
self, | |
embedding_dim: int, | |
condition_dim: int, | |
zero_init: bool = False, | |
single_layer: bool = False, | |
): | |
super().__init__() | |
self.silu = nn.SiLU() | |
if single_layer: | |
self.linear1 = nn.Identity() | |
else: | |
self.linear1 = nn.Linear(condition_dim, condition_dim) | |
self.linear2 = nn.Linear(condition_dim, embedding_dim * 2) | |
# Only zero init the last linear layer | |
if zero_init: | |
nn.init.zeros_(self.linear2.weight) | |
nn.init.zeros_(self.linear2.bias) | |
def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor: | |
emb = self.linear2(self.silu(self.linear1(condition))) | |
scale, shift = torch.chunk(emb, 2, dim=1) | |
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) | |
return x | |
class FeedForward(nn.Module): | |
r""" | |
A feed-forward layer. | |
Parameters: | |
dim (`int`): The number of channels in the input. | |
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. | |
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. | |
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. | |
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. | |
""" | |
def __init__( | |
self, | |
dim: int, | |
dim_out: Optional[int] = None, | |
mult: int = 4, | |
dropout: float = 0.0, | |
activation_fn: str = "geglu", | |
final_dropout: bool = False, | |
): | |
super().__init__() | |
inner_dim = int(dim * mult) | |
dim_out = dim_out if dim_out is not None else dim | |
linear_cls = nn.Linear | |
if activation_fn == "gelu": | |
act_fn = GELU(dim, inner_dim) | |
if activation_fn == "gelu-approximate": | |
act_fn = GELU(dim, inner_dim, approximate="tanh") | |
elif activation_fn == "geglu": | |
act_fn = GEGLU(dim, inner_dim) | |
elif activation_fn == "geglu-approximate": | |
act_fn = ApproximateGELU(dim, inner_dim) | |
self.net = nn.ModuleList([]) | |
# project in | |
self.net.append(act_fn) | |
# project dropout | |
self.net.append(nn.Dropout(dropout)) | |
# project out | |
self.net.append(linear_cls(inner_dim, dim_out)) | |
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout | |
if final_dropout: | |
self.net.append(nn.Dropout(dropout)) | |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
for module in self.net: | |
hidden_states = module(hidden_states) | |
return hidden_states | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
query_dim: int, | |
heads: int = 8, | |
dim_head: int = 64, | |
dropout: float = 0.0, | |
bias: bool = False, | |
out_bias: bool = True, | |
): | |
super().__init__() | |
self.inner_dim = dim_head * heads | |
self.num_heads = heads | |
self.scale = dim_head**-0.5 | |
self.dropout = dropout | |
# Linear projections | |
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) | |
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) | |
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) | |
# Output projection | |
self.to_out = nn.ModuleList( | |
[ | |
nn.Linear(self.inner_dim, query_dim, bias=out_bias), | |
nn.Dropout(dropout), | |
] | |
) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
batch_size, sequence_length, _ = hidden_states.shape | |
# Project queries, keys, and values | |
query = self.to_q(hidden_states) | |
key = self.to_k(hidden_states) | |
value = self.to_v(hidden_states) | |
# Reshape for multi-head attention | |
query = query.reshape( | |
batch_size, sequence_length, self.num_heads, -1 | |
).transpose(1, 2) | |
key = key.reshape(batch_size, sequence_length, self.num_heads, -1).transpose( | |
1, 2 | |
) | |
value = value.reshape( | |
batch_size, sequence_length, self.num_heads, -1 | |
).transpose(1, 2) | |
# Compute scaled dot product attention | |
hidden_states = torch.nn.functional.scaled_dot_product_attention( | |
query, | |
key, | |
value, | |
attn_mask=attention_mask, | |
scale=self.scale, | |
) | |
# Reshape and project output | |
hidden_states = hidden_states.transpose(1, 2).reshape( | |
batch_size, sequence_length, self.inner_dim | |
) | |
# Apply output projection and dropout | |
for module in self.to_out: | |
hidden_states = module(hidden_states) | |
return hidden_states | |
class BasicTransformerBlock(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
num_attention_heads: int, | |
attention_head_dim: int, | |
activation_fn: str = "geglu", | |
attention_bias: bool = False, | |
norm_elementwise_affine: bool = True, | |
norm_eps: float = 1e-5, | |
): | |
super().__init__() | |
# Self-Attn | |
self.norm1 = nn.LayerNorm( | |
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps | |
) | |
self.attn1 = Attention( | |
query_dim=dim, | |
heads=num_attention_heads, | |
dim_head=attention_head_dim, | |
bias=attention_bias, | |
) | |
# Feed-forward | |
self.norm3 = nn.LayerNorm( | |
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps | |
) | |
self.ff = FeedForward( | |
dim, | |
activation_fn=activation_fn, | |
) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
# Self-Attention | |
norm_hidden_states = self.norm1(hidden_states) | |
hidden_states = ( | |
self.attn1( | |
norm_hidden_states, | |
attention_mask=attention_mask, | |
) | |
+ hidden_states | |
) | |
# Feed-forward | |
ff_output = self.ff(self.norm3(hidden_states)) | |
hidden_states = ff_output + hidden_states | |
return hidden_states | |
class GELU(nn.Module): | |
r""" | |
GELU activation function with tanh approximation support with `approximate="tanh"`. | |
Parameters: | |
dim_in (`int`): The number of channels in the input. | |
dim_out (`int`): The number of channels in the output. | |
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. | |
""" | |
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): | |
super().__init__() | |
self.proj = nn.Linear(dim_in, dim_out) | |
self.approximate = approximate | |
def gelu(self, gate: torch.Tensor) -> torch.Tensor: | |
if gate.device.type != "mps": | |
return F.gelu(gate, approximate=self.approximate) | |
# mps: gelu is not implemented for float16 | |
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to( | |
dtype=gate.dtype | |
) | |
def forward(self, hidden_states): | |
hidden_states = self.proj(hidden_states) | |
hidden_states = self.gelu(hidden_states) | |
return hidden_states | |
class GEGLU(nn.Module): | |
r""" | |
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. | |
Parameters: | |
dim_in (`int`): The number of channels in the input. | |
dim_out (`int`): The number of channels in the output. | |
""" | |
def __init__(self, dim_in: int, dim_out: int): | |
super().__init__() | |
linear_cls = nn.Linear | |
self.proj = linear_cls(dim_in, dim_out * 2) | |
def gelu(self, gate: torch.Tensor) -> torch.Tensor: | |
if gate.device.type != "mps": | |
return F.gelu(gate) | |
# mps: gelu is not implemented for float16 | |
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) | |
def forward(self, hidden_states, scale: float = 1.0): | |
args = () | |
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) | |
return hidden_states * self.gelu(gate) | |
class ApproximateGELU(nn.Module): | |
r""" | |
The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2: | |
https://arxiv.org/abs/1606.08415. | |
Parameters: | |
dim_in (`int`): The number of channels in the input. | |
dim_out (`int`): The number of channels in the output. | |
""" | |
def __init__(self, dim_in: int, dim_out: int): | |
super().__init__() | |
self.proj = nn.Linear(dim_in, dim_out) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.proj(x) | |
return x * torch.sigmoid(1.702 * x) | |