Spaces:
Sleeping
Sleeping
Commit
·
61c0885
1
Parent(s):
1b92f63
fix: inference time
Browse files- src/api/batch_api.py +24 -16
src/api/batch_api.py
CHANGED
@@ -273,11 +273,12 @@ async def rt_cto_nto(
|
|
273 |
)
|
274 |
|
275 |
async def generate():
|
276 |
-
|
277 |
|
278 |
try:
|
279 |
mask, _, _ = await pipeline.shoulderPointMaskGeneration_(image=source_image)
|
280 |
-
|
|
|
281 |
except Exception as e:
|
282 |
logger.error(f">>> MASK GENERATION ERROR: {str(e)} <<<")
|
283 |
yield json.dumps({"error": "Error generating mask", "code": 500}) + "\n"
|
@@ -299,8 +300,10 @@ async def rt_cto_nto(
|
|
299 |
return
|
300 |
|
301 |
for idx, clothing_type in enumerate(clothing_list):
|
|
|
302 |
try:
|
303 |
# Perform CTO
|
|
|
304 |
cto_output = replicate_run_cto({
|
305 |
"mask": mask_data_uri,
|
306 |
"image": image_data_uri,
|
@@ -308,18 +311,17 @@ async def rt_cto_nto(
|
|
308 |
"negative_prompt": "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly",
|
309 |
"num_inference_steps": 25
|
310 |
})
|
311 |
-
|
312 |
-
|
313 |
-
raise ValueError("Invalid output from clothing try-on")
|
314 |
|
315 |
-
# Get CTO result
|
|
|
316 |
async with aiohttp.ClientSession() as session:
|
317 |
async with session.get(str(cto_output[0])) as response:
|
318 |
if response.status != 200:
|
319 |
raise ValueError("Failed to fetch CTO output")
|
320 |
cto_result_bytes = await response.read()
|
321 |
|
322 |
-
# Process NTO
|
323 |
with BytesIO(cto_result_bytes) as buf:
|
324 |
cto_result_image = Image.open(buf).convert("RGB")
|
325 |
result, headerText, mask = await pipeline.necklaceTryOn_(
|
@@ -327,27 +329,33 @@ async def rt_cto_nto(
|
|
327 |
jewellery=jewellery,
|
328 |
storename=storename
|
329 |
)
|
330 |
-
|
331 |
-
|
332 |
-
raise ValueError("Failed to process necklace try-on")
|
333 |
|
334 |
# Upload result
|
335 |
-
|
336 |
-
|
337 |
-
|
|
|
|
|
|
|
338 |
|
339 |
-
# Stream result
|
340 |
output_result = {
|
341 |
"code": 200,
|
342 |
"output": result_url,
|
343 |
-
"
|
|
|
|
|
|
|
|
|
|
|
344 |
"clothing_type": clothing_type,
|
345 |
"progress": f"{idx + 1}/{len(clothing_list)}"
|
346 |
}
|
347 |
yield json.dumps(output_result) + "\n"
|
348 |
await asyncio.sleep(0.1)
|
349 |
|
350 |
-
# Clean up
|
351 |
del result
|
352 |
gc.collect()
|
353 |
|
|
|
273 |
)
|
274 |
|
275 |
async def generate():
|
276 |
+
setup_time = time.time() # Track setup time separately
|
277 |
|
278 |
try:
|
279 |
mask, _, _ = await pipeline.shoulderPointMaskGeneration_(image=source_image)
|
280 |
+
mask_generation_time = round(time.time() - setup_time, 2)
|
281 |
+
logger.info(f">>> MASK GENERATION COMPLETED in {mask_generation_time}s <<<")
|
282 |
except Exception as e:
|
283 |
logger.error(f">>> MASK GENERATION ERROR: {str(e)} <<<")
|
284 |
yield json.dumps({"error": "Error generating mask", "code": 500}) + "\n"
|
|
|
300 |
return
|
301 |
|
302 |
for idx, clothing_type in enumerate(clothing_list):
|
303 |
+
iteration_start_time = time.time()
|
304 |
try:
|
305 |
# Perform CTO
|
306 |
+
cto_start_time = time.time()
|
307 |
cto_output = replicate_run_cto({
|
308 |
"mask": mask_data_uri,
|
309 |
"image": image_data_uri,
|
|
|
311 |
"negative_prompt": "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly",
|
312 |
"num_inference_steps": 25
|
313 |
})
|
314 |
+
cto_time = round(time.time() - cto_start_time, 2)
|
315 |
+
logger.info(f">>> CTO COMPLETED for {clothing_type} in {cto_time}s <<<")
|
|
|
316 |
|
317 |
+
# Get CTO result and process NTO
|
318 |
+
nto_start_time = time.time()
|
319 |
async with aiohttp.ClientSession() as session:
|
320 |
async with session.get(str(cto_output[0])) as response:
|
321 |
if response.status != 200:
|
322 |
raise ValueError("Failed to fetch CTO output")
|
323 |
cto_result_bytes = await response.read()
|
324 |
|
|
|
325 |
with BytesIO(cto_result_bytes) as buf:
|
326 |
cto_result_image = Image.open(buf).convert("RGB")
|
327 |
result, headerText, mask = await pipeline.necklaceTryOn_(
|
|
|
329 |
jewellery=jewellery,
|
330 |
storename=storename
|
331 |
)
|
332 |
+
nto_time = round(time.time() - nto_start_time, 2)
|
333 |
+
logger.info(f">>> NTO COMPLETED for {clothing_type} in {nto_time}s <<<")
|
|
|
334 |
|
335 |
# Upload result
|
336 |
+
upload_start_time = time.time()
|
337 |
+
result_url = await supabase_upload_and_return_url(
|
338 |
+
prefix="clothing_necklace_try_on",
|
339 |
+
image=result
|
340 |
+
)
|
341 |
+
upload_time = round(time.time() - upload_start_time, 2)
|
342 |
|
343 |
+
# Stream result with detailed timing
|
344 |
output_result = {
|
345 |
"code": 200,
|
346 |
"output": result_url,
|
347 |
+
"timing": {
|
348 |
+
"cto_inference": cto_time,
|
349 |
+
"nto_inference": nto_time,
|
350 |
+
"upload": upload_time,
|
351 |
+
"total_iteration": round(time.time() - iteration_start_time, 2)
|
352 |
+
},
|
353 |
"clothing_type": clothing_type,
|
354 |
"progress": f"{idx + 1}/{len(clothing_list)}"
|
355 |
}
|
356 |
yield json.dumps(output_result) + "\n"
|
357 |
await asyncio.sleep(0.1)
|
358 |
|
|
|
359 |
del result
|
360 |
gc.collect()
|
361 |
|