|
import os |
|
from pathlib import Path |
|
from typing import List |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import tqdm |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
from torch.utils.data._utils.collate import default_collate |
|
|
|
from internals.util.commons import download_file, download_image |
|
from internals.util.config import get_root_dir |
|
from saicinpainting.evaluation.utils import move_to_device |
|
from saicinpainting.training.data.datasets import make_default_val_dataset |
|
from saicinpainting.training.trainers import load_checkpoint |
|
|
|
|
|
class ObjectRemoval: |
|
def load(self, model_dir): |
|
print("Downloading LAMA model...") |
|
|
|
self.lama_path = Path.home() / ".cache" / "lama" |
|
|
|
out_file = self.lama_path / "models" / "best.ckpt" |
|
os.makedirs(os.path.dirname(out_file), exist_ok=True) |
|
download_file( |
|
"https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt", out_file |
|
) |
|
config = OmegaConf.load(get_root_dir() + "/config.yml") |
|
config.training_model.predict_only = True |
|
self.model = load_checkpoint( |
|
config, str(out_file), strict=False, map_location="cuda" |
|
) |
|
self.model.freeze() |
|
self.model.to("cuda") |
|
|
|
@torch.no_grad() |
|
def process( |
|
self, |
|
image_url: str, |
|
mask_image_url: str, |
|
seed: int, |
|
width: int, |
|
height: int, |
|
) -> List: |
|
torch.manual_seed(seed) |
|
|
|
img_folder = self.lama_path / "images" |
|
indir = img_folder / "input" |
|
|
|
img_folder.mkdir(parents=True, exist_ok=True) |
|
indir.mkdir(parents=True, exist_ok=True) |
|
|
|
download_image(image_url).resize((width, height)).save(indir / "data.png") |
|
download_image(mask_image_url).resize((width, height)).save( |
|
indir / "data_mask.png" |
|
) |
|
|
|
dataset = make_default_val_dataset( |
|
img_folder / "input", img_suffix=".png", pad_out_to_modulo=8 |
|
) |
|
|
|
out_images = [] |
|
for img_i in tqdm.trange(len(dataset)): |
|
batch = move_to_device(default_collate([dataset[img_i]]), "cuda") |
|
batch["mask"] = (batch["mask"] > 0) * 1 |
|
batch = self.model(batch) |
|
out_path = str(img_folder / "out.png") |
|
|
|
cur_res = batch["inpainted"][0].permute(1, 2, 0).detach().cpu().numpy() |
|
|
|
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") |
|
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) |
|
cv2.imwrite(out_path, cur_res) |
|
|
|
image = Image.open(out_path).convert("RGB") |
|
out_images.append(image) |
|
os.remove(out_path) |
|
|
|
return out_images |
|
|