File size: 2,064 Bytes
93d3aa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95ca169
 
93d3aa5
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import chromadb
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
import json
from ast import literal_eval
from chromadb.config import Settings
from paddlenlp import Taskflow
import requests
from io import BytesIO
from PIL import Image
import gradio as gr

vision_language=Taskflow("feature_extraction", model='PaddlePaddle/ernie_vil-2.0-base-zh')

def getImageTestFeture(content):
    if content.startswith("http"):
        response = requests.get(content)
        x=BytesIO(response.content)
        f_embeds = vision_language(Image.open(x))
    else:
        f_embeds = vision_language(content)
    text_features = f_embeds["features"][0]
    return text_features

class MyEmbeddingFunction(EmbeddingFunction):
    def __call__(self, texts: Documents) -> Embeddings:
        qr=[]
        for doc in texts:
            text_embeds = getImageTestFeture(doc)
            #print(len(text_features))
            bedx=text_embeds.tolist()
            qr.append(bedx)
        return qr

client = chromadb.Client(Settings(
    chroma_db_impl="duckdb+parquet",
    persist_directory="x/" # Optional, defaults to .chromadb/ in the current directory
))

collection = client.get_or_create_collection(name="pics", metadata={"hnsw:space": "cosine"}, embedding_function=MyEmbeddingFunction())

def queryimgage(text):
    html="<table border='1'>\
    <tr>\
    <th>img</th>\
    <th>score</th>\
  </tr>"
    atext=[]
    atext.append(text)
    results = collection.query(
        query_texts=atext,
        n_results=20,
    )
    ids=results['ids'][0]
    documents=results['documents'][0]
    distances=results['distances'][0]
    xcount=len(ids)
    for i in range(xcount):
        #print("id:%s,url:%s,score:%s"%(ids[i],documents[i],distances[i]))
         html=html +"<tr>\
        <td><img src='"+documents[xcount-1-i]+"' width=640></td>\
        <td>"+ str(distances[xcount-1-i])+"</td>"
    html=html+"</table>"
    return  html

demo = gr.Interface(
    queryimgage,
    gr.Textbox(placeholder="请输入文本"),
    [ "html"]
)

demo.launch()