from diffusers import StableDiffusionPipeline import torch import io from PIL import Image import os from cryptography.fernet import Fernet from google.cloud import storage import pinecone import json # decrypt Storage Cloud credentials fernet = Fernet(os.environ['DECRYPTION_KEY']) with open('cloud-storage.encrypted', 'rb') as fp: encrypted = fp.read() creds = json.loads(fernet.decrypt(encrypted).decode()) # then save creds to file with open('cloud-storage.json', 'w', encoding='utf-8') as fp: fp.write(json.dumps(creds, indent=4)) with open('cloud-storage.json', 'w') as fp: fp.write(json.dumps(G_API, indent=4)) del G_API # connect to Cloud Storage os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json' storage_client = storage.Client() bucket = storage_client.get_bucket('hf-diffusion-images') # get api key for pinecone auth PINECONE_KEY = os.environ['PINECONE_KEY'] index_id = "hf-diffusion" # init connection to pinecone pinecone.init( api_key=PINECONE_KEY, environment="us-west1-gcp" ) if index_id not in pinecone.list_indexes(): raise ValueError(f"Index '{index_id}' not found") index = pinecone.Index(index_id) device = 'cpu' # init all of the models and move them to a given GPU pipe = StableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", use_auth_token=True ) pipe.to(device) def encode_text(text: str): text_inputs = pipe.tokenizer( text, return_tensors='pt' ).to(device) text_embeds = pipe.text_encoder(**text_inputs) text_embeds = text_embeds.pooler_output.cpu().tolist()[0] return text_embeds def prompt_query(text: str): embeds = encode_text(text) xc = index.query(embeds, top_k=30, include_metadata=True) prompts = [ match['metadata']['prompt'] for match in xc['matches'] ] # deduplicate while preserving order prompts = list(dict.fromkeys(prompts)) return [[x] for x in prompts[:5]] def get_image(url: str): blob = bucket.blob(url).download_as_string() blob_bytes = io.BytesIO(blob) im = Image.open(blob_bytes) return im def prompt_image(text: str): embeds = encode_text(text) xc = index.query(embeds, top_k=9, include_metadata=True) image_urls = [ match['metadata']['image_url'] for match in xc['matches'] ] images = [] for image_url in image_urls: try: blob = bucket.blob(image_url).download_as_string() blob_bytes = io.BytesIO(blob) im = Image.open(blob_bytes) images.append(im) except ValueError: print(f"error for '{image_url}'") return images # __APP FUNCTIONS__ def set_suggestion(text: str): return gr.TextArea.update(value=text[0]) def set_images(text: str): images = prompt_image(text) return gr.Gallery.update(value=images) # __CREATE APP__ demo = gr.Blocks() with demo: gr.Markdown( """ # Dream Cacher """ ) with gr.Row(): with gr.Column(): prompt = gr.TextArea( value="A dream about a cat", placeholder="Enter a prompt to dream about", interactive=True ) search = gr.Button(value="Search!") suggestions = gr.Dataset( components=[prompt], samples=[ ["Something"], ["something else"] ] ) # event listener for change in prompt prompt.change(prompt_query, prompt, suggestions) # event listener for click on suggestion suggestions.click( set_suggestion, suggestions, suggestions.components ) # results column with gr.Column(): pics = gr.Gallery() pics.style(grid=3) # search event listening search.click(set_images, prompt, pics) demo.launch()