Ashrafb commited on
Commit
ef1a321
·
verified ·
1 Parent(s): a48eda7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +15 -7
main.py CHANGED
@@ -19,19 +19,27 @@ client = Client("Ashrafb/moondream_captioning")
19
  # Create the "uploads" directory if it doesn't exist
20
  os.makedirs("uploads", exist_ok=True)
21
 
22
- # Define a function to save uploaded file
 
 
 
23
  async def save_upload_file(upload_file: UploadFile) -> str:
24
- file_path = os.path.join("uploads", upload_file.filename)
25
- with open(file_path, "wb") as buffer:
 
 
 
 
26
  buffer.write(await upload_file.read())
27
- return file_path
 
28
 
29
  @app.post("/get_caption")
30
  async def get_caption(image: UploadFile = File(...), context: str = None):
31
  # Save the uploaded image to a temporary file
32
- file_path = await save_upload_file(image)
33
- # Pass the file path to the client for prediction
34
- result = client.predict(file_path, context, api_name="/get_caption")
35
  return {"caption": result}
36
 
37
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
19
  # Create the "uploads" directory if it doesn't exist
20
  os.makedirs("uploads", exist_ok=True)
21
 
22
+ import os
23
+ import tempfile
24
+
25
+ # Define a function to save uploaded file to a temporary file
26
  async def save_upload_file(upload_file: UploadFile) -> str:
27
+ # Create a temporary directory if it doesn't exist
28
+ os.makedirs("temp_uploads", exist_ok=True)
29
+ # Create a temporary file path
30
+ temp_file_path = os.path.join("temp_uploads", tempfile.NamedTemporaryFile().name)
31
+ # Save the uploaded file to the temporary file
32
+ with open(temp_file_path, "wb") as buffer:
33
  buffer.write(await upload_file.read())
34
+ return temp_file_path
35
+
36
 
37
  @app.post("/get_caption")
38
  async def get_caption(image: UploadFile = File(...), context: str = None):
39
  # Save the uploaded image to a temporary file
40
+ temp_file_path = await save_upload_file(image)
41
+ # Pass the temporary file path to the client for prediction
42
+ result = client.predict(temp_file_path, context, api_name="/get_caption")
43
  return {"caption": result}
44
 
45
  app.mount("/", StaticFiles(directory="static", html=True), name="static")