import torch import torch.nn.functional as F from datetime import datetime import os import csv def few_shot_fault_classification( model, test_images, test_image_filenames, nominal_images, nominal_descriptions, defective_images, defective_descriptions, num_few_shot_nominal_imgs: int, device="cpu", file_path: str = '.', file_name: str = 'image_classification_results.csv', print_one_liner: bool = False ): """ Classify test images as nominal or defective based on similarity to nominal and defective images. """ # Ensure inputs are lists if not isinstance(test_images, list): test_images = [test_images] if not isinstance(test_image_filenames, list): test_image_filenames = [test_image_filenames] if not isinstance(nominal_images, list): nominal_images = [nominal_images] if not isinstance(nominal_descriptions, list): nominal_descriptions = [nominal_descriptions] if not isinstance(defective_images, list): defective_images = [defective_images] if not isinstance(defective_descriptions, list): defective_descriptions = [defective_descriptions] # Ensure the output directory exists os.makedirs(file_path, exist_ok=True) # Prepare full path for the CSV file csv_file = os.path.join(file_path, file_name) results = [] with torch.no_grad(): # Encode nominal images nominal_features = torch.stack([model.encode_image(img.to(device)) for img in nominal_images]) nominal_features /= nominal_features.norm(dim=-1, keepdim=True) # Encode defective images defective_features = torch.stack([model.encode_image(img.to(device)) for img in defective_images]) defective_features /= defective_features.norm(dim=-1, keepdim=True) # Prepare list to save data for CSV csv_data = [] # Process each test image for idx, test_img in enumerate(test_images): test_features = model.encode_image(test_img.to(device)) test_features /= test_features.norm(dim=-1, keepdim=True) # Initialize variables to store max similarities and indices max_nominal_similarity = -float('inf') max_defective_similarity = -float('inf') max_nominal_idx = -1 max_defective_idx = -1 # Loop through each nominal image to find max similarity for i in range(nominal_features.shape[0]): similarity = (test_features @ nominal_features[i].T).item() if similarity > max_nominal_similarity: max_nominal_similarity = similarity max_nominal_idx = i # Loop through each defective image to find max similarity for j in range(defective_features.shape[0]): similarity = (test_features @ defective_features[j].T).item() if similarity > max_defective_similarity: max_defective_similarity = similarity max_defective_idx = j # Convert similarities to probabilities similarities = torch.tensor([max_nominal_similarity, max_defective_similarity]) probabilities = F.softmax(similarities, dim=0).tolist() prob_not_defective = probabilities[0] prob_defective = probabilities[1] # Determine classification result classification = "Defective" if prob_defective > prob_not_defective else "Nominal" # Create result dict result = { "datetime_of_operation": datetime.now().isoformat(), "num_few_shot_nominal_imgs": num_few_shot_nominal_imgs, "image_path": test_image_filenames[idx], "image_name": test_image_filenames[idx].split('/')[-1], "classification_result": classification, "non_defect_prob": round(prob_not_defective, 3), "defect_prob": round(prob_defective, 3), "nominal_description": nominal_descriptions[max_nominal_idx], "defective_description": defective_descriptions[max_defective_idx], "max_nominal_similarity": round(max_nominal_similarity, 3), "max_defective_similarity": round(max_defective_similarity, 3) } csv_data.append(result) results.append(result) # Optionally print one-liner summary for each test image if print_one_liner: print(f"{test_image_filenames[idx]} → {classification} " f"(Nominal: {prob_not_defective:.3f}, Defective: {prob_defective:.3f})") # Write to CSV (append mode if file exists, write mode if not) file_exists = os.path.isfile(csv_file) with open(csv_file, mode='a' if file_exists else 'w', newline='') as file: fieldnames = [ "datetime_of_operation", "num_few_shot_nominal_imgs", "image_path", "image_name", "classification_result", "non_defect_prob", "defect_prob", "nominal_description", "defective_description", "max_nominal_similarity", "max_defective_similarity" ] writer = csv.DictWriter(file, fieldnames=fieldnames) # Write header if file doesn't exist if not file_exists: writer.writeheader() # Write each row of data for row in csv_data: writer.writerow(row) return results