|
from functools import partial |
|
from typing import Sequence |
|
|
|
import torch as T |
|
|
|
|
|
def svd_truncated(mat: T.Tensor, rank: int): |
|
lvecs, svals, rvecs = T.linalg.svd(mat) |
|
return lvecs[:, :rank], svals[:rank], rvecs[:rank, :].T |
|
|
|
|
|
def ttd(ten: T.Tensor, rank: Sequence[int], noiters: int = 1000, |
|
method: str = 'tsvd') -> Sequence[T.Tensor]: |
|
"""Function ttd implements tensor-train decomposition. |
|
""" |
|
if ten.ndim + 1 != len(rank): |
|
raise ValueError |
|
if rank[0] != 1 or rank[-1] != 1: |
|
raise ValueError |
|
|
|
if method == 'svd': |
|
factorize = svd_truncated |
|
elif method == 'tsvd': |
|
factorize = partial(T.svd_lowrank, niter=noiters) |
|
else: |
|
raise ValueError(f'Unknown method: {method}.') |
|
|
|
cores = [] |
|
shape = ten.shape |
|
|
|
|
|
for core_shape in zip(rank, shape, rank[1:]): |
|
|
|
|
|
mat = ten.reshape(core_shape[0] * core_shape[1], -1) |
|
|
|
lvecs, svals, rvecs = factorize(mat, core_shape[2]) |
|
|
|
core = lvecs * svals[None, :] |
|
core = core.reshape(core_shape) |
|
cores.append(core) |
|
|
|
ten = rvecs.T |
|
|
|
return cores |
|
|