File size: 4,461 Bytes
1c9edc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
from typing import Tuple, Callable
def hacer_nada(x: torch.Tensor, modo: str = None):
    return x
def brujeria_mps(entrada, dim, indice):
    if entrada.shape[-1] == 1:
        return torch.gather(entrada.unsqueeze(-1), dim - 1 if dim < 0 else dim, indice.unsqueeze(-1)).squeeze(-1)
    else:
        return torch.gather(entrada, dim, indice)
def emparejamiento_suave_aleatorio_2d(
    metrica: torch.Tensor,
    ancho: int,
    alto: int,
    paso_x: int,
    paso_y: int,
    radio: int,
    sin_aleatoriedad: bool = False,
    generador: torch.Generator = None
) -> Tuple[Callable, Callable]:
    lote, num_nodos, _ = metrica.shape
    if radio <= 0:
        return hacer_nada, hacer_nada
    recopilar = brujeria_mps if metrica.device.type == "mps" else torch.gather
    with torch.no_grad():
        alto_paso_y, ancho_paso_x = alto // paso_y, ancho // paso_x
        if sin_aleatoriedad:
            indice_aleatorio = torch.zeros(alto_paso_y, ancho_paso_x, 1, device=metrica.device, dtype=torch.int64)
        else:
            indice_aleatorio = torch.randint(paso_y * paso_x, size=(alto_paso_y, ancho_paso_x, 1), device=generador.device, generator=generador).to(metrica.device)
        vista_buffer_indice = torch.zeros(alto_paso_y, ancho_paso_x, paso_y * paso_x, device=metrica.device, dtype=torch.int64)
        vista_buffer_indice.scatter_(dim=2, index=indice_aleatorio, src=-torch.ones_like(indice_aleatorio, dtype=indice_aleatorio.dtype))
        vista_buffer_indice = vista_buffer_indice.view(alto_paso_y, ancho_paso_x, paso_y, paso_x).transpose(1, 2).reshape(alto_paso_y * paso_y, ancho_paso_x * paso_x)
        if (alto_paso_y * paso_y) < alto or (ancho_paso_x * paso_x) < ancho:
            buffer_indice = torch.zeros(alto, ancho, device=metrica.device, dtype=torch.int64)
            buffer_indice[:(alto_paso_y * paso_y), :(ancho_paso_x * paso_x)] = vista_buffer_indice
        else:
            buffer_indice = vista_buffer_indice
        indice_aleatorio = buffer_indice.reshape(1, -1, 1).argsort(dim=1)
        del buffer_indice, vista_buffer_indice
        num_destino = alto_paso_y * ancho_paso_x
        indices_a = indice_aleatorio[:, num_destino:, :]
        indices_b = indice_aleatorio[:, :num_destino, :]
        def dividir(x):
            canales = x.shape[-1]
            origen = recopilar(x, dim=1, index=indices_a.expand(lote, num_nodos - num_destino, canales))
            destino = recopilar(x, dim=1, index=indices_b.expand(lote, num_destino, canales))
            return origen, destino
        metrica = metrica / metrica.norm(dim=-1, keepdim=True)
        a, b = dividir(metrica)
        puntuaciones = a @ b.transpose(-1, -2)
        radio = min(a.shape[1], radio)
        nodo_max, nodo_indice = puntuaciones.max(dim=-1)
        indice_borde = nodo_max.argsort(dim=-1, descending=True)[..., None]
        indice_no_emparejado = indice_borde[..., radio:, :]
        indice_origen = indice_borde[..., :radio, :]
        indice_destino = recopilar(nodo_indice[..., None], dim=-2, index=indice_origen)
    def fusionar(x: torch.Tensor, modo="mean") -> torch.Tensor:
        origen, destino = dividir(x)
        n, t1, c = origen.shape
        no_emparejado = recopilar(origen, dim=-2, index=indice_no_emparejado.expand(n, t1 - radio, c))
        origen = recopilar(origen, dim=-2, index=indice_origen.expand(n, radio, c))
        destino = destino.scatter_reduce(-2, indice_destino.expand(n, radio, c), origen, reduce=modo)
        return torch.cat([no_emparejado, destino], dim=1)
    def desfusionar(x: torch.Tensor) -> torch.Tensor:
        longitud_no_emparejado = indice_no_emparejado.shape[1]
        no_emparejado, destino = x[..., :longitud_no_emparejado, :], x[..., longitud_no_emparejado:, :]
        _, _, c = no_emparejado.shape
        origen = recopilar(destino, dim=-2, index=indice_destino.expand(lote, radio, c))
        salida = torch.zeros(lote, num_nodos, c, device=x.device, dtype=x.dtype)
        salida.scatter_(dim=-2, index=indices_b.expand(lote, num_destino, c), src=destino)
        salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_no_emparejado).expand(lote, longitud_no_emparejado, c), src=no_emparejado)
        salida.scatter_(dim=-2, index=recopilar(indices_a.expand(lote, indices_a.shape[1], 1), dim=1, index=indice_origen).expand(lote, radio, c), src=origen)
        return salida
    return fusionar, desfusionar