File size: 2,274 Bytes
949552b a8b4f93 949552b 026d4fb 949552b dcf82fb 949552b 026d4fb 949552b 026d4fb 949552b 026d4fb 949552b 026d4fb 949552b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from PIL import Image
import gradio as gr
import requests
# import torch
from models.zhclip import ZhCLIPProcessor, ZhCLIPModel # From https://www.github.com/yue-gang/ZH-CLIP
version = 'nlpcver/zh-clip-vit-roberta-large-patch14'
model = ZhCLIPModel.from_pretrained(version)
processor = ZhCLIPProcessor.from_pretrained(version)
def get_result(image,text,text1):
inputs = processor(text=[text,text1], images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
image_features = outputs.image_features
text_features = outputs.text_features
text_probs = (image_features @ text_features.T).softmax(dim=-1)
return text_probs
with gr.Blocks(
css="""
.message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
#component-21 > div.wrap.svelte-w6rprc {height: 600px;}
"""
) as iface:
state = gr.State([])
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil",label="Image Input")
with gr.Row():
with gr.Column(scale=1):
chat_input = gr.Textbox(lines=1, label="Captions0 Input")
chat_input1 = gr.Textbox(lines=1, label="Captions1 Input")
with gr.Row():
clear_button = gr.Button(value="Clear", interactive=True,width=30)
submit_button = gr.Button(
value="Submit", interactive=True, variant="primary"
)
with gr.Column():
caption_output = gr.Textbox(lines=0, label="ITM")
clear_button.click(
lambda: ("", [],"","",""),
[],
[chat_input, state,caption_output],
queue=False,
)
submit_button.click(
get_result,
[
image_input,
chat_input,
chat_input1,
],
[caption_output],
)
iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(enable_queue=True)
|