import io | |
from pathlib import Path | |
from typing import Union | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
from rembg import remove | |
import internals.util.image as ImageUtil | |
from carvekit.api.high import HiInterface | |
from internals.util.commons import download_image, read_url | |
class RemoveBackground: | |
def remove(self, image: Union[str, Image.Image]) -> Image.Image: | |
if type(image) is str: | |
image = Image.open(io.BytesIO(read_url(image))) | |
output = remove(image) | |
return output | |
class RemoveBackgroundV2: | |
def __init__(self): | |
self.interface = HiInterface( | |
object_type="object", # Can be "object" or "hairs-like". | |
batch_size_seg=5, | |
batch_size_matting=1, | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net | |
matting_mask_size=2048, | |
trimap_prob_threshold=231, | |
trimap_dilation=30, | |
trimap_erosion_iters=5, | |
fp16=False, | |
) | |
def remove(self, image: Union[str, Image.Image]) -> Image.Image: | |
img_path = Path.home() / ".cache" / "rm_bg.png" | |
if type(image) is str: | |
image = download_image(image) | |
w, h = image.size | |
if max(w, h) > 1536: | |
image = ImageUtil.resize_image(image, dimension=1024) | |
image.save(img_path) | |
images_without_background = self.interface([img_path]) | |
out = images_without_background[0] | |
return out | |