import json import openai import os import time import logging import base64 import requests from datetime import datetime from tenacity import retry, wait_exponential, stop_after_attempt from datasets import load_dataset # Initialize global variables logger = logging.getLogger('benchmark') model_name = 'chatgpt-4o-latest' # default value temperature = 0.2 # default value log_filename = None def setup_logging(filename): """Setup logging configuration""" global logger logger.setLevel(logging.INFO) # Remove any existing handlers logger.handlers = [] # Create file handler handler = logging.FileHandler(filename) handler.setFormatter(logging.Formatter('%(message)s')) logger.addHandler(handler) return logger def encode_image(image_path): """Encode local image to base64 string""" try: with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') except Exception as e: print(f"Error encoding image {image_path}: {str(e)}") return None def encode_image_url(image_url): """Encode image from URL to base64 string""" try: response = requests.get(image_url) response.raise_for_status() return base64.b64encode(response.content).decode('utf-8') except Exception as e: print(f"Error encoding image from URL {image_url}: {str(e)}") return None @retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3)) def create_multimodal_request(example, client, use_urls=False, shutdown_event=None): """ Create a multimodal request from a dataset example Args: example: Dataset example to process client: OpenAI client use_urls: Boolean flag to use image URLs instead of local files shutdown_event: Optional threading.Event for graceful shutdown """ prompt = f"""Given the following medical case: Please answer this multiple choice question: {example['question']} Base your answer only on the provided images and case information.""" content = [{"type": "text", "text": prompt}] if use_urls: # Handle image URLs from the dataset image_urls = example['image_source_urls'] if isinstance(image_urls, str): image_urls = [image_urls] elif isinstance(image_urls[0], list): # Handle nested lists image_urls = [url for sublist in image_urls for url in sublist] for img_url in image_urls: if img_url and isinstance(img_url, str): base64_image = encode_image_url(img_url) if base64_image: content.append({ "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{base64_image}" } }) print(f"Successfully loaded image from URL: {img_url}") else: # Handle local image files image_paths = example['images'] if isinstance(image_paths, str): image_paths = [image_paths] elif isinstance(image_paths[0], list): # Handle nested lists image_paths = [path for sublist in image_paths for path in sublist] for img_path in image_paths: if img_path and isinstance(img_path, str): img_path = img_path.replace('figures/', '') full_path = os.path.join("figures", img_path) if os.path.exists(full_path): base64_image = encode_image(full_path) if base64_image: content.append({ "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{base64_image}" } }) print(f"Successfully loaded image: {full_path}") else: print(f"Image file not found: {full_path}") # If no images found, log and return None if len(content) == 1: # Only the text prompt exists print(f"No images found for question {example.get('question_id', 'unknown')}") log_entry = { "question_id": example.get('question_id', 'unknown'), "timestamp": datetime.now().isoformat(), "model": model_name, "temperature": temperature, "status": "skipped", "reason": "no_images", "input": { "question": example['question'], "explanation": example.get('explanation', ''), "image_paths": example.get('images' if not use_urls else 'image_source_urls') } } logger.info(json.dumps(log_entry)) return None messages = [ {"role": "system", "content": "You are a medical imaging expert. Provide only the letter corresponding to your answer choice (A/B/C/D/E/F)."}, {"role": "user", "content": content} ] try: start_time = time.time() response = client.chat.completions.create( model=model_name, messages=messages, max_tokens=50, temperature=temperature ) duration = time.time() - start_time log_entry = { "question_id": example.get('question_id', 'unknown'), "timestamp": datetime.now().isoformat(), "model": model_name, "temperature": temperature, "duration": round(duration, 2), "usage": { "prompt_tokens": response.usage.prompt_tokens, "completion_tokens": response.usage.completion_tokens, "total_tokens": response.usage.total_tokens }, "model_answer": response.choices[0].message.content, "correct_answer": example['answer'], "input": { "messages": messages, "question": example['question'], "explanation": example.get('explanation', ''), "image_source": "url" if use_urls else "local", "images": example.get('image_source_urls' if use_urls else 'images') } } logger.info(json.dumps(log_entry)) return response except Exception as e: log_entry = { "question_id": example.get('question_id', 'unknown'), "timestamp": datetime.now().isoformat(), "model": model_name, "temperature": temperature, "status": "error", "error": str(e), "input": { "messages": messages, "question": example['question'], "explanation": example.get('explanation', ''), "image_source": "url" if use_urls else "local", "images": example.get('image_source_urls' if use_urls else 'images') } } logger.info(json.dumps(log_entry)) print(f"Error processing question {example.get('question_id', 'unknown')}: {str(e)}") raise def main(): import signal import threading import argparse # Add command line argument parsing parser = argparse.ArgumentParser(description='Run medical image analysis benchmark') parser.add_argument('--use-urls', action='store_true', help='Use image URLs instead of local files') parser.add_argument('--model', type=str, default='chatgpt-4o-latest', help='Model name to use') parser.add_argument('--temperature', type=float, default=0.2, help='Temperature for model inference') parser.add_argument('--log-prefix', type=str, help='Prefix for log filename (default: model name)') parser.add_argument('--max-cases', type=int, default=None, help='Maximum number of cases to process (default: all)') args = parser.parse_args() # Set global variables global model_name, temperature, log_filename model_name = args.model temperature = args.temperature log_prefix = args.log_prefix if args.log_prefix is not None else args.model log_filename = f"{log_prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" # Setup logging setup_logging(log_filename) # Create an event for handling graceful shutdown shutdown_event = threading.Event() def signal_handler(signum, frame): print("\nShutdown signal received. Completing current task...") shutdown_event.set() # Register signal handlers signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) # Load the dataset from Hugging Face dataset = load_dataset("json", data_files="chestagentbench/metadata.jsonl") train_dataset = dataset["train"] api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY environment variable is not set.") client = openai.OpenAI(api_key=api_key) total_examples = len(train_dataset) processed = 0 skipped = 0 print(f"Beginning benchmark evaluation for model {model_name}") print(f"Using {'image URLs' if args.use_urls else 'local files'} for images") print(f"Temperature: {temperature}") # Handle max cases limit dataset_to_process = train_dataset if args.max_cases is not None: dataset_to_process = train_dataset.select(range(min(args.max_cases, len(train_dataset)))) total_examples = len(dataset_to_process) print(f"Processing {total_examples} cases (limited by --max-cases argument)") for example in dataset_to_process: if shutdown_event.is_set(): print("\nGraceful shutdown initiated. Saving progress...") break processed += 1 response = create_multimodal_request(example, client, args.use_urls, shutdown_event) if response is None: skipped += 1 print(f"Skipped question: {example.get('question_id', 'unknown')}") continue print(f"Progress: {processed}/{total_examples}") print(f"Question ID: {example.get('question_id', 'unknown')}") print(f"Model Answer: {response.choices[0].message.content}") print(f"Correct Answer: {example['answer']}\n") print(f"\nBenchmark Summary:") print(f"Total Examples Processed: {processed}") print(f"Total Examples Skipped: {skipped}") # Verify log file exists and has content if os.path.exists(log_filename) and os.path.getsize(log_filename) > 0: print(f"\nLog file saved to: {os.path.abspath(log_filename)}") else: print(f"\nWarning: Log file could not be verified at: {os.path.abspath(log_filename)}") print("Please check directory permissions and available disk space.") if __name__ == "__main__": main()