Spaces:
Runtime error
Runtime error
File size: 4,430 Bytes
fb4fac3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import cupy as cp
remapping_kernel = cp.RawKernel(r'''
extern "C" __global__
void remap(
const int height,
const int width,
const int channel,
const int patch_size,
const int pad_size,
const float* source_style,
const int* nnf,
float* target_style
) {
const int r = (patch_size - 1) / 2;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
if (x >= height or y >= width) return;
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
const int min_px = x < r ? -x : -r;
const int max_px = x + r > height - 1 ? height - 1 - x : r;
const int min_py = y < r ? -y : -r;
const int max_py = y + r > width - 1 ? width - 1 - y : r;
int num = 0;
for (int px = min_px; px <= max_px; px++){
for (int py = min_py; py <= max_py; py++){
const int nid = (x + px) * width + y + py;
const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
num++;
for (int c = 0; c < channel; c++){
target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
}
}
}
for (int c = 0; c < channel; c++){
target_style[z + pid * channel + c] /= num;
}
}
''', 'remap')
patch_error_kernel = cp.RawKernel(r'''
extern "C" __global__
void patch_error(
const int height,
const int width,
const int channel,
const int patch_size,
const int pad_size,
const float* source,
const int* nnf,
const float* target,
float* error
) {
const int r = (patch_size - 1) / 2;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
if (x >= height or y >= width) return;
const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
float e = 0;
for (int px = -r; px <= r; px++){
for (int py = -r; py <= r; py++){
const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
for (int c = 0; c < channel; c++){
const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
e += diff * diff;
}
}
}
error[blockIdx.z * height * width + x * width + y] = e;
}
''', 'patch_error')
pairwise_patch_error_kernel = cp.RawKernel(r'''
extern "C" __global__
void pairwise_patch_error(
const int height,
const int width,
const int channel,
const int patch_size,
const int pad_size,
const float* source_a,
const int* nnf_a,
const float* source_b,
const int* nnf_b,
float* error
) {
const int r = (patch_size - 1) / 2;
const int x = blockDim.x * blockIdx.x + threadIdx.x;
const int y = blockDim.y * blockIdx.y + threadIdx.y;
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
if (x >= height or y >= width) return;
const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
const int x_a = nnf_a[z_nnf + 0];
const int y_a = nnf_a[z_nnf + 1];
const int x_b = nnf_b[z_nnf + 0];
const int y_b = nnf_b[z_nnf + 1];
float e = 0;
for (int px = -r; px <= r; px++){
for (int py = -r; py <= r; py++){
const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
for (int c = 0; c < channel; c++){
const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
e += diff * diff;
}
}
}
error[blockIdx.z * height * width + x * width + y] = e;
}
''', 'pairwise_patch_error')
|