File size: 2,910 Bytes
19b3da3
 
 
 
 
 
 
 
 
 
 
 
22df957
19b3da3
 
 
 
 
 
 
 
99a0484
 
19b3da3
99a0484
 
19b3da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99a0484
 
22df957
 
 
 
 
 
19b3da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
from pathlib import Path
from typing import List

import cv2
import numpy as np
import torch
import tqdm
from omegaconf import OmegaConf
from PIL import Image
from torch.utils.data._utils.collate import default_collate

from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import download_file, download_image
from internals.util.config import get_root_dir
from saicinpainting.evaluation.utils import move_to_device
from saicinpainting.training.data.datasets import make_default_val_dataset
from saicinpainting.training.trainers import load_checkpoint


class ObjectRemoval:
    __loaded = False

    def load(self, model_dir):
        if self.__loaded:
            return
        print("Downloading LAMA model...")

        self.lama_path = Path.home() / ".cache" / "lama"

        out_file = self.lama_path / "models" / "best.ckpt"
        os.makedirs(os.path.dirname(out_file), exist_ok=True)
        download_file(
            "https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt", out_file
        )
        config = OmegaConf.load(get_root_dir() + "/config.yml")
        config.training_model.predict_only = True
        self.model = load_checkpoint(
            config, str(out_file), strict=False, map_location="cuda"
        )
        self.model.freeze()
        self.model.to("cuda")

        self.__loaded = True

    def unload(self):
        self.__loaded = False
        self.model = None

        clear_cuda_and_gc()

    @torch.no_grad()
    def process(
        self,
        image_url: str,
        mask_image_url: str,
        seed: int,
        width: int,
        height: int,
    ) -> List:
        torch.manual_seed(seed)

        img_folder = self.lama_path / "images"
        indir = img_folder / "input"

        img_folder.mkdir(parents=True, exist_ok=True)
        indir.mkdir(parents=True, exist_ok=True)

        download_image(image_url).resize((width, height)).save(indir / "data.png")
        download_image(mask_image_url).resize((width, height)).save(
            indir / "data_mask.png"
        )

        dataset = make_default_val_dataset(
            img_folder / "input", img_suffix=".png", pad_out_to_modulo=8
        )

        out_images = []
        for img_i in tqdm.trange(len(dataset)):
            batch = move_to_device(default_collate([dataset[img_i]]), "cuda")
            batch["mask"] = (batch["mask"] > 0) * 1
            batch = self.model(batch)
            out_path = str(img_folder / "out.png")

            cur_res = batch["inpainted"][0].permute(1, 2, 0).detach().cpu().numpy()

            cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
            cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
            cv2.imwrite(out_path, cur_res)

            image = Image.open(out_path).convert("RGB")
            out_images.append(image)
            os.remove(out_path)

        return out_images