File size: 1,387 Bytes
4cacee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from functools import partial
from typing import Sequence

import torch as T


def svd_truncated(mat: T.Tensor, rank: int):
    lvecs, svals, rvecs = T.linalg.svd(mat)
    return lvecs[:, :rank], svals[:rank], rvecs[:rank, :].T


def ttd(ten: T.Tensor, rank: Sequence[int], noiters: int = 1000,
        method: str = 'tsvd') -> Sequence[T.Tensor]:
    """Function ttd implements tensor-train decomposition.
    """
    if ten.ndim + 1 != len(rank):
        raise ValueError
    if rank[0] != 1 or rank[-1] != 1:
        raise ValueError

    if method == 'svd':
        factorize = svd_truncated
    elif method == 'tsvd':
        factorize = partial(T.svd_lowrank, niter=noiters)
    else:
        raise ValueError(f'Unknown method: {method}.')

    cores = []
    shape = ten.shape

    # Iterate over shape of cores and split off core from tensor.
    for core_shape in zip(rank, shape, rank[1:]):
        # breakpoint()
        # Matricization of tensor over the first two axes.
        mat = ten.reshape(core_shape[0] * core_shape[1], -1)
        # Singlular Value Decomposition (SVD).
        lvecs, svals, rvecs = factorize(mat, core_shape[2])
        # Reshape core and rest of tensor.
        core = lvecs * svals[None, :]
        core = core.reshape(core_shape)
        cores.append(core)
        # Use right vectors as a tensor itself.
        ten = rvecs.T

    return cores