jamescalam commited on
Commit
4a10f8f
1 Parent(s): 628dd10

added diffusion functionality

Browse files
Files changed (1) hide show
  1. app.py +63 -8
app.py CHANGED
@@ -46,6 +46,8 @@ pipe = StableDiffusionPipeline.from_pretrained(
46
  )
47
  pipe.to(device)
48
 
 
 
49
  def encode_text(text: str):
50
  text_inputs = pipe.tokenizer(
51
  text, return_tensors='pt'
@@ -64,28 +66,68 @@ def prompt_query(text: str):
64
  prompts = list(dict.fromkeys(prompts))
65
  return [[x] for x in prompts[:5]]
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def get_image(url: str):
68
  blob = bucket.blob(url).download_as_string()
69
  blob_bytes = io.BytesIO(blob)
70
  im = Image.open(blob_bytes)
71
  return im
72
 
 
 
 
 
 
 
 
 
 
 
 
73
  def prompt_image(text: str):
74
  embeds = encode_text(text)
75
  xc = index.query(embeds, top_k=9, include_metadata=True)
76
  image_urls = [
77
  match['metadata']['image_url'] for match in xc['matches']
78
  ]
 
 
79
  images = []
80
- for image_url in image_urls:
81
  try:
82
  blob = bucket.blob(image_url).download_as_string()
83
  blob_bytes = io.BytesIO(blob)
84
  im = Image.open(blob_bytes)
85
- images.append(im)
 
 
 
86
  except ValueError:
87
- print(f"error for '{image_url}'")
88
- return images
89
 
90
  # __APP FUNCTIONS__
91
 
@@ -93,8 +135,19 @@ def set_suggestion(text: str):
93
  return gr.TextArea.update(value=text[0])
94
 
95
  def set_images(text: str):
96
- images = prompt_image(text)
97
- return gr.Gallery.update(value=images)
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  # __CREATE APP__
100
  demo = gr.Blocks()
@@ -129,12 +182,14 @@ with demo:
129
  suggestions.components
130
  )
131
 
132
-
133
  # results column
134
  with gr.Column():
135
  pics = gr.Gallery()
136
  pics.style(grid=3)
137
  # search event listening
138
- search.click(set_images, prompt, pics)
 
 
 
139
 
140
  demo.launch()
 
46
  )
47
  pipe.to(device)
48
 
49
+ missing_im = Image.open('missing.png')
50
+
51
  def encode_text(text: str):
52
  text_inputs = pipe.tokenizer(
53
  text, return_tensors='pt'
 
66
  prompts = list(dict.fromkeys(prompts))
67
  return [[x] for x in prompts[:5]]
68
 
69
+ def diffuse(text: str):
70
+ # diffuse
71
+ out = pipe(text)
72
+ if any(out.nsfw_content_detected):
73
+ return {}
74
+ else:
75
+ _id = str(uuid.uuid4())
76
+ # add image to Cloud Storage
77
+ im = out.images[0]
78
+ im.save(f'{_id}.png', format='png')
79
+ # push to storage
80
+ blob = bucket.blob(f'images/{_id}.png')
81
+ blob.upload_from_filename(f'{_id}.png')
82
+ # delete local file
83
+ os.remove(f'{_id}.png')
84
+ # add embedding and metadata to Pinecone
85
+ embeds = encode_text(text)
86
+ meta = {
87
+ 'prompt': text,
88
+ 'image_url': f'images/{_id}.png'
89
+ }
90
+ index.upsert([(_id, embeds, meta)])
91
+ return out.images[0]
92
+
93
  def get_image(url: str):
94
  blob = bucket.blob(url).download_as_string()
95
  blob_bytes = io.BytesIO(blob)
96
  im = Image.open(blob_bytes)
97
  return im
98
 
99
+ def test_image(_id, image):
100
+ try:
101
+ image.save('tmp.png')
102
+ return True
103
+ except OSError:
104
+ # delete corrupted file from pinecone and cloud
105
+ index.delete(ids=[_id])
106
+ bucket.blob(f"images/{_id}.png").delete()
107
+ print(f"DELETED '{_id}'")
108
+ return False
109
+
110
  def prompt_image(text: str):
111
  embeds = encode_text(text)
112
  xc = index.query(embeds, top_k=9, include_metadata=True)
113
  image_urls = [
114
  match['metadata']['image_url'] for match in xc['matches']
115
  ]
116
+ scores = [match['score'] for match in xc['matches']]
117
+ ids = [match['id'] for match in xc['matches']]
118
  images = []
119
+ for _id, image_url in zip(ids, image_urls):
120
  try:
121
  blob = bucket.blob(image_url).download_as_string()
122
  blob_bytes = io.BytesIO(blob)
123
  im = Image.open(blob_bytes)
124
+ if test_image(_id, im):
125
+ images.append(im)
126
+ else:
127
+ images.append(missing_im)
128
  except ValueError:
129
+ print(f"ValueError: '{image_url}'")
130
+ return images, scores
131
 
132
  # __APP FUNCTIONS__
133
 
 
135
  return gr.TextArea.update(value=text[0])
136
 
137
  def set_images(text: str):
138
+ images, scores = prompt_image(text)
139
+ match_found = False
140
+ for score in scores:
141
+ if score > 0.85:
142
+ match_found = True
143
+ if match_found:
144
+ print("MATCH FOUND")
145
+ return gr.Gallery.update(value=images)
146
+ else:
147
+ print("NO MATCH FOUND")
148
+ diffuse(text)
149
+ images, scores = prompt_image(text)
150
+ return gr.Gallery.update(value=images)
151
 
152
  # __CREATE APP__
153
  demo = gr.Blocks()
 
182
  suggestions.components
183
  )
184
 
 
185
  # results column
186
  with gr.Column():
187
  pics = gr.Gallery()
188
  pics.style(grid=3)
189
  # search event listening
190
+ try:
191
+ search.click(set_images, prompt, pics)
192
+ except OSError:
193
+ print("OSError")
194
 
195
  demo.launch()