File size: 2,274 Bytes
949552b
 
 
a8b4f93
 
949552b
 
 
 
 
 
 
 
026d4fb
 
949552b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcf82fb
949552b
 
026d4fb
 
949552b
 
 
 
 
026d4fb
949552b
 
 
026d4fb
949552b
 
 
 
 
 
 
 
 
 
 
026d4fb
949552b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from PIL import Image
import gradio as gr
import requests
# import torch

from models.zhclip import ZhCLIPProcessor, ZhCLIPModel  # From https://www.github.com/yue-gang/ZH-CLIP

version = 'nlpcver/zh-clip-vit-roberta-large-patch14'
model = ZhCLIPModel.from_pretrained(version)
processor = ZhCLIPProcessor.from_pretrained(version)



def get_result(image,text,text1):
    inputs = processor(text=[text,text1], images=image, return_tensors="pt", padding=True)

    outputs = model(**inputs)
    image_features = outputs.image_features
    text_features = outputs.text_features
    text_probs = (image_features @ text_features.T).softmax(dim=-1)
    return text_probs
with gr.Blocks(
    css="""
    .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
    #component-21 > div.wrap.svelte-w6rprc {height: 600px;}
    """
) as iface:
    state = gr.State([])

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil",label="Image Input")
            with gr.Row():
                with gr.Column(scale=1):
                    chat_input = gr.Textbox(lines=1, label="Captions0 Input")
                    chat_input1 = gr.Textbox(lines=1, label="Captions1 Input")
                    with gr.Row():
                        clear_button = gr.Button(value="Clear", interactive=True,width=30)
                        submit_button = gr.Button(
                            value="Submit", interactive=True, variant="primary"
                        )
                      
        with gr.Column():
            caption_output = gr.Textbox(lines=0, label="ITM")
            
       
        clear_button.click(
                        lambda: ("", [],"","",""),
                        [],
                        [chat_input,  state,caption_output],
                        queue=False,
                    )
        submit_button.click(
                        get_result,
                        [
                            image_input,
                            chat_input,
                            chat_input1,
                        ],
                        [caption_output],
                    )
iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(enable_queue=True)