|
from typing import Optional |
|
|
|
import numpy as np |
|
import paddle |
|
import paddle.nn.functional as F |
|
|
|
|
|
def reverse_transform(alpha, trans_info): |
|
"""recover pred to origin shape""" |
|
for item in trans_info[::-1]: |
|
if item[0] == "resize": |
|
h, w = item[1][0], item[1][1] |
|
alpha = F.interpolate(alpha, [h, w], mode="bilinear") |
|
elif item[0] == "padding": |
|
h, w = item[1][0], item[1][1] |
|
alpha = alpha[:, :, 0:h, 0:w] |
|
else: |
|
raise Exception(f"Unexpected info '{item[0]}' in im_info") |
|
|
|
return alpha |
|
|
|
|
|
def preprocess(img, transforms, trimap=None): |
|
data = {} |
|
data["img"] = img |
|
if trimap is not None: |
|
data["trimap"] = trimap |
|
data["gt_fields"] = ["trimap"] |
|
data["trans_info"] = [] |
|
data = transforms(data) |
|
data["img"] = paddle.to_tensor(data["img"]) |
|
data["img"] = data["img"].unsqueeze(0) |
|
if trimap is not None: |
|
data["trimap"] = paddle.to_tensor(data["trimap"]) |
|
data["trimap"] = data["trimap"].unsqueeze((0, 1)) |
|
|
|
return data |
|
|
|
|
|
def predict( |
|
model, |
|
transforms, |
|
image: np.ndarray, |
|
trimap: Optional[np.ndarray] = None, |
|
): |
|
with paddle.no_grad(): |
|
data = preprocess(img=image, transforms=transforms, trimap=None) |
|
|
|
alpha = model(data) |
|
|
|
alpha = reverse_transform(alpha, data["trans_info"]) |
|
alpha = alpha.numpy().squeeze() |
|
|
|
if trimap is not None: |
|
alpha[trimap == 0] = 0 |
|
alpha[trimap == 255] = 1. |
|
|
|
return alpha |
|
|