File size: 2,119 Bytes
0f90202
 
 
 
7f3b4d1
0f90202
 
 
 
 
 
 
 
 
7f3b4d1
 
 
 
0f90202
7f3b4d1
 
0f90202
 
7f3b4d1
 
0f90202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
695e16e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
from transformers import AutoProcessor, CLIPModel

clip_path = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(clip_path).eval()
processor = AutoProcessor.from_pretrained(clip_path)


async def predict(init_image, labels_level1):
    if init_image is None:
        return "", ""

    split_labels = labels_level1.split(",")
    ret_str = ""
    
    inputs = processor(
        text=split_labels, images=init_image, return_tensors="pt", padding=True
    )

    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image  # this is the image-text similarity score


    for i in range(len(split_labels)):
        ret_str += split_labels[i] + ": " + str(logits_per_image[0][i]) + "\n"

    return ret_str, ret_str


css = """
#container{
    margin: 0 auto;
    max-width: 80rem;
}
#intro{
    max-width: 100%;
    text-align: center;
    margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
    init_image_state = gr.State()
    with gr.Column(elem_id="container"):
        gr.Markdown(
            """# Clip Demo
            """,
            elem_id="intro",
        )
        with gr.Row():
            txt_input = gr.Textbox(
                value="cartoon,painting,screenshot",
                interactive=True, label="设定大类别类别", scale=5)
            txt = gr.Textbox(value="", label="Output:", scale=5)
            generate_bt = gr.Button("点击开始分类", scale=1)
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(
                    sources=["upload", "clipboard"],
                    label="User Image",
                    type="pil",
                )
        with gr.Row():
            prob_label = gr.Textbox(value="", label="一级分类")

        inputs = [image_input, txt_input]
        generate_bt.click(fn=predict, inputs=inputs, outputs=[txt, prob_label], show_progress=True)
        image_input.change(
            fn=predict,
            inputs=inputs,
            outputs=[txt, prob_label],
            show_progress=True,
            queue=False,
        )

demo.queue().launch()