pic_search / app.py
YiHuan's picture
Update app.py
95ca169
raw
history blame
2.06 kB
import chromadb
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
import json
from ast import literal_eval
from chromadb.config import Settings
from paddlenlp import Taskflow
import requests
from io import BytesIO
from PIL import Image
import gradio as gr
vision_language=Taskflow("feature_extraction", model='PaddlePaddle/ernie_vil-2.0-base-zh')
def getImageTestFeture(content):
if content.startswith("http"):
response = requests.get(content)
x=BytesIO(response.content)
f_embeds = vision_language(Image.open(x))
else:
f_embeds = vision_language(content)
text_features = f_embeds["features"][0]
return text_features
class MyEmbeddingFunction(EmbeddingFunction):
def __call__(self, texts: Documents) -> Embeddings:
qr=[]
for doc in texts:
text_embeds = getImageTestFeture(doc)
#print(len(text_features))
bedx=text_embeds.tolist()
qr.append(bedx)
return qr
client = chromadb.Client(Settings(
chroma_db_impl="duckdb+parquet",
persist_directory="x/" # Optional, defaults to .chromadb/ in the current directory
))
collection = client.get_or_create_collection(name="pics", metadata={"hnsw:space": "cosine"}, embedding_function=MyEmbeddingFunction())
def queryimgage(text):
html="<table border='1'>\
<tr>\
<th>img</th>\
<th>score</th>\
</tr>"
atext=[]
atext.append(text)
results = collection.query(
query_texts=atext,
n_results=20,
)
ids=results['ids'][0]
documents=results['documents'][0]
distances=results['distances'][0]
xcount=len(ids)
for i in range(xcount):
#print("id:%s,url:%s,score:%s"%(ids[i],documents[i],distances[i]))
html=html +"<tr>\
<td><img src='"+documents[xcount-1-i]+"' width=640></td>\
<td>"+ str(distances[xcount-1-i])+"</td>"
html=html+"</table>"
return html
demo = gr.Interface(
queryimgage,
gr.Textbox(placeholder="请输入文本"),
[ "html"]
)
demo.launch()