Spaces:
Runtime error
Runtime error
import cupy as cp | |
remapping_kernel = cp.RawKernel(r''' | |
extern "C" __global__ | |
void remap( | |
const int height, | |
const int width, | |
const int channel, | |
const int patch_size, | |
const int pad_size, | |
const float* source_style, | |
const int* nnf, | |
float* target_style | |
) { | |
const int r = (patch_size - 1) / 2; | |
const int x = blockDim.x * blockIdx.x + threadIdx.x; | |
const int y = blockDim.y * blockIdx.y + threadIdx.y; | |
if (x >= height or y >= width) return; | |
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; | |
const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size); | |
const int min_px = x < r ? -x : -r; | |
const int max_px = x + r > height - 1 ? height - 1 - x : r; | |
const int min_py = y < r ? -y : -r; | |
const int max_py = y + r > width - 1 ? width - 1 - y : r; | |
int num = 0; | |
for (int px = min_px; px <= max_px; px++){ | |
for (int py = min_py; py <= max_py; py++){ | |
const int nid = (x + px) * width + y + py; | |
const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px; | |
const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py; | |
if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue; | |
const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size); | |
num++; | |
for (int c = 0; c < channel; c++){ | |
target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c]; | |
} | |
} | |
} | |
for (int c = 0; c < channel; c++){ | |
target_style[z + pid * channel + c] /= num; | |
} | |
} | |
''', 'remap') | |
patch_error_kernel = cp.RawKernel(r''' | |
extern "C" __global__ | |
void patch_error( | |
const int height, | |
const int width, | |
const int channel, | |
const int patch_size, | |
const int pad_size, | |
const float* source, | |
const int* nnf, | |
const float* target, | |
float* error | |
) { | |
const int r = (patch_size - 1) / 2; | |
const int x = blockDim.x * blockIdx.x + threadIdx.x; | |
const int y = blockDim.y * blockIdx.y + threadIdx.y; | |
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; | |
if (x >= height or y >= width) return; | |
const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0]; | |
const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1]; | |
float e = 0; | |
for (int px = -r; px <= r; px++){ | |
for (int py = -r; py <= r; py++){ | |
const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py; | |
const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py; | |
for (int c = 0; c < channel; c++){ | |
const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c]; | |
e += diff * diff; | |
} | |
} | |
} | |
error[blockIdx.z * height * width + x * width + y] = e; | |
} | |
''', 'patch_error') | |
pairwise_patch_error_kernel = cp.RawKernel(r''' | |
extern "C" __global__ | |
void pairwise_patch_error( | |
const int height, | |
const int width, | |
const int channel, | |
const int patch_size, | |
const int pad_size, | |
const float* source_a, | |
const int* nnf_a, | |
const float* source_b, | |
const int* nnf_b, | |
float* error | |
) { | |
const int r = (patch_size - 1) / 2; | |
const int x = blockDim.x * blockIdx.x + threadIdx.x; | |
const int y = blockDim.y * blockIdx.y + threadIdx.y; | |
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; | |
if (x >= height or y >= width) return; | |
const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2; | |
const int x_a = nnf_a[z_nnf + 0]; | |
const int y_a = nnf_a[z_nnf + 1]; | |
const int x_b = nnf_b[z_nnf + 0]; | |
const int y_b = nnf_b[z_nnf + 1]; | |
float e = 0; | |
for (int px = -r; px <= r; px++){ | |
for (int py = -r; py <= r; py++){ | |
const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py; | |
const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py; | |
for (int c = 0; c < channel; c++){ | |
const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c]; | |
e += diff * diff; | |
} | |
} | |
} | |
error[blockIdx.z * height * width + x * width + y] = e; | |
} | |
''', 'pairwise_patch_error') | |