File size: 5,615 Bytes
628dd10
99caaea
 
 
 
 
2335e48
99caaea
 
83798fc
99caaea
2335e48
 
 
 
 
 
 
 
 
99caaea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a67055
99caaea
 
 
4a10f8f
 
99caaea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a10f8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99caaea
 
 
 
 
 
4a10f8f
 
 
 
 
 
 
 
 
 
 
99caaea
 
 
 
 
 
4a10f8f
 
99caaea
4a10f8f
99caaea
 
 
 
4a10f8f
 
 
 
99caaea
4a10f8f
 
99caaea
 
 
 
 
 
 
4a10f8f
 
 
 
 
 
 
 
 
 
 
 
 
99caaea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a10f8f
 
 
 
99caaea
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
import io
from PIL import Image
import os
from cryptography.fernet import Fernet
from google.cloud import storage
import pinecone
import json

# decrypt Storage Cloud credentials
fernet = Fernet(os.environ['DECRYPTION_KEY'])

with open('cloud-storage.encrypted', 'rb') as fp:
    encrypted = fp.read()
    creds = json.loads(fernet.decrypt(encrypted).decode())
# then save creds to file
with open('cloud-storage.json', 'w', encoding='utf-8') as fp:
    fp.write(json.dumps(creds, indent=4))
# connect to Cloud Storage
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'cloud-storage.json'
storage_client = storage.Client()
bucket = storage_client.get_bucket('hf-diffusion-images')
    
# get api key for pinecone auth
PINECONE_KEY = os.environ['PINECONE_KEY']

index_id = "hf-diffusion"

# init connection to pinecone
pinecone.init(
    api_key=PINECONE_KEY,
    environment="us-west1-gcp"
)
if index_id not in pinecone.list_indexes():
    raise ValueError(f"Index '{index_id}' not found")

index = pinecone.Index(index_id)

device = 'cpu'

# init all of the models and move them to a given GPU
pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", use_auth_token=os.environ['HF_AUTH']
)
pipe.to(device)

missing_im = Image.open('missing.png')

def encode_text(text: str):
    text_inputs = pipe.tokenizer(
        text, return_tensors='pt'
    ).to(device)
    text_embeds = pipe.text_encoder(**text_inputs)
    text_embeds = text_embeds.pooler_output.cpu().tolist()[0]
    return text_embeds

def prompt_query(text: str):
    embeds = encode_text(text)
    xc = index.query(embeds, top_k=30, include_metadata=True)
    prompts = [
        match['metadata']['prompt'] for match in xc['matches']
    ]
    # deduplicate while preserving order
    prompts = list(dict.fromkeys(prompts))
    return [[x] for x in prompts[:5]]

def diffuse(text: str):
    # diffuse
    out = pipe(text)
    if any(out.nsfw_content_detected):
        return {}
    else:
        _id = str(uuid.uuid4())
        # add image to Cloud Storage
        im = out.images[0]
        im.save(f'{_id}.png', format='png')
        # push to storage
        blob = bucket.blob(f'images/{_id}.png')
        blob.upload_from_filename(f'{_id}.png')
        # delete local file
        os.remove(f'{_id}.png')
        # add embedding and metadata to Pinecone
        embeds = encode_text(text)
        meta = {
            'prompt': text,
            'image_url': f'images/{_id}.png'
        }
        index.upsert([(_id, embeds, meta)])
        return out.images[0]

def get_image(url: str):
    blob = bucket.blob(url).download_as_string()
    blob_bytes = io.BytesIO(blob)
    im = Image.open(blob_bytes)
    return im

def test_image(_id, image):
    try:
        image.save('tmp.png')
        return True
    except OSError:
        # delete corrupted file from pinecone and cloud
        index.delete(ids=[_id])
        bucket.blob(f"images/{_id}.png").delete()
        print(f"DELETED '{_id}'")
        return False

def prompt_image(text: str):
    embeds = encode_text(text)
    xc = index.query(embeds, top_k=9, include_metadata=True)
    image_urls = [
        match['metadata']['image_url'] for match in xc['matches']
    ]
    scores = [match['score'] for match in xc['matches']]
    ids = [match['id'] for match in xc['matches']]
    images = []
    for _id, image_url in zip(ids, image_urls):
        try:
            blob = bucket.blob(image_url).download_as_string()
            blob_bytes = io.BytesIO(blob)
            im = Image.open(blob_bytes)
            if test_image(_id, im):
                images.append(im)
            else:
                images.append(missing_im)
        except ValueError:
            print(f"ValueError: '{image_url}'")
    return images, scores

# __APP FUNCTIONS__

def set_suggestion(text: str):
    return gr.TextArea.update(value=text[0])

def set_images(text: str):
    images, scores = prompt_image(text)
    match_found = False
    for score in scores:
        if score > 0.85:
            match_found = True
    if match_found:
        print("MATCH FOUND")
        return gr.Gallery.update(value=images)
    else:
        print("NO MATCH FOUND")
        diffuse(text)
        images, scores = prompt_image(text)
        return gr.Gallery.update(value=images)

# __CREATE APP__
demo = gr.Blocks()

with demo:
    gr.Markdown(
        """
        # Dream Cacher
        """
    )
    with gr.Row():
        with gr.Column():
            prompt = gr.TextArea(
                value="A dream about a cat",
                placeholder="Enter a prompt to dream about",
                interactive=True
            )
            search = gr.Button(value="Search!")
            suggestions = gr.Dataset(
                components=[prompt],
                samples=[
                    ["Something"],
                    ["something else"]
                ]
            )
            # event listener for change in prompt
            prompt.change(prompt_query, prompt, suggestions)
            # event listener for click on suggestion
            suggestions.click(
                set_suggestion,
                suggestions,
                suggestions.components
            )
            
        # results column
        with gr.Column():
            pics = gr.Gallery()
            pics.style(grid=3)
            # search event listening
            try:
                search.click(set_images, prompt, pics)
            except OSError:
                print("OSError")

demo.launch()