Chaerin5 commited on
Commit
55fa8f8
·
1 Parent(s): a366fc8

enable zerogpu

Browse files
Files changed (1) hide show
  1. app.py +33 -10
app.py CHANGED
@@ -256,17 +256,40 @@ hands = mp_hands.Hands(
256
  min_detection_confidence=0.1,
257
  )
258
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  @spaces.GPU(duration=60)
260
- def make_ref_cond(
261
- image
262
- ):
263
- print("ready to run autoencoder")
264
- # print(f"image.device: {image.device}, type(image): {type(image)}")
265
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
266
- torch.cuda.set_device(0)
267
- image = image.to("cuda")
268
- latent = opts.latent_scaling_factor * autoencoder.encode(image[None, ...]).sample()
269
- return image[None, ...], latent
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  def get_ref_anno(ref):
272
  print("inside get_ref_anno")
 
256
  min_detection_confidence=0.1,
257
  )
258
 
259
+ # @spaces.GPU(duration=60)
260
+ # def make_ref_cond(
261
+ # image
262
+ # ):
263
+ # print("ready to run autoencoder")
264
+ # # print(f"image.device: {image.device}, type(image): {type(image)}")
265
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
266
+ # torch.cuda.set_device(0)
267
+ # image = image.to("cuda")
268
+ # latent = opts.latent_scaling_factor * autoencoder.encode(image[None, ...]).sample()
269
+ # return image[None, ...], latent
270
+
271
  @spaces.GPU(duration=60)
272
+ def make_ref_cond(image):
273
+ def initialize_and_process(image):
274
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
275
+ torch.cuda.set_device(0)
276
+ print("Initializing autoencoder in worker process")
277
+ image = image.to("cuda")
278
+ latent = opts.latent_scaling_factor * autoencoder.encode(image[None, ...]).sample()
279
+ return image[None, ...], latent
280
+
281
+ from multiprocessing import Process, Queue
282
+ queue = Queue()
283
+
284
+ def worker(image, queue):
285
+ result = initialize_and_process(image)
286
+ queue.put(result)
287
+
288
+ process = Process(target=worker, args=(image, queue))
289
+ process.start()
290
+ process.join()
291
+
292
+ return queue.get()
293
 
294
  def get_ref_anno(ref):
295
  print("inside get_ref_anno")