File size: 1,394 Bytes
19b3da3 cd51d32 19b3da3 a3d6c18 19b3da3 a3d6c18 19b3da3 a3d6c18 cd51d32 a3d6c18 cd51d32 a3d6c18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import io
from pathlib import Path
from typing import Union
import torch
import torch.nn.functional as F
from PIL import Image
from rembg import remove
from carvekit.api.high import HiInterface
from internals.util.commons import read_url
class RemoveBackground:
def remove(self, image: Union[str, Image.Image]) -> Image.Image:
if type(image) is str:
image = Image.open(io.BytesIO(read_url(image)))
output = remove(image)
return output
class RemoveBackgroundV2:
def __init__(self):
self.interface = HiInterface(
object_type="object", # Can be "object" or "hairs-like".
batch_size_seg=5,
batch_size_matting=1,
device="cuda" if torch.cuda.is_available() else "cpu",
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=False,
)
def remove(self, image: Union[str, Image.Image]) -> Image.Image:
img_path = Path.home() / ".cache" / "rm_bg.png"
if type(image) is str:
image = Image.open(io.BytesIO(read_url(image)))
image.save(img_path)
images_without_background = self.interface([img_path])
out = images_without_background[0]
return out
|