|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" op.py """ |
|
import math |
|
from packaging.version import parse as VersionParse |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from transformers.models.t5.modeling_t5 import T5LayerNorm as RMSNorm |
|
|
|
|
|
def get_layer_norm(dim: int, layer_norm_type: str = "layer_norm", layer_norm_eps: float = 1e-5): |
|
"""Get layer normalization layer. |
|
Args: |
|
dim (int): Feature dimension |
|
layer_norm_type (str): "layer_norm" or "rms_norm" |
|
layer_norm_eps (float): Epsilon value for numerical stability |
|
|
|
Returns: |
|
nn.Module: Layer normalization layer |
|
""" |
|
if layer_norm_type == "rms_norm": |
|
|
|
return RMSNorm(hidden_size=dim, eps=layer_norm_eps) |
|
else: |
|
return nn.LayerNorm(normalized_shape=dim, eps=layer_norm_eps) |
|
|
|
|
|
def check_all_elements_equal(x: torch.Tensor) -> bool: |
|
return x.eq(x[0]).all().item() |
|
|
|
|
|
def minmax_normalize(x: torch.Tensor, eps: float = 0.008) -> torch.FloatTensor: |
|
"""Min-max normalization: |
|
|
|
x_norm = (x - x_min) / (x_max - x_min + eps) |
|
|
|
Args: |
|
x (torch.Tensor): (B, T, F) |
|
Returns: |
|
torch.Tensor: (B, T, F) with output range of [0, 1] |
|
""" |
|
x_max = rearrange(x, "b t f -> b (t f)").max(1, keepdim=True)[0] |
|
x_min = rearrange(x, "b t f -> b (f t)").min(1, keepdim=True)[0] |
|
x_max = x_max[:, None, :] |
|
x_min = x_min[:, None, :] |
|
return (x - x_min) / (x_max - x_min + eps) |
|
|
|
|
|
def count_parameters(model): |
|
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
num_params = sum(p.numel() for p in model.parameters()) |
|
return num_trainable_params, num_params |
|
|
|
|
|
def adjust_b_to_gcd(a, b, min_gcd=16): |
|
""" |
|
Adjust the value of b to ensure the GCD(a, b) is at least min_gcd with minimum change to b. |
|
|
|
Parameters: |
|
- a (int): A positive integer |
|
- b (int): A positive integer |
|
- min_gcd (int): The minimum desired GCD |
|
|
|
Returns: |
|
- int: The adjusted value of b |
|
""" |
|
current_gcd = math.gcd(a, b) |
|
|
|
|
|
if current_gcd >= min_gcd: |
|
return b |
|
|
|
|
|
if a < min_gcd: |
|
raise ValueError("a must be at least as large as min_gcd.") |
|
|
|
|
|
adjusted_b_up = b |
|
adjusted_b_down = b |
|
|
|
while True: |
|
adjusted_b_up += 1 |
|
adjusted_b_down -= 1 |
|
|
|
if math.gcd(a, adjusted_b_up) >= min_gcd: |
|
return adjusted_b_up |
|
elif math.gcd(a, adjusted_b_down) >= min_gcd: |
|
return adjusted_b_down |
|
|
|
|
|
def optional_compiler_disable(func): |
|
if VersionParse(torch.__version__) >= VersionParse("2.1"): |
|
|
|
return torch.compiler.disable(func) |
|
else: |
|
|
|
return func |
|
|
|
|
|
def optional_compiler_dynamic(func): |
|
return torch.compile(func, dynamic=True) |
|
|