MatthiasC commited on
Commit
b46e9dc
·
1 Parent(s): 0de5eaa

Changes, 1 worker on server and remove clip usage from server

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. server.py +7 -4
app.py CHANGED
@@ -11,7 +11,7 @@ import requests
11
 
12
 
13
  def start_server():
14
- os.system("uvicorn server:app --port 8080 --host 0.0.0.0 --workers 2")
15
 
16
 
17
  def load_models():
 
11
 
12
 
13
  def start_server():
14
+ os.system("uvicorn server:app --port 8080 --host 0.0.0.0 --workers 1")
15
 
16
 
17
  def load_models():
server.py CHANGED
@@ -12,6 +12,7 @@ from PIL import Image
12
 
13
  #import clip
14
  from dalle.models import Dalle
 
15
  from dalle.utils.utils import clip_score, download
16
 
17
  print("Loading models...")
@@ -78,18 +79,20 @@ def generate(prompt):
78
 
79
  def sample(prompt):
80
  # Sampling
 
81
  images = (
82
  model.sampling(prompt=prompt, top_k=96, top_p=None, softmax_temperature=1.0, num_candidates=9, device=device)
83
  .cpu()
84
  .numpy()
85
  )
 
86
  images = np.transpose(images, (0, 2, 3, 1))
87
 
88
  # CLIP Re-ranking
89
- rank = clip_score(
90
- prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device
91
- )
92
- images = images[rank]
93
 
94
  pil_images = []
95
  for i in range(len(images)):
 
12
 
13
  #import clip
14
  from dalle.models import Dalle
15
+ import logging
16
  from dalle.utils.utils import clip_score, download
17
 
18
  print("Loading models...")
 
79
 
80
  def sample(prompt):
81
  # Sampling
82
+ logging.info("starting sampling")
83
  images = (
84
  model.sampling(prompt=prompt, top_k=96, top_p=None, softmax_temperature=1.0, num_candidates=9, device=device)
85
  .cpu()
86
  .numpy()
87
  )
88
+ logging.info("sampling succeeded")
89
  images = np.transpose(images, (0, 2, 3, 1))
90
 
91
  # CLIP Re-ranking
92
+ # rank = clip_score(
93
+ # prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device
94
+ # )
95
+ # images = images[rank]
96
 
97
  pil_images = []
98
  for i in range(len(images)):