weibo1903 commited on
Commit
0b30541
·
verified ·
1 Parent(s): 3c326f9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +24 -23
main.py CHANGED
@@ -25,26 +25,27 @@ def read_root():
25
  return {"message": "API is live. Use the /predict endpoint."}
26
 
27
  @app.get("/predict")
28
- def predict():
29
-
30
- return {"got query"}
31
- # messages = [
32
- # {"role": "system", "content": "You are a helpful assistant with vision abilities."},
33
- # {"role": "user", "content": [{"type": "image", "image": image_url}, {"type": "text", "text": prompt}]},
34
- # ]
35
- # text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
36
- # image_inputs, video_inputs = process_vision_info(messages)
37
- # inputs = processor(
38
- # text=[text],
39
- # images=image_inputs,
40
- # videos=video_inputs,
41
- # padding=True,
42
- # return_tensors="pt",
43
- # ).to(model.device)
44
- # with torch.no_grad():
45
- # generated_ids = model.generate(**inputs, max_new_tokens=128)
46
- # generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
47
- # output_texts = processor.batch_decode(
48
- # generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
49
- # )
50
- # return {"response": output_texts[0]}
 
 
25
  return {"message": "API is live. Use the /predict endpoint."}
26
 
27
  @app.get("/predict")
28
+ def predict(
29
+ image_url: str = Query(..., description="URL of the image"),
30
+ prompt: str = Query(..., description="Prompt for the image")
31
+ ):
32
+ messages = [
33
+ {"role": "system", "content": "You are a helpful assistant with vision abilities."},
34
+ {"role": "user", "content": [{"type": "image", "image": image_url}, {"type": "text", "text": prompt}]},
35
+ ]
36
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
37
+ image_inputs, video_inputs = process_vision_info(messages)
38
+ inputs = processor(
39
+ text=[text],
40
+ images=image_inputs,
41
+ videos=video_inputs,
42
+ padding=True,
43
+ return_tensors="pt",
44
+ ).to(model.device)
45
+ with torch.no_grad():
46
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
47
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
48
+ output_texts = processor.batch_decode(
49
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
50
+ )
51
+ return {"response": output_texts[0]}