|
import argparse |
|
import time |
|
import base64 |
|
import numpy as np |
|
import requests |
|
import os |
|
from urllib.parse import urlparse |
|
from tritonclient.http import InferenceServerClient, InferInput, InferRequestedOutput |
|
|
|
def download_image(image_url): |
|
parsed_url = urlparse(image_url) |
|
filename = os.path.basename(parsed_url.path) |
|
response = requests.get(image_url) |
|
if response.status_code == 200: |
|
with open(filename, 'wb') as img_file: |
|
img_file.write(response.content) |
|
return filename |
|
else: |
|
raise Exception("Failed to download image") |
|
|
|
def image_to_base64_data_uri(image_input): |
|
with open(image_input, "rb") as img_file: |
|
base64_data = base64.b64encode(img_file.read()).decode('utf-8') |
|
return base64_data |
|
|
|
def setup_argparse(): |
|
parser = argparse.ArgumentParser(description="Client for Triton Inference Server") |
|
parser.add_argument("--image_path", type=str, required=True, help="Path to the image or URL of the image to process") |
|
parser.add_argument("--prompt", type=str, required=True, help="Prompt to be used for the inference") |
|
return parser.parse_args() |
|
|
|
if __name__ == "__main__": |
|
args = setup_argparse() |
|
|
|
triton_client = InferenceServerClient(url="localhost:8000", verbose=False) |
|
|
|
if args.image_path.startswith('http://') or args.image_path.startswith('https://'): |
|
image_path = download_image(args.image_path) |
|
else: |
|
image_path = args.image_path |
|
|
|
image_data = image_to_base64_data_uri(image_path).encode('utf-8') |
|
image_data_np = np.array([image_data], dtype=object) |
|
prompt_np = np.array([args.prompt.encode('utf-8')], dtype=object) |
|
|
|
images_in = InferInput(name="IMAGES", shape=[1], datatype="BYTES") |
|
images_in.set_data_from_numpy(image_data_np, binary_data=True) |
|
prompt_in = InferInput(name="PROMPT", shape=[1], datatype="BYTES") |
|
prompt_in.set_data_from_numpy(prompt_np, binary_data=True) |
|
|
|
results_out = InferRequestedOutput(name="RESULTS", binary_data=False) |
|
|
|
start_time = time.time() |
|
response = triton_client.infer(model_name="spacellava", |
|
model_version="1", |
|
inputs=[prompt_in, images_in], |
|
outputs=[results_out]) |
|
|
|
results = response.get_response()["outputs"][0]["data"][0] |
|
print("--- %s seconds ---" % (time.time() - start_time)) |
|
print(results) |
|
|
|
|