Spaces:
Runtime error
Runtime error
from diffusers import StableDiffusionPipeline | |
import torch | |
import io | |
from PIL import Image | |
import os | |
from cryptography.fernet import Fernet | |
from google.cloud import storage | |
import pinecone | |
import json | |
# decrypt Storage Cloud credentials | |
fernet = Fernet(os.environ['DECRYPTION_KEY']) | |
with open('cloud-storage.encrypted', 'rb') as fp: | |
encrypted = fp.read() | |
creds = json.loads(fernet.decrypt(encrypted).decode()) | |
# then save creds to file | |
with open('cloud-storage.json', 'w', encoding='utf-8') as fp: | |
fp.write(json.dumps(creds, indent=4)) | |
with open('cloud-storage.json', 'w') as fp: | |
fp.write(json.dumps(G_API, indent=4)) | |
del G_API | |
# connect to Cloud Storage | |
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json' | |
storage_client = storage.Client() | |
bucket = storage_client.get_bucket('hf-diffusion-images') | |
# get api key for pinecone auth | |
PINECONE_KEY = os.environ['PINECONE_KEY'] | |
index_id = "hf-diffusion" | |
# init connection to pinecone | |
pinecone.init( | |
api_key=PINECONE_KEY, | |
environment="us-west1-gcp" | |
) | |
if index_id not in pinecone.list_indexes(): | |
raise ValueError(f"Index '{index_id}' not found") | |
index = pinecone.Index(index_id) | |
device = 'cpu' | |
# init all of the models and move them to a given GPU | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"CompVis/stable-diffusion-v1-4", use_auth_token=True | |
) | |
pipe.to(device) | |
def encode_text(text: str): | |
text_inputs = pipe.tokenizer( | |
text, return_tensors='pt' | |
).to(device) | |
text_embeds = pipe.text_encoder(**text_inputs) | |
text_embeds = text_embeds.pooler_output.cpu().tolist()[0] | |
return text_embeds | |
def prompt_query(text: str): | |
embeds = encode_text(text) | |
xc = index.query(embeds, top_k=30, include_metadata=True) | |
prompts = [ | |
match['metadata']['prompt'] for match in xc['matches'] | |
] | |
# deduplicate while preserving order | |
prompts = list(dict.fromkeys(prompts)) | |
return [[x] for x in prompts[:5]] | |
def get_image(url: str): | |
blob = bucket.blob(url).download_as_string() | |
blob_bytes = io.BytesIO(blob) | |
im = Image.open(blob_bytes) | |
return im | |
def prompt_image(text: str): | |
embeds = encode_text(text) | |
xc = index.query(embeds, top_k=9, include_metadata=True) | |
image_urls = [ | |
match['metadata']['image_url'] for match in xc['matches'] | |
] | |
images = [] | |
for image_url in image_urls: | |
try: | |
blob = bucket.blob(image_url).download_as_string() | |
blob_bytes = io.BytesIO(blob) | |
im = Image.open(blob_bytes) | |
images.append(im) | |
except ValueError: | |
print(f"error for '{image_url}'") | |
return images | |
# __APP FUNCTIONS__ | |
def set_suggestion(text: str): | |
return gr.TextArea.update(value=text[0]) | |
def set_images(text: str): | |
images = prompt_image(text) | |
return gr.Gallery.update(value=images) | |
# __CREATE APP__ | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown( | |
""" | |
# Dream Cacher | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.TextArea( | |
value="A dream about a cat", | |
placeholder="Enter a prompt to dream about", | |
interactive=True | |
) | |
search = gr.Button(value="Search!") | |
suggestions = gr.Dataset( | |
components=[prompt], | |
samples=[ | |
["Something"], | |
["something else"] | |
] | |
) | |
# event listener for change in prompt | |
prompt.change(prompt_query, prompt, suggestions) | |
# event listener for click on suggestion | |
suggestions.click( | |
set_suggestion, | |
suggestions, | |
suggestions.components | |
) | |
# results column | |
with gr.Column(): | |
pics = gr.Gallery() | |
pics.style(grid=3) | |
# search event listening | |
search.click(set_images, prompt, pics) | |
demo.launch() |