import torch import torch.nn as nn import torch.nn.functional as F import torch.linalg as linalg from tqdm import tqdm def extract_conv( weight: nn.Parameter|torch.Tensor, mode = 'fixed', mode_param = 0, device = 'cpu', ) -> tuple[nn.Parameter, nn.Parameter]: out_ch, in_ch, kernel_size, _ = weight.shape U, S, Vh = linalg.svd(weight.reshape(out_ch, -1).to(device)) if mode=='fixed': lora_rank = mode_param elif mode=='threshold': assert mode_param>=0 lora_rank = torch.sum(S>mode_param) elif mode=='ratio': assert 1>=mode_param>=0 min_s = torch.max(S)*mode_param lora_rank = torch.sum(S>min_s) elif mode=='percentile': assert 1>=mode_param>=0 s_cum = torch.cumsum(S, dim=0) min_cum_sum = mode_param * torch.sum(S) lora_rank = torch.sum(s_cum tuple[nn.Parameter, nn.Parameter]: out_ch, in_ch = weight.shape U, S, Vh = linalg.svd(weight.to(device)) if mode=='fixed': lora_rank = mode_param elif mode=='threshold': assert mode_param>=0 lora_rank = torch.sum(S>mode_param) elif mode=='ratio': assert 1>=mode_param>=0 min_s = torch.max(S)*mode_param lora_rank = torch.sum(S>min_s) elif mode=='percentile': assert 1>=mode_param>=0 s_cum = torch.cumsum(S, dim=0) min_cum_sum = mode_param * torch.sum(S) lora_rank = torch.sum(s_cum