Upload 3 files
Browse files- image2image.py +86 -0
- text2image.py +78 -0
- utils.py +43 -0
image2image.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import json
|
3 |
+
from multiprocessing.pool import ThreadPool as Pool
|
4 |
+
import gradio as gr
|
5 |
+
import PIL
|
6 |
+
from PIL import Image
|
7 |
+
from utils import *
|
8 |
+
|
9 |
+
from clip_retrieval.clip_client import ClipClient
|
10 |
+
|
11 |
+
def image2text_gr():
|
12 |
+
def clip_api(query_image=None, return_n=8, model_name=clip_base, thumbnail=yes):
|
13 |
+
client = ClipClient(url="http://9.135.121.52:1234//knn-service",
|
14 |
+
indice_name="ltr_cover_index",
|
15 |
+
aesthetic_weight=0,
|
16 |
+
num_images=int(return_n))
|
17 |
+
result = client.query(image=query_image)
|
18 |
+
|
19 |
+
if not result or len(result) == 0:
|
20 |
+
print("no result found")
|
21 |
+
return None
|
22 |
+
|
23 |
+
print(f"get result sucessed, num: {len(result)}")
|
24 |
+
|
25 |
+
cover_urls = [res['cover_url'] for res in result]
|
26 |
+
cover_info = []
|
27 |
+
for res in result:
|
28 |
+
json_info = {"cover_url": res['cover_url'],
|
29 |
+
"similarity": round(res['similarity'], 6),
|
30 |
+
"docid": res['docids']}
|
31 |
+
cover_info.append(str(json_info))
|
32 |
+
pool = Pool()
|
33 |
+
new_url2image = partial(url2img, thumbnail=thumbnail)
|
34 |
+
ret_imgs = pool.map(new_url2image, cover_urls)
|
35 |
+
pool.close()
|
36 |
+
pool.join()
|
37 |
+
|
38 |
+
new_ret = []
|
39 |
+
for i in range(len(ret_imgs)):
|
40 |
+
new_ret.append([ret_imgs[i], cover_info[i]])
|
41 |
+
return new_ret
|
42 |
+
|
43 |
+
examples = [
|
44 |
+
["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/test2014/COCO_test2014_000000000069.jpg", 20,
|
45 |
+
clip_base, "是"],
|
46 |
+
["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/test2014/COCO_test2014_000000000080.jpg", 20,
|
47 |
+
clip_base, "是"],
|
48 |
+
["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/train2014/COCO_train2014_000000000009.jpg",
|
49 |
+
20, clip_base, "是"],
|
50 |
+
["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/train2014/COCO_train2014_000000000308.jpg",
|
51 |
+
20, clip_base, "是"]
|
52 |
+
]
|
53 |
+
|
54 |
+
title = "<h1 align='center'>CLIP图到图搜索应用</h1>"
|
55 |
+
|
56 |
+
with gr.Blocks() as demo:
|
57 |
+
gr.Markdown(title)
|
58 |
+
gr.Markdown(description)
|
59 |
+
with gr.Row():
|
60 |
+
with gr.Column(scale=1):
|
61 |
+
with gr.Column(scale=2):
|
62 |
+
img = gr.Textbox(value="https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/test2014/COCO_test2014_000000000069.jpg", label="图片地址", elem_id=0, interactive=True)
|
63 |
+
num = gr.components.Slider(minimum=0, maximum=50, step=1, value=8, label="返回图片数(可能被过滤部分)", elem_id=2)
|
64 |
+
model = gr.components.Radio(label="模型选择", choices=[clip_base],
|
65 |
+
value=clip_base, elem_id=3)
|
66 |
+
tn = gr.components.Radio(label="是否返回缩略图", choices=[yes, no],
|
67 |
+
value=yes, elem_id=4)
|
68 |
+
btn = gr.Button("搜索", )
|
69 |
+
with gr.Column(scale=100):
|
70 |
+
out = gr.Gallery(label="检索结果为:", columns=4, height="auto")
|
71 |
+
inputs = [img, num, model, tn]
|
72 |
+
btn.click(fn=clip_api, inputs=inputs, outputs=out)
|
73 |
+
gr.Examples(examples, inputs=inputs)
|
74 |
+
return demo
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
with gr.TabbedInterface(
|
79 |
+
[image2text_gr()],
|
80 |
+
["图到图搜索"],
|
81 |
+
) as demo:
|
82 |
+
demo.launch(
|
83 |
+
#enable_queue=True,
|
84 |
+
server_name='127.0.0.1',
|
85 |
+
share=False
|
86 |
+
)
|
text2image.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import json
|
3 |
+
from multiprocessing.pool import ThreadPool as Pool
|
4 |
+
import gradio as gr
|
5 |
+
from utils import *
|
6 |
+
|
7 |
+
from clip_retrieval.clip_client import ClipClient
|
8 |
+
|
9 |
+
|
10 |
+
def text2image_gr():
|
11 |
+
def clip_api(query_text='', return_n=8, model_name=clip_base, thumbnail="是"):
|
12 |
+
client = ClipClient(url="http://9.135.121.52:1234//knn-service",
|
13 |
+
indice_name="ltr_cover_index",
|
14 |
+
aesthetic_weight=0,
|
15 |
+
num_images=int(return_n))
|
16 |
+
#result = client.query(embedding_input=query_emb)
|
17 |
+
result = client.query(text=query_text)
|
18 |
+
|
19 |
+
if not result or len(result) == 0:
|
20 |
+
print("no result found")
|
21 |
+
return None
|
22 |
+
|
23 |
+
print(f"get result sucessed, num: {len(result)}")
|
24 |
+
|
25 |
+
cover_urls = [res['cover_url'] for res in result]
|
26 |
+
cover_info = []
|
27 |
+
for res in result:
|
28 |
+
json_info = {"cover_url": res['cover_url'],
|
29 |
+
"similarity": round(res['similarity'], 6),
|
30 |
+
"docid": res['docids']}
|
31 |
+
cover_info.append(str(json_info))
|
32 |
+
pool = Pool()
|
33 |
+
new_url2image = partial(url2img, thumbnail=thumbnail)
|
34 |
+
ret_imgs = pool.map(new_url2image, cover_urls)
|
35 |
+
pool.close()
|
36 |
+
pool.join()
|
37 |
+
|
38 |
+
new_ret = []
|
39 |
+
for i in range(len(ret_imgs)):
|
40 |
+
new_ret.append([ret_imgs[i], cover_info[i]])
|
41 |
+
return new_ret
|
42 |
+
|
43 |
+
examples = [
|
44 |
+
["cat", 12, clip_base, "是"],
|
45 |
+
["dog", 12, clip_base, "是"],
|
46 |
+
["bag", 12, clip_base, "是"],
|
47 |
+
["a cat is sit on the table", 12, clip_base, "是"]
|
48 |
+
]
|
49 |
+
|
50 |
+
title = "<h1 align='center'>CLIP文到图搜索应用</h1>"
|
51 |
+
|
52 |
+
with gr.Blocks() as demo:
|
53 |
+
gr.Markdown(title)
|
54 |
+
gr.Markdown(description)
|
55 |
+
with gr.Row():
|
56 |
+
with gr.Column(scale=1):
|
57 |
+
with gr.Column(scale=2):
|
58 |
+
text = gr.Textbox(value="cat", label="请填写文本", elem_id=0, interactive=True)
|
59 |
+
num = gr.components.Slider(minimum=0, maximum=50, step=1, value=8, label="返回图片数(可能被过滤部分)", elem_id=2)
|
60 |
+
model = gr.components.Radio(label="模型选择", choices=[clip_base],
|
61 |
+
value=clip_base, elem_id=3)
|
62 |
+
thumbnail = gr.components.Radio(label="是否返回缩略图", choices=[yes, no],
|
63 |
+
value=yes, elem_id=4)
|
64 |
+
btn = gr.Button("搜索", )
|
65 |
+
with gr.Column(scale=100):
|
66 |
+
out = gr.Gallery(label="检索结果为:", columns=4, height="auto") #.style(grid=4, height=200)
|
67 |
+
inputs = [text, num, model, thumbnail]
|
68 |
+
btn.click(fn=clip_api, inputs=inputs, outputs=out)
|
69 |
+
gr.Examples(examples, inputs=inputs)
|
70 |
+
return demo
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
gr.close_all()
|
74 |
+
with gr.TabbedInterface(
|
75 |
+
[text2image_gr()],
|
76 |
+
["文到图搜索"],
|
77 |
+
) as demo:
|
78 |
+
demo.launch(server_name='127.0.0.1', share=False)
|
utils.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
from PIL import ImageFile
|
4 |
+
import requests
|
5 |
+
import base64
|
6 |
+
from io import BytesIO
|
7 |
+
|
8 |
+
clip_base = "CLIP(Base)"
|
9 |
+
description = "本项目为CLIP模型的DEMO,可用于图文检索和图像、文本的表征提取,应用于搜索、推荐等应用场景。"
|
10 |
+
|
11 |
+
yes = "是"
|
12 |
+
no = "否"
|
13 |
+
|
14 |
+
server_ip = os.environ.get("CLIP_SERVER_IP", "9.135.121.52")
|
15 |
+
|
16 |
+
clip_service_url_d = {
|
17 |
+
clip_base: f'http://{server_ip}/knn-service',
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
def pil_base64(image, img_format="JPEG"):
|
22 |
+
Image.MAX_IMAGE_PIXELS = 1000000000
|
23 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
24 |
+
img_buffer = BytesIO()
|
25 |
+
image.save(img_buffer, format=img_format)
|
26 |
+
byte_data = img_buffer.getvalue()
|
27 |
+
base64_str = base64.b64encode(byte_data)
|
28 |
+
return base64_str.decode("utf-8")
|
29 |
+
|
30 |
+
|
31 |
+
def url2img(img_url, thumbnail=yes):
|
32 |
+
try:
|
33 |
+
#print(img_url, thumbnail)
|
34 |
+
#image = Image.open(requests.get(img_url, stream=True).raw)
|
35 |
+
path = img_url.split("9.22.26.31")[1]
|
36 |
+
image = Image.open(path).convert("RGB")
|
37 |
+
max_ = max(image.size)
|
38 |
+
if max_ > 224 and thumbnail == yes:
|
39 |
+
ratio = max_ // 224
|
40 |
+
image.thumbnail(size=(image.width // ratio, image.height // ratio))
|
41 |
+
return image
|
42 |
+
except Exception as e:
|
43 |
+
print(e)
|