import cupy as cp remapping_kernel = cp.RawKernel(r''' extern "C" __global__ void remap( const int height, const int width, const int channel, const int patch_size, const int pad_size, const float* source_style, const int* nnf, float* target_style ) { const int r = (patch_size - 1) / 2; const int x = blockDim.x * blockIdx.x + threadIdx.x; const int y = blockDim.y * blockIdx.y + threadIdx.y; if (x >= height or y >= width) return; const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size); const int min_px = x < r ? -x : -r; const int max_px = x + r > height - 1 ? height - 1 - x : r; const int min_py = y < r ? -y : -r; const int max_py = y + r > width - 1 ? width - 1 - y : r; int num = 0; for (int px = min_px; px <= max_px; px++){ for (int py = min_py; py <= max_py; py++){ const int nid = (x + px) * width + y + py; const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px; const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py; if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue; const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size); num++; for (int c = 0; c < channel; c++){ target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c]; } } } for (int c = 0; c < channel; c++){ target_style[z + pid * channel + c] /= num; } } ''', 'remap') patch_error_kernel = cp.RawKernel(r''' extern "C" __global__ void patch_error( const int height, const int width, const int channel, const int patch_size, const int pad_size, const float* source, const int* nnf, const float* target, float* error ) { const int r = (patch_size - 1) / 2; const int x = blockDim.x * blockIdx.x + threadIdx.x; const int y = blockDim.y * blockIdx.y + threadIdx.y; const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; if (x >= height or y >= width) return; const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0]; const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1]; float e = 0; for (int px = -r; px <= r; px++){ for (int py = -r; py <= r; py++){ const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py; const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py; for (int c = 0; c < channel; c++){ const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c]; e += diff * diff; } } } error[blockIdx.z * height * width + x * width + y] = e; } ''', 'patch_error') pairwise_patch_error_kernel = cp.RawKernel(r''' extern "C" __global__ void pairwise_patch_error( const int height, const int width, const int channel, const int patch_size, const int pad_size, const float* source_a, const int* nnf_a, const float* source_b, const int* nnf_b, float* error ) { const int r = (patch_size - 1) / 2; const int x = blockDim.x * blockIdx.x + threadIdx.x; const int y = blockDim.y * blockIdx.y + threadIdx.y; const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel; if (x >= height or y >= width) return; const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2; const int x_a = nnf_a[z_nnf + 0]; const int y_a = nnf_a[z_nnf + 1]; const int x_b = nnf_b[z_nnf + 0]; const int y_b = nnf_b[z_nnf + 1]; float e = 0; for (int px = -r; px <= r; px++){ for (int py = -r; py <= r; py++){ const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py; const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py; for (int c = 0; c < channel; c++){ const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c]; e += diff * diff; } } } error[blockIdx.z * height * width + x * width + y] = e; } ''', 'pairwise_patch_error')