keyishen commited on
Commit
0f90202
·
verified ·
1 Parent(s): 2f9f177

Create app.py

Browse files

1. add clip demo app

Files changed (1) hide show
  1. app.py +78 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoProcessor, CLIPModel
4
+
5
+ clip_path = "openai/clip-vit-base-patch32"
6
+ model = CLIPModel.from_pretrained(clip_path)
7
+ processor = AutoProcessor.from_pretrained(clip_path)
8
+
9
+
10
+ async def predict(init_image, labels_level1):
11
+ if init_image is None:
12
+ return "", ""
13
+
14
+ split_labels = labels_level1.split(",")
15
+ ret_str = ""
16
+
17
+ with torch.no_grad(), torch.cuda.amp.autocast():
18
+ inputs = processor(
19
+ text=split_labels, images=init_image, return_tensors="pt", padding=True
20
+ )
21
+
22
+ outputs = model(**inputs)
23
+ logits_per_image = outputs.logits_per_image # this is the image-text similarity score
24
+
25
+
26
+ for i in range(len(split_labels)):
27
+ ret_str += split_labels[i] + ": " + str(logits_per_image[0][i]) + "\n"
28
+
29
+ return ret_str, ret_str
30
+
31
+
32
+ css = """
33
+ #container{
34
+ margin: 0 auto;
35
+ max-width: 80rem;
36
+ }
37
+ #intro{
38
+ max-width: 100%;
39
+ text-align: center;
40
+ margin: 0 auto;
41
+ }
42
+ """
43
+ with gr.Blocks(css=css) as demo:
44
+ init_image_state = gr.State()
45
+ with gr.Column(elem_id="container"):
46
+ gr.Markdown(
47
+ """# Clip Demo
48
+ """,
49
+ elem_id="intro",
50
+ )
51
+ with gr.Row():
52
+ txt_input = gr.Textbox(
53
+ value="cartoon,painting,screenshot",
54
+ interactive=True, label="设定大类别类别", scale=5)
55
+ txt = gr.Textbox(value="", label="Output:", scale=5)
56
+ generate_bt = gr.Button("点击开始分类", scale=1)
57
+ with gr.Row():
58
+ with gr.Column():
59
+ image_input = gr.Image(
60
+ sources=["upload", "clipboard"],
61
+ label="User Image",
62
+ type="pil",
63
+ )
64
+ with gr.Row():
65
+ prob_label = gr.Textbox(value="", label="一级分类")
66
+
67
+ inputs = [image_input, txt_input]
68
+ generate_bt.click(fn=predict, inputs=inputs, outputs=[txt, prob_label], show_progress=True)
69
+ image_input.change(
70
+ fn=predict,
71
+ inputs=inputs,
72
+ outputs=[txt, prob_label],
73
+ show_progress=True,
74
+ queue=False,
75
+ )
76
+
77
+ demo.queue()
78
+ demo.launch(server_name='0.0.0.0', server_port=8081, share=False)