import argparse import time import base64 import numpy as np import requests import os from urllib.parse import urlparse from tritonclient.http import InferenceServerClient, InferInput, InferRequestedOutput def download_image(image_url): parsed_url = urlparse(image_url) filename = os.path.basename(parsed_url.path) response = requests.get(image_url) if response.status_code == 200: with open(filename, 'wb') as img_file: img_file.write(response.content) return filename else: raise Exception("Failed to download image") def image_to_base64_data_uri(image_input): with open(image_input, "rb") as img_file: base64_data = base64.b64encode(img_file.read()).decode('utf-8') return base64_data def setup_argparse(): parser = argparse.ArgumentParser(description="Client for Triton Inference Server") parser.add_argument("--image_path", type=str, required=True, help="Path to the image or URL of the image to process") parser.add_argument("--prompt", type=str, required=True, help="Prompt to be used for the inference") return parser.parse_args() if __name__ == "__main__": args = setup_argparse() triton_client = InferenceServerClient(url="localhost:8000", verbose=False) if args.image_path.startswith('http://') or args.image_path.startswith('https://'): image_path = download_image(args.image_path) else: image_path = args.image_path image_data = image_to_base64_data_uri(image_path).encode('utf-8') image_data_np = np.array([image_data], dtype=object) prompt_np = np.array([args.prompt.encode('utf-8')], dtype=object) images_in = InferInput(name="IMAGES", shape=[1], datatype="BYTES") images_in.set_data_from_numpy(image_data_np, binary_data=True) prompt_in = InferInput(name="PROMPT", shape=[1], datatype="BYTES") prompt_in.set_data_from_numpy(prompt_np, binary_data=True) results_out = InferRequestedOutput(name="RESULTS", binary_data=False) start_time = time.time() response = triton_client.infer(model_name="spacellava", model_version="1", inputs=[prompt_in, images_in], outputs=[results_out]) results = response.get_response()["outputs"][0]["data"][0] print("--- %s seconds ---" % (time.time() - start_time)) print(results)