File size: 2,538 Bytes
5cb1539
3a6f1f2
 
3015442
 
 
3a6f1f2
782da61
3a6f1f2
 
 
 
 
 
 
 
e332358
99c661e
 
 
3015442
 
 
 
 
99c661e
782da61
e332358
782da61
3015442
 
 
99c661e
3a6f1f2
 
e332358
3a6f1f2
ef928a1
 
df195bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
782da61
 
 
 
df195bf
5ee83f7
782da61
ef928a1
3015442
3a6f1f2
 
 
782da61
3a6f1f2
e332358
8f57daf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import os
import cv2
from rembg import new_session, remove
from PIL import Image
from io import BytesIO

def inference(file, mask, model, alpha_influence, segmentation_strength):
    im = cv2.imread(file, cv2.IMREAD_COLOR)
    cv2.imwrite(os.path.join("input.png"), im)

    input_path = 'input.png'
    output_path = 'output.png'

    with open(input_path, 'rb') as i:
        with open(output_path, 'wb') as o:
            input = i.read()
            output = remove(
                input, 
                only_mask=(True if mask == "Mask only" else False),
                alpha_matting=True,  # Habilitar el modo alpha matting
                alpha_matting_foreground_threshold=alpha_influence,  # Control de influencia del canal alfa
                alpha_matting_background_threshold=1 - alpha_influence,  # Control del canal alfa para el fondo
                alpha_matting_erode_size=int(segmentation_strength * 10),  # Control de fuerza de segmentación
                session=new_session(model)
            )

            o.write(output)

    return Image.open(BytesIO(output))

title = "RemBG"
description = "Gradio demo for RemBG. To use it, simply upload your image and adjust the alpha influence and segmentation strength."
article = "<p style='text-align: center;'><a href='https://github.com/danielgatis/rembg' target='_blank'>Github Repo</a></p>"

gr.Interface(
    inference, 
    [
        gr.inputs.Image(type="filepath", label="Input"),
        gr.inputs.Radio(
            [
                "Default", 
                "Mask only"
            ], 
            type="value",
            default="Default",
            label="Choices"
        ),
        gr.inputs.Dropdown([
            "u2net", 
            "u2netp", 
            "u2net_human_seg", 
            "u2net_cloth_seg", 
            "silueta",
            "isnet-general-use",
            "isnet-anime",
            "sam",
        ], 
        type="value",
        default="isnet-general-use",
        label="Models"
        ),
        gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.5, label="Alpha Influence"),
        gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.5, label="Segmentation Strength"),
    ], 
    gr.outputs.Image(type="PIL", label="Output"),
    title=title,
    description=description,
    article=article,
    examples=[["lion.png", "Default", "u2net", 0.5, 0.5], ["girl.jpg", "Default", "u2net", 0.5, 0.5], ["anime-girl.jpg", "Default", "isnet-anime", 0.5, 0.5]],
    enable_queue=True
).launch()