|
import torch |
|
import torch.nn.functional as F |
|
from datetime import datetime |
|
import os |
|
import csv |
|
|
|
def few_shot_fault_classification( |
|
model, |
|
test_images, |
|
test_image_filenames, |
|
nominal_images, |
|
nominal_descriptions, |
|
defective_images, |
|
defective_descriptions, |
|
num_few_shot_nominal_imgs: int, |
|
device="cpu", |
|
file_path: str = '.', |
|
file_name: str = 'image_classification_results.csv', |
|
print_one_liner: bool = False |
|
): |
|
""" |
|
Classify test images as nominal or defective based on similarity to nominal and defective images. |
|
""" |
|
|
|
if not isinstance(test_images, list): |
|
test_images = [test_images] |
|
if not isinstance(test_image_filenames, list): |
|
test_image_filenames = [test_image_filenames] |
|
if not isinstance(nominal_images, list): |
|
nominal_images = [nominal_images] |
|
if not isinstance(nominal_descriptions, list): |
|
nominal_descriptions = [nominal_descriptions] |
|
if not isinstance(defective_images, list): |
|
defective_images = [defective_images] |
|
if not isinstance(defective_descriptions, list): |
|
defective_descriptions = [defective_descriptions] |
|
|
|
|
|
os.makedirs(file_path, exist_ok=True) |
|
|
|
|
|
csv_file = os.path.join(file_path, file_name) |
|
results = [] |
|
|
|
with torch.no_grad(): |
|
|
|
nominal_features = torch.stack([model.encode_image(img.to(device)) for img in nominal_images]) |
|
nominal_features /= nominal_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
defective_features = torch.stack([model.encode_image(img.to(device)) for img in defective_images]) |
|
defective_features /= defective_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
csv_data = [] |
|
|
|
|
|
for idx, test_img in enumerate(test_images): |
|
test_features = model.encode_image(test_img.to(device)) |
|
test_features /= test_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
max_nominal_similarity = -float('inf') |
|
max_defective_similarity = -float('inf') |
|
max_nominal_idx = -1 |
|
max_defective_idx = -1 |
|
|
|
|
|
for i in range(nominal_features.shape[0]): |
|
similarity = (test_features @ nominal_features[i].T).item() |
|
if similarity > max_nominal_similarity: |
|
max_nominal_similarity = similarity |
|
max_nominal_idx = i |
|
|
|
|
|
for j in range(defective_features.shape[0]): |
|
similarity = (test_features @ defective_features[j].T).item() |
|
if similarity > max_defective_similarity: |
|
max_defective_similarity = similarity |
|
max_defective_idx = j |
|
|
|
|
|
similarities = torch.tensor([max_nominal_similarity, max_defective_similarity]) |
|
probabilities = F.softmax(similarities, dim=0).tolist() |
|
prob_not_defective = probabilities[0] |
|
prob_defective = probabilities[1] |
|
|
|
|
|
classification = "Defective" if prob_defective > prob_not_defective else "Nominal" |
|
|
|
|
|
result = { |
|
"datetime_of_operation": datetime.now().isoformat(), |
|
"num_few_shot_nominal_imgs": num_few_shot_nominal_imgs, |
|
"image_path": test_image_filenames[idx], |
|
"image_name": test_image_filenames[idx].split('/')[-1], |
|
"classification_result": classification, |
|
"non_defect_prob": round(prob_not_defective, 3), |
|
"defect_prob": round(prob_defective, 3), |
|
"nominal_description": nominal_descriptions[max_nominal_idx], |
|
"defective_description": defective_descriptions[max_defective_idx], |
|
"max_nominal_similarity": round(max_nominal_similarity, 3), |
|
"max_defective_similarity": round(max_defective_similarity, 3) |
|
} |
|
|
|
csv_data.append(result) |
|
results.append(result) |
|
|
|
|
|
if print_one_liner: |
|
print(f"{test_image_filenames[idx]} → {classification} " |
|
f"(Nominal: {prob_not_defective:.3f}, Defective: {prob_defective:.3f})") |
|
|
|
|
|
file_exists = os.path.isfile(csv_file) |
|
with open(csv_file, mode='a' if file_exists else 'w', newline='') as file: |
|
fieldnames = [ |
|
"datetime_of_operation", "num_few_shot_nominal_imgs", "image_path", "image_name", |
|
"classification_result", "non_defect_prob", "defect_prob", |
|
"nominal_description", "defective_description", |
|
"max_nominal_similarity", "max_defective_similarity" |
|
] |
|
writer = csv.DictWriter(file, fieldnames=fieldnames) |
|
|
|
|
|
if not file_exists: |
|
writer.writeheader() |
|
|
|
|
|
for row in csv_data: |
|
writer.writerow(row) |
|
|
|
return results |