clip / classifier.py
fmegahed's picture
Update classifier.py
a1e1c29 verified
import torch
import torch.nn.functional as F
from datetime import datetime
import os
import csv
def few_shot_fault_classification(
model,
test_images,
test_image_filenames,
nominal_images,
nominal_descriptions,
defective_images,
defective_descriptions,
num_few_shot_nominal_imgs: int,
device="cpu",
file_path: str = '.',
file_name: str = 'image_classification_results.csv',
print_one_liner: bool = False
):
"""
Classify test images as nominal or defective based on similarity to nominal and defective images.
"""
# Ensure inputs are lists
if not isinstance(test_images, list):
test_images = [test_images]
if not isinstance(test_image_filenames, list):
test_image_filenames = [test_image_filenames]
if not isinstance(nominal_images, list):
nominal_images = [nominal_images]
if not isinstance(nominal_descriptions, list):
nominal_descriptions = [nominal_descriptions]
if not isinstance(defective_images, list):
defective_images = [defective_images]
if not isinstance(defective_descriptions, list):
defective_descriptions = [defective_descriptions]
# Ensure the output directory exists
os.makedirs(file_path, exist_ok=True)
# Prepare full path for the CSV file
csv_file = os.path.join(file_path, file_name)
results = []
with torch.no_grad():
# Encode nominal images
nominal_features = torch.stack([model.encode_image(img.to(device)) for img in nominal_images])
nominal_features /= nominal_features.norm(dim=-1, keepdim=True)
# Encode defective images
defective_features = torch.stack([model.encode_image(img.to(device)) for img in defective_images])
defective_features /= defective_features.norm(dim=-1, keepdim=True)
# Prepare list to save data for CSV
csv_data = []
# Process each test image
for idx, test_img in enumerate(test_images):
test_features = model.encode_image(test_img.to(device))
test_features /= test_features.norm(dim=-1, keepdim=True)
# Initialize variables to store max similarities and indices
max_nominal_similarity = -float('inf')
max_defective_similarity = -float('inf')
max_nominal_idx = -1
max_defective_idx = -1
# Loop through each nominal image to find max similarity
for i in range(nominal_features.shape[0]):
similarity = (test_features @ nominal_features[i].T).item()
if similarity > max_nominal_similarity:
max_nominal_similarity = similarity
max_nominal_idx = i
# Loop through each defective image to find max similarity
for j in range(defective_features.shape[0]):
similarity = (test_features @ defective_features[j].T).item()
if similarity > max_defective_similarity:
max_defective_similarity = similarity
max_defective_idx = j
# Convert similarities to probabilities
similarities = torch.tensor([max_nominal_similarity, max_defective_similarity])
probabilities = F.softmax(similarities, dim=0).tolist()
prob_not_defective = probabilities[0]
prob_defective = probabilities[1]
# Determine classification result
classification = "Defective" if prob_defective > prob_not_defective else "Nominal"
# Create result dict
result = {
"datetime_of_operation": datetime.now().isoformat(),
"num_few_shot_nominal_imgs": num_few_shot_nominal_imgs,
"image_path": test_image_filenames[idx],
"image_name": test_image_filenames[idx].split('/')[-1],
"classification_result": classification,
"non_defect_prob": round(prob_not_defective, 3),
"defect_prob": round(prob_defective, 3),
"nominal_description": nominal_descriptions[max_nominal_idx],
"defective_description": defective_descriptions[max_defective_idx],
"max_nominal_similarity": round(max_nominal_similarity, 3),
"max_defective_similarity": round(max_defective_similarity, 3)
}
csv_data.append(result)
results.append(result)
# Optionally print one-liner summary for each test image
if print_one_liner:
print(f"{test_image_filenames[idx]}{classification} "
f"(Nominal: {prob_not_defective:.3f}, Defective: {prob_defective:.3f})")
# Write to CSV (append mode if file exists, write mode if not)
file_exists = os.path.isfile(csv_file)
with open(csv_file, mode='a' if file_exists else 'w', newline='') as file:
fieldnames = [
"datetime_of_operation", "num_few_shot_nominal_imgs", "image_path", "image_name",
"classification_result", "non_defect_prob", "defect_prob",
"nominal_description", "defective_description",
"max_nominal_similarity", "max_defective_similarity"
]
writer = csv.DictWriter(file, fieldnames=fieldnames)
# Write header if file doesn't exist
if not file_exists:
writer.writeheader()
# Write each row of data
for row in csv_data:
writer.writerow(row)
return results