Yougen commited on
Commit
37688e5
·
verified ·
1 Parent(s): be5a0a3

Upload 3 files

Browse files
Files changed (3) hide show
  1. image2image.py +86 -0
  2. text2image.py +78 -0
  3. utils.py +43 -0
image2image.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import json
3
+ from multiprocessing.pool import ThreadPool as Pool
4
+ import gradio as gr
5
+ import PIL
6
+ from PIL import Image
7
+ from utils import *
8
+
9
+ from clip_retrieval.clip_client import ClipClient
10
+
11
+ def image2text_gr():
12
+ def clip_api(query_image=None, return_n=8, model_name=clip_base, thumbnail=yes):
13
+ client = ClipClient(url="http://9.135.121.52:1234//knn-service",
14
+ indice_name="ltr_cover_index",
15
+ aesthetic_weight=0,
16
+ num_images=int(return_n))
17
+ result = client.query(image=query_image)
18
+
19
+ if not result or len(result) == 0:
20
+ print("no result found")
21
+ return None
22
+
23
+ print(f"get result sucessed, num: {len(result)}")
24
+
25
+ cover_urls = [res['cover_url'] for res in result]
26
+ cover_info = []
27
+ for res in result:
28
+ json_info = {"cover_url": res['cover_url'],
29
+ "similarity": round(res['similarity'], 6),
30
+ "docid": res['docids']}
31
+ cover_info.append(str(json_info))
32
+ pool = Pool()
33
+ new_url2image = partial(url2img, thumbnail=thumbnail)
34
+ ret_imgs = pool.map(new_url2image, cover_urls)
35
+ pool.close()
36
+ pool.join()
37
+
38
+ new_ret = []
39
+ for i in range(len(ret_imgs)):
40
+ new_ret.append([ret_imgs[i], cover_info[i]])
41
+ return new_ret
42
+
43
+ examples = [
44
+ ["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/test2014/COCO_test2014_000000000069.jpg", 20,
45
+ clip_base, "是"],
46
+ ["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/test2014/COCO_test2014_000000000080.jpg", 20,
47
+ clip_base, "是"],
48
+ ["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/train2014/COCO_train2014_000000000009.jpg",
49
+ 20, clip_base, "是"],
50
+ ["https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/train2014/COCO_train2014_000000000308.jpg",
51
+ 20, clip_base, "是"]
52
+ ]
53
+
54
+ title = "<h1 align='center'>CLIP图到图搜索应用</h1>"
55
+
56
+ with gr.Blocks() as demo:
57
+ gr.Markdown(title)
58
+ gr.Markdown(description)
59
+ with gr.Row():
60
+ with gr.Column(scale=1):
61
+ with gr.Column(scale=2):
62
+ img = gr.Textbox(value="https://xingchen-data.oss-cn-zhangjiakou.aliyuncs.com/coco/2014/test2014/COCO_test2014_000000000069.jpg", label="图片地址", elem_id=0, interactive=True)
63
+ num = gr.components.Slider(minimum=0, maximum=50, step=1, value=8, label="返回图片数(可能被过滤部分)", elem_id=2)
64
+ model = gr.components.Radio(label="模型选择", choices=[clip_base],
65
+ value=clip_base, elem_id=3)
66
+ tn = gr.components.Radio(label="是否返回缩略图", choices=[yes, no],
67
+ value=yes, elem_id=4)
68
+ btn = gr.Button("搜索", )
69
+ with gr.Column(scale=100):
70
+ out = gr.Gallery(label="检索结果为:", columns=4, height="auto")
71
+ inputs = [img, num, model, tn]
72
+ btn.click(fn=clip_api, inputs=inputs, outputs=out)
73
+ gr.Examples(examples, inputs=inputs)
74
+ return demo
75
+
76
+
77
+ if __name__ == "__main__":
78
+ with gr.TabbedInterface(
79
+ [image2text_gr()],
80
+ ["图到图搜索"],
81
+ ) as demo:
82
+ demo.launch(
83
+ #enable_queue=True,
84
+ server_name='127.0.0.1',
85
+ share=False
86
+ )
text2image.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import json
3
+ from multiprocessing.pool import ThreadPool as Pool
4
+ import gradio as gr
5
+ from utils import *
6
+
7
+ from clip_retrieval.clip_client import ClipClient
8
+
9
+
10
+ def text2image_gr():
11
+ def clip_api(query_text='', return_n=8, model_name=clip_base, thumbnail="是"):
12
+ client = ClipClient(url="http://9.135.121.52:1234//knn-service",
13
+ indice_name="ltr_cover_index",
14
+ aesthetic_weight=0,
15
+ num_images=int(return_n))
16
+ #result = client.query(embedding_input=query_emb)
17
+ result = client.query(text=query_text)
18
+
19
+ if not result or len(result) == 0:
20
+ print("no result found")
21
+ return None
22
+
23
+ print(f"get result sucessed, num: {len(result)}")
24
+
25
+ cover_urls = [res['cover_url'] for res in result]
26
+ cover_info = []
27
+ for res in result:
28
+ json_info = {"cover_url": res['cover_url'],
29
+ "similarity": round(res['similarity'], 6),
30
+ "docid": res['docids']}
31
+ cover_info.append(str(json_info))
32
+ pool = Pool()
33
+ new_url2image = partial(url2img, thumbnail=thumbnail)
34
+ ret_imgs = pool.map(new_url2image, cover_urls)
35
+ pool.close()
36
+ pool.join()
37
+
38
+ new_ret = []
39
+ for i in range(len(ret_imgs)):
40
+ new_ret.append([ret_imgs[i], cover_info[i]])
41
+ return new_ret
42
+
43
+ examples = [
44
+ ["cat", 12, clip_base, "是"],
45
+ ["dog", 12, clip_base, "是"],
46
+ ["bag", 12, clip_base, "是"],
47
+ ["a cat is sit on the table", 12, clip_base, "是"]
48
+ ]
49
+
50
+ title = "<h1 align='center'>CLIP文到图搜索应用</h1>"
51
+
52
+ with gr.Blocks() as demo:
53
+ gr.Markdown(title)
54
+ gr.Markdown(description)
55
+ with gr.Row():
56
+ with gr.Column(scale=1):
57
+ with gr.Column(scale=2):
58
+ text = gr.Textbox(value="cat", label="请填写文本", elem_id=0, interactive=True)
59
+ num = gr.components.Slider(minimum=0, maximum=50, step=1, value=8, label="返回图片数(可能被过滤部分)", elem_id=2)
60
+ model = gr.components.Radio(label="模型选择", choices=[clip_base],
61
+ value=clip_base, elem_id=3)
62
+ thumbnail = gr.components.Radio(label="是否返回缩略图", choices=[yes, no],
63
+ value=yes, elem_id=4)
64
+ btn = gr.Button("搜索", )
65
+ with gr.Column(scale=100):
66
+ out = gr.Gallery(label="检索结果为:", columns=4, height="auto") #.style(grid=4, height=200)
67
+ inputs = [text, num, model, thumbnail]
68
+ btn.click(fn=clip_api, inputs=inputs, outputs=out)
69
+ gr.Examples(examples, inputs=inputs)
70
+ return demo
71
+
72
+ if __name__ == "__main__":
73
+ gr.close_all()
74
+ with gr.TabbedInterface(
75
+ [text2image_gr()],
76
+ ["文到图搜索"],
77
+ ) as demo:
78
+ demo.launch(server_name='127.0.0.1', share=False)
utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ from PIL import ImageFile
4
+ import requests
5
+ import base64
6
+ from io import BytesIO
7
+
8
+ clip_base = "CLIP(Base)"
9
+ description = "本项目为CLIP模型的DEMO,可用于图文检索和图像、文本的表征提取,应用于搜索、推荐等应用场景。"
10
+
11
+ yes = "是"
12
+ no = "否"
13
+
14
+ server_ip = os.environ.get("CLIP_SERVER_IP", "9.135.121.52")
15
+
16
+ clip_service_url_d = {
17
+ clip_base: f'http://{server_ip}/knn-service',
18
+ }
19
+
20
+
21
+ def pil_base64(image, img_format="JPEG"):
22
+ Image.MAX_IMAGE_PIXELS = 1000000000
23
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
24
+ img_buffer = BytesIO()
25
+ image.save(img_buffer, format=img_format)
26
+ byte_data = img_buffer.getvalue()
27
+ base64_str = base64.b64encode(byte_data)
28
+ return base64_str.decode("utf-8")
29
+
30
+
31
+ def url2img(img_url, thumbnail=yes):
32
+ try:
33
+ #print(img_url, thumbnail)
34
+ #image = Image.open(requests.get(img_url, stream=True).raw)
35
+ path = img_url.split("9.22.26.31")[1]
36
+ image = Image.open(path).convert("RGB")
37
+ max_ = max(image.size)
38
+ if max_ > 224 and thumbnail == yes:
39
+ ratio = max_ // 224
40
+ image.thumbnail(size=(image.width // ratio, image.height // ratio))
41
+ return image
42
+ except Exception as e:
43
+ print(e)