CM2000112 / internals /pipelines /remove_background.py
jayparmr's picture
Upload folder using huggingface_hub
cd51d32
raw
history blame
1.39 kB
import io
from pathlib import Path
from typing import Union
import torch
import torch.nn.functional as F
from PIL import Image
from rembg import remove
from carvekit.api.high import HiInterface
from internals.util.commons import read_url
class RemoveBackground:
def remove(self, image: Union[str, Image.Image]) -> Image.Image:
if type(image) is str:
image = Image.open(io.BytesIO(read_url(image)))
output = remove(image)
return output
class RemoveBackgroundV2:
def __init__(self):
self.interface = HiInterface(
object_type="object", # Can be "object" or "hairs-like".
batch_size_seg=5,
batch_size_matting=1,
device="cuda" if torch.cuda.is_available() else "cpu",
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=False,
)
def remove(self, image: Union[str, Image.Image]) -> Image.Image:
img_path = Path.home() / ".cache" / "rm_bg.png"
if type(image) is str:
image = Image.open(io.BytesIO(read_url(image)))
image.save(img_path)
images_without_background = self.interface([img_path])
out = images_without_background[0]
return out