Spaces:
Build error
Build error
File size: 1,758 Bytes
8ce5e2d b46e9dc e65b549 9f824d9 8ce5e2d e65b549 72a955a e65b549 8ce5e2d 9b4f999 e65b549 5cb75cc e65b549 8ce5e2d b46e9dc 8ce5e2d b46e9dc 8ce5e2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
import sys
import base64
from io import BytesIO
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
from torch import nn
from fastapi import FastAPI
import numpy as np
from PIL import Image
from dalle.models import Dalle
import logging
import streamlit as st
print("Loading models...")
app = FastAPI()
from huggingface_hub import hf_hub_download
logging.info("Start downloading")
full_dict_path = hf_hub_download(repo_id="ml6team/logo-generator", filename="full_dict_new.ckpt",
use_auth_token=st.secrets["model_download"])
logging.info("End downloading")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Dalle.from_pretrained("minDALL-E/1.3B")
model.load_state_dict(torch.load(full_dict_path, map_location=torch.device('cpu')))
model.to(device=device)
print("Models loaded !")
@app.get("/")
def read_root():
return {"minDALL-E!"}
@app.get("/{generate}")
def generate(prompt):
images = sample(prompt)
images = [to_base64(image) for image in images]
return {"images": images}
def sample(prompt):
# Sampling
logging.info("starting sampling")
images = (
model.sampling(prompt=prompt, top_k=96, top_p=None, softmax_temperature=1.0, num_candidates=9, device=device)
.cpu()
.numpy()
)
logging.info("sampling succeeded")
images = np.transpose(images, (0, 2, 3, 1))
pil_images = []
for i in range(len(images)):
im = Image.fromarray((images[i] * 255).astype(np.uint8))
pil_images.append(im)
return pil_images
def to_base64(pil_image):
buffered = BytesIO()
pil_image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()) |