File size: 1,394 Bytes
19b3da3
cd51d32
19b3da3
 
a3d6c18
 
19b3da3
 
 
a3d6c18
19b3da3
 
 
 
 
 
 
 
 
 
a3d6c18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd51d32
a3d6c18
 
 
cd51d32
 
a3d6c18
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import io
from pathlib import Path
from typing import Union

import torch
import torch.nn.functional as F
from PIL import Image
from rembg import remove

from carvekit.api.high import HiInterface
from internals.util.commons import read_url


class RemoveBackground:
    def remove(self, image: Union[str, Image.Image]) -> Image.Image:
        if type(image) is str:
            image = Image.open(io.BytesIO(read_url(image)))

        output = remove(image)
        return output


class RemoveBackgroundV2:
    def __init__(self):
        self.interface = HiInterface(
            object_type="object",  # Can be "object" or "hairs-like".
            batch_size_seg=5,
            batch_size_matting=1,
            device="cuda" if torch.cuda.is_available() else "cpu",
            seg_mask_size=640,  # Use 640 for Tracer B7 and 320 for U2Net
            matting_mask_size=2048,
            trimap_prob_threshold=231,
            trimap_dilation=30,
            trimap_erosion_iters=5,
            fp16=False,
        )

    def remove(self, image: Union[str, Image.Image]) -> Image.Image:
        img_path = Path.home() / ".cache" / "rm_bg.png"
        if type(image) is str:
            image = Image.open(io.BytesIO(read_url(image)))

        image.save(img_path)
        images_without_background = self.interface([img_path])
        out = images_without_background[0]
        return out