dream-cacher / app.py
jamescalam's picture
Update app.py
1a67055
raw
history blame
3.92 kB
from diffusers import StableDiffusionPipeline
import torch
import io
from PIL import Image
import os
from cryptography.fernet import Fernet
from google.cloud import storage
import pinecone
import json
# decrypt Storage Cloud credentials
fernet = Fernet(os.environ['DECRYPTION_KEY'])
with open('cloud-storage.encrypted', 'rb') as fp:
encrypted = fp.read()
creds = json.loads(fernet.decrypt(encrypted).decode())
# then save creds to file
with open('cloud-storage.json', 'w', encoding='utf-8') as fp:
fp.write(json.dumps(creds, indent=4))
# connect to Cloud Storage
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json'
storage_client = storage.Client()
bucket = storage_client.get_bucket('hf-diffusion-images')
# get api key for pinecone auth
PINECONE_KEY = os.environ['PINECONE_KEY']
index_id = "hf-diffusion"
# init connection to pinecone
pinecone.init(
api_key=PINECONE_KEY,
environment="us-west1-gcp"
)
if index_id not in pinecone.list_indexes():
raise ValueError(f"Index '{index_id}' not found")
index = pinecone.Index(index_id)
device = 'cpu'
# init all of the models and move them to a given GPU
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", use_auth_token=os.environ['HF_AUTH']
)
pipe.to(device)
def encode_text(text: str):
text_inputs = pipe.tokenizer(
text, return_tensors='pt'
).to(device)
text_embeds = pipe.text_encoder(**text_inputs)
text_embeds = text_embeds.pooler_output.cpu().tolist()[0]
return text_embeds
def prompt_query(text: str):
embeds = encode_text(text)
xc = index.query(embeds, top_k=30, include_metadata=True)
prompts = [
match['metadata']['prompt'] for match in xc['matches']
]
# deduplicate while preserving order
prompts = list(dict.fromkeys(prompts))
return [[x] for x in prompts[:5]]
def get_image(url: str):
blob = bucket.blob(url).download_as_string()
blob_bytes = io.BytesIO(blob)
im = Image.open(blob_bytes)
return im
def prompt_image(text: str):
embeds = encode_text(text)
xc = index.query(embeds, top_k=9, include_metadata=True)
image_urls = [
match['metadata']['image_url'] for match in xc['matches']
]
images = []
for image_url in image_urls:
try:
blob = bucket.blob(image_url).download_as_string()
blob_bytes = io.BytesIO(blob)
im = Image.open(blob_bytes)
images.append(im)
except ValueError:
print(f"error for '{image_url}'")
return images
# __APP FUNCTIONS__
def set_suggestion(text: str):
return gr.TextArea.update(value=text[0])
def set_images(text: str):
images = prompt_image(text)
return gr.Gallery.update(value=images)
# __CREATE APP__
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# Dream Cacher
"""
)
with gr.Row():
with gr.Column():
prompt = gr.TextArea(
value="A dream about a cat",
placeholder="Enter a prompt to dream about",
interactive=True
)
search = gr.Button(value="Search!")
suggestions = gr.Dataset(
components=[prompt],
samples=[
["Something"],
["something else"]
]
)
# event listener for change in prompt
prompt.change(prompt_query, prompt, suggestions)
# event listener for click on suggestion
suggestions.click(
set_suggestion,
suggestions,
suggestions.components
)
# results column
with gr.Column():
pics = gr.Gallery()
pics.style(grid=3)
# search event listening
search.click(set_images, prompt, pics)
demo.launch()