Spaces:
Runtime error
Runtime error
jamescalam
commited on
Commit
•
4a10f8f
1
Parent(s):
628dd10
added diffusion functionality
Browse files
app.py
CHANGED
@@ -46,6 +46,8 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
|
46 |
)
|
47 |
pipe.to(device)
|
48 |
|
|
|
|
|
49 |
def encode_text(text: str):
|
50 |
text_inputs = pipe.tokenizer(
|
51 |
text, return_tensors='pt'
|
@@ -64,28 +66,68 @@ def prompt_query(text: str):
|
|
64 |
prompts = list(dict.fromkeys(prompts))
|
65 |
return [[x] for x in prompts[:5]]
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
def get_image(url: str):
|
68 |
blob = bucket.blob(url).download_as_string()
|
69 |
blob_bytes = io.BytesIO(blob)
|
70 |
im = Image.open(blob_bytes)
|
71 |
return im
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def prompt_image(text: str):
|
74 |
embeds = encode_text(text)
|
75 |
xc = index.query(embeds, top_k=9, include_metadata=True)
|
76 |
image_urls = [
|
77 |
match['metadata']['image_url'] for match in xc['matches']
|
78 |
]
|
|
|
|
|
79 |
images = []
|
80 |
-
for image_url in image_urls:
|
81 |
try:
|
82 |
blob = bucket.blob(image_url).download_as_string()
|
83 |
blob_bytes = io.BytesIO(blob)
|
84 |
im = Image.open(blob_bytes)
|
85 |
-
|
|
|
|
|
|
|
86 |
except ValueError:
|
87 |
-
print(f"
|
88 |
-
return images
|
89 |
|
90 |
# __APP FUNCTIONS__
|
91 |
|
@@ -93,8 +135,19 @@ def set_suggestion(text: str):
|
|
93 |
return gr.TextArea.update(value=text[0])
|
94 |
|
95 |
def set_images(text: str):
|
96 |
-
images = prompt_image(text)
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
# __CREATE APP__
|
100 |
demo = gr.Blocks()
|
@@ -129,12 +182,14 @@ with demo:
|
|
129 |
suggestions.components
|
130 |
)
|
131 |
|
132 |
-
|
133 |
# results column
|
134 |
with gr.Column():
|
135 |
pics = gr.Gallery()
|
136 |
pics.style(grid=3)
|
137 |
# search event listening
|
138 |
-
|
|
|
|
|
|
|
139 |
|
140 |
demo.launch()
|
|
|
46 |
)
|
47 |
pipe.to(device)
|
48 |
|
49 |
+
missing_im = Image.open('missing.png')
|
50 |
+
|
51 |
def encode_text(text: str):
|
52 |
text_inputs = pipe.tokenizer(
|
53 |
text, return_tensors='pt'
|
|
|
66 |
prompts = list(dict.fromkeys(prompts))
|
67 |
return [[x] for x in prompts[:5]]
|
68 |
|
69 |
+
def diffuse(text: str):
|
70 |
+
# diffuse
|
71 |
+
out = pipe(text)
|
72 |
+
if any(out.nsfw_content_detected):
|
73 |
+
return {}
|
74 |
+
else:
|
75 |
+
_id = str(uuid.uuid4())
|
76 |
+
# add image to Cloud Storage
|
77 |
+
im = out.images[0]
|
78 |
+
im.save(f'{_id}.png', format='png')
|
79 |
+
# push to storage
|
80 |
+
blob = bucket.blob(f'images/{_id}.png')
|
81 |
+
blob.upload_from_filename(f'{_id}.png')
|
82 |
+
# delete local file
|
83 |
+
os.remove(f'{_id}.png')
|
84 |
+
# add embedding and metadata to Pinecone
|
85 |
+
embeds = encode_text(text)
|
86 |
+
meta = {
|
87 |
+
'prompt': text,
|
88 |
+
'image_url': f'images/{_id}.png'
|
89 |
+
}
|
90 |
+
index.upsert([(_id, embeds, meta)])
|
91 |
+
return out.images[0]
|
92 |
+
|
93 |
def get_image(url: str):
|
94 |
blob = bucket.blob(url).download_as_string()
|
95 |
blob_bytes = io.BytesIO(blob)
|
96 |
im = Image.open(blob_bytes)
|
97 |
return im
|
98 |
|
99 |
+
def test_image(_id, image):
|
100 |
+
try:
|
101 |
+
image.save('tmp.png')
|
102 |
+
return True
|
103 |
+
except OSError:
|
104 |
+
# delete corrupted file from pinecone and cloud
|
105 |
+
index.delete(ids=[_id])
|
106 |
+
bucket.blob(f"images/{_id}.png").delete()
|
107 |
+
print(f"DELETED '{_id}'")
|
108 |
+
return False
|
109 |
+
|
110 |
def prompt_image(text: str):
|
111 |
embeds = encode_text(text)
|
112 |
xc = index.query(embeds, top_k=9, include_metadata=True)
|
113 |
image_urls = [
|
114 |
match['metadata']['image_url'] for match in xc['matches']
|
115 |
]
|
116 |
+
scores = [match['score'] for match in xc['matches']]
|
117 |
+
ids = [match['id'] for match in xc['matches']]
|
118 |
images = []
|
119 |
+
for _id, image_url in zip(ids, image_urls):
|
120 |
try:
|
121 |
blob = bucket.blob(image_url).download_as_string()
|
122 |
blob_bytes = io.BytesIO(blob)
|
123 |
im = Image.open(blob_bytes)
|
124 |
+
if test_image(_id, im):
|
125 |
+
images.append(im)
|
126 |
+
else:
|
127 |
+
images.append(missing_im)
|
128 |
except ValueError:
|
129 |
+
print(f"ValueError: '{image_url}'")
|
130 |
+
return images, scores
|
131 |
|
132 |
# __APP FUNCTIONS__
|
133 |
|
|
|
135 |
return gr.TextArea.update(value=text[0])
|
136 |
|
137 |
def set_images(text: str):
|
138 |
+
images, scores = prompt_image(text)
|
139 |
+
match_found = False
|
140 |
+
for score in scores:
|
141 |
+
if score > 0.85:
|
142 |
+
match_found = True
|
143 |
+
if match_found:
|
144 |
+
print("MATCH FOUND")
|
145 |
+
return gr.Gallery.update(value=images)
|
146 |
+
else:
|
147 |
+
print("NO MATCH FOUND")
|
148 |
+
diffuse(text)
|
149 |
+
images, scores = prompt_image(text)
|
150 |
+
return gr.Gallery.update(value=images)
|
151 |
|
152 |
# __CREATE APP__
|
153 |
demo = gr.Blocks()
|
|
|
182 |
suggestions.components
|
183 |
)
|
184 |
|
|
|
185 |
# results column
|
186 |
with gr.Column():
|
187 |
pics = gr.Gallery()
|
188 |
pics.style(grid=3)
|
189 |
# search event listening
|
190 |
+
try:
|
191 |
+
search.click(set_images, prompt, pics)
|
192 |
+
except OSError:
|
193 |
+
print("OSError")
|
194 |
|
195 |
demo.launch()
|