File size: 2,121 Bytes
d526dbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
r""" CHM 4D kernel (psi, iso, and full) generator """

import torch

from .geometry import Geometry


class KernelGenerator:
    def __init__(self, ksz, ktype):
        self.ksz = ksz
        self.idx4d = Geometry.init_idx4d(ksz)
        self.kernel = torch.zeros((ksz, ksz, ksz, ksz))
        self.center = (ksz // 2, ksz // 2)
        self.ktype = ktype

    def quadrant(self, crd):
        if crd[0] < self.center[0]:
            horz_quad = -1
        elif crd[0] < self.center[0]:
            horz_quad = 1
        else:
            horz_quad = 0

        if crd[1] < self.center[1]:
            vert_quad = -1
        elif crd[1] < self.center[1]:
            vert_quad = 1
        else:
            vert_quad = 0

        return horz_quad, vert_quad

    def generate(self):
        return None if self.ktype == 'full' else self.generate_chm_kernel()

    def generate_chm_kernel(self):
        param_dict = {}
        for idx in self.idx4d:
            src_i, src_j, trg_i, trg_j = idx
            d_tail = Geometry.get_distance((src_i, src_j), self.center)
            d_head = Geometry.get_distance((trg_i, trg_j), self.center)
            d_off = Geometry.get_distance((src_i, src_j), (trg_i, trg_j))
            horz_quad, vert_quad = self.quadrant((src_j, src_i))

            src_crd = (src_i, src_j)
            trg_crd = (trg_i, trg_j)

            key = self.build_key(horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off)
            coord1d = Geometry.get_coord1d((src_i, src_j, trg_i, trg_j), self.ksz)

            if param_dict.get(key) is None: param_dict[key] = []
            param_dict[key].append(coord1d)

        return param_dict

    def build_key(self, horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off):

        if self.ktype == 'iso':
            return '%d' % d_off
        elif self.ktype == 'psi':
            d_max = max(d_head, d_tail)
            d_min = min(d_head, d_tail)
            return '%d_%d_%d' % (d_max, d_min, d_off)
        else:
            raise Exception('not implemented.')