File size: 3,941 Bytes
628dd10
99caaea
 
 
 
 
2335e48
99caaea
 
83798fc
99caaea
2335e48
 
 
 
 
 
 
 
 
99caaea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a67055
99caaea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
import io
from PIL import Image
import os
from cryptography.fernet import Fernet
from google.cloud import storage
import pinecone
import json

# decrypt Storage Cloud credentials
fernet = Fernet(os.environ['DECRYPTION_KEY'])

with open('cloud-storage.encrypted', 'rb') as fp:
    encrypted = fp.read()
    creds = json.loads(fernet.decrypt(encrypted).decode())
# then save creds to file
with open('cloud-storage.json', 'w', encoding='utf-8') as fp:
    fp.write(json.dumps(creds, indent=4))
# connect to Cloud Storage
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json'
storage_client = storage.Client()
bucket = storage_client.get_bucket('hf-diffusion-images')
    
# get api key for pinecone auth
PINECONE_KEY = os.environ['PINECONE_KEY']

index_id = "hf-diffusion"

# init connection to pinecone
pinecone.init(
    api_key=PINECONE_KEY,
    environment="us-west1-gcp"
)
if index_id not in pinecone.list_indexes():
    raise ValueError(f"Index '{index_id}' not found")

index = pinecone.Index(index_id)

device = 'cpu'

# init all of the models and move them to a given GPU
pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", use_auth_token=os.environ['HF_AUTH']
)
pipe.to(device)

def encode_text(text: str):
    text_inputs = pipe.tokenizer(
        text, return_tensors='pt'
    ).to(device)
    text_embeds = pipe.text_encoder(**text_inputs)
    text_embeds = text_embeds.pooler_output.cpu().tolist()[0]
    return text_embeds

def prompt_query(text: str):
    embeds = encode_text(text)
    xc = index.query(embeds, top_k=30, include_metadata=True)
    prompts = [
        match['metadata']['prompt'] for match in xc['matches']
    ]
    # deduplicate while preserving order
    prompts = list(dict.fromkeys(prompts))
    return [[x] for x in prompts[:5]]

def get_image(url: str):
    blob = bucket.blob(url).download_as_string()
    blob_bytes = io.BytesIO(blob)
    im = Image.open(blob_bytes)
    return im

def prompt_image(text: str):
    embeds = encode_text(text)
    xc = index.query(embeds, top_k=9, include_metadata=True)
    image_urls = [
        match['metadata']['image_url'] for match in xc['matches']
    ]
    images = []
    for image_url in image_urls:
        try:
            blob = bucket.blob(image_url).download_as_string()
            blob_bytes = io.BytesIO(blob)
            im = Image.open(blob_bytes)
            images.append(im)
        except ValueError:
            print(f"error for '{image_url}'")
    return images

# __APP FUNCTIONS__

def set_suggestion(text: str):
    return gr.TextArea.update(value=text[0])

def set_images(text: str):
    images = prompt_image(text)
    return gr.Gallery.update(value=images)

# __CREATE APP__
demo = gr.Blocks()

with demo:
    gr.Markdown(
        """
        # Dream Cacher
        """
    )
    with gr.Row():
        with gr.Column():
            prompt = gr.TextArea(
                value="A dream about a cat",
                placeholder="Enter a prompt to dream about",
                interactive=True
            )
            search = gr.Button(value="Search!")
            suggestions = gr.Dataset(
                components=[prompt],
                samples=[
                    ["Something"],
                    ["something else"]
                ]
            )
            # event listener for change in prompt
            prompt.change(prompt_query, prompt, suggestions)
            # event listener for click on suggestion
            suggestions.click(
                set_suggestion,
                suggestions,
                suggestions.components
            )
            
        
        # results column
        with gr.Column():
            pics = gr.Gallery()
            pics.style(grid=3)
            # search event listening
            search.click(set_images, prompt, pics)

demo.launch()