import gradio as gr from diffusers import StableDiffusionPipeline import torch import io from PIL import Image import os from cryptography.fernet import Fernet from google.cloud import storage import pinecone import json # decrypt Storage Cloud credentials fernet = Fernet(os.environ['DECRYPTION_KEY']) with open('cloud-storage.encrypted', 'rb') as fp: encrypted = fp.read() creds = json.loads(fernet.decrypt(encrypted).decode()) # then save creds to file with open('cloud-storage.json', 'w', encoding='utf-8') as fp: fp.write(json.dumps(creds, indent=4)) # connect to Cloud Storage os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json' storage_client = storage.Client() bucket = storage_client.get_bucket('hf-diffusion-images') # get api key for pinecone auth PINECONE_KEY = os.environ['PINECONE_KEY'] index_id = "hf-diffusion" # init connection to pinecone pinecone.init( api_key=PINECONE_KEY, environment="us-west1-gcp" ) if index_id not in pinecone.list_indexes(): raise ValueError(f"Index '{index_id}' not found") index = pinecone.Index(index_id) device = 'cpu' # init all of the models and move them to a given GPU pipe = StableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", use_auth_token=os.environ['HF_AUTH'] ) pipe.to(device) missing_im = Image.open('missing.png') def encode_text(text: str): text_inputs = pipe.tokenizer( text, return_tensors='pt' ).to(device) text_embeds = pipe.text_encoder(**text_inputs) text_embeds = text_embeds.pooler_output.cpu().tolist()[0] return text_embeds def prompt_query(text: str): embeds = encode_text(text) xc = index.query(embeds, top_k=30, include_metadata=True) prompts = [ match['metadata']['prompt'] for match in xc['matches'] ] # deduplicate while preserving order prompts = list(dict.fromkeys(prompts)) return [[x] for x in prompts[:5]] def diffuse(text: str): # diffuse out = pipe(text) if any(out.nsfw_content_detected): return {} else: _id = str(uuid.uuid4()) # add image to Cloud Storage im = out.images[0] im.save(f'{_id}.png', format='png') # push to storage blob = bucket.blob(f'images/{_id}.png') blob.upload_from_filename(f'{_id}.png') # delete local file os.remove(f'{_id}.png') # add embedding and metadata to Pinecone embeds = encode_text(text) meta = { 'prompt': text, 'image_url': f'images/{_id}.png' } index.upsert([(_id, embeds, meta)]) return out.images[0] def get_image(url: str): blob = bucket.blob(url).download_as_string() blob_bytes = io.BytesIO(blob) im = Image.open(blob_bytes) return im def test_image(_id, image): try: image.save('tmp.png') return True except OSError: # delete corrupted file from pinecone and cloud index.delete(ids=[_id]) bucket.blob(f"images/{_id}.png").delete() print(f"DELETED '{_id}'") return False def prompt_image(text: str): embeds = encode_text(text) xc = index.query(embeds, top_k=9, include_metadata=True) image_urls = [ match['metadata']['image_url'] for match in xc['matches'] ] scores = [match['score'] for match in xc['matches']] ids = [match['id'] for match in xc['matches']] images = [] for _id, image_url in zip(ids, image_urls): try: blob = bucket.blob(image_url).download_as_string() blob_bytes = io.BytesIO(blob) im = Image.open(blob_bytes) if test_image(_id, im): images.append(im) else: images.append(missing_im) except ValueError: print(f"ValueError: '{image_url}'") return images, scores # __APP FUNCTIONS__ def set_suggestion(text: str): return gr.TextArea.update(value=text[0]) def set_images(text: str): images, scores = prompt_image(text) match_found = False for score in scores: if score > 0.85: match_found = True if match_found: print("MATCH FOUND") return gr.Gallery.update(value=images) else: print("NO MATCH FOUND") diffuse(text) images, scores = prompt_image(text) return gr.Gallery.update(value=images) # __CREATE APP__ demo = gr.Blocks() with demo: gr.Markdown( """ # Dream Cacher """ ) with gr.Row(): with gr.Column(): prompt = gr.TextArea( value="A dream about a cat", placeholder="Enter a prompt to dream about", interactive=True ) search = gr.Button(value="Search!") suggestions = gr.Dataset( components=[prompt], samples=[ ["Something"], ["something else"] ] ) # event listener for change in prompt prompt.change(prompt_query, prompt, suggestions) # event listener for click on suggestion suggestions.click( set_suggestion, suggestions, suggestions.components ) # results column with gr.Column(): pics = gr.Gallery() pics.style(grid=3) # search event listening try: search.click(set_images, prompt, pics) except OSError: print("OSError") demo.launch()