Spaces:
Runtime error
Runtime error
File size: 1,430 Bytes
be37091 f35c84d be37091 f35c84d be37091 e8b13db be37091 8439859 be37091 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import numpy as np
import os
from PIL import Image
import random
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
# Model references
# dalle-mini, mega too large
# DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or 🤗 Hub or local folder or google bucket
DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
DALLE_COMMIT_ID = None
# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
def generate_image(text):
tokenized_prompt = processor([text])
gen_top_k = None
gen_top_p = None
temperature = 0.85
cond_scale = 3.0
encoded_images = model.generate(
tokenized_prompt,
random.randint(0, 1e7),
model.params,
gen_top_k,
gen_top_p,
temperature,
cond_scale,
)
encoded_images = encoded_images.sequences[..., 1:]
decoded_images = model.decode(encoded_images, vqgan.params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
img = decoded_images[0]
return Image.fromarray(np.asarray(img * 255, dtype=np.uint8))
|