File size: 8,297 Bytes
b8d3576
 
0ac2665
b8d3576
0ac2665
 
 
 
 
b8d3576
0ac2665
4f8dfb9
0ac2665
3cdb380
 
4f8dfb9
b8d3576
3cdb380
 
 
 
 
 
 
 
 
 
 
 
 
 
0ac2665
3cdb380
0ac2665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cdb380
0ac2665
 
3cdb380
0ac2665
 
 
 
 
3cdb380
0ac2665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cdb380
0ac2665
 
 
 
 
 
 
 
 
3cdb380
0ac2665
b8d3576
3cdb380
0ac2665
 
 
 
b8d3576
0ac2665
 
 
 
 
 
 
 
 
 
3cdb380
0ac2665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f8dfb9
 
0ac2665
 
 
3cdb380
0ac2665
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import streamlit as st
import torch
import clip
from PIL import Image
import os
import pandas as pd
from datetime import datetime
import torch.nn.functional as F
from typing import List

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CLIP model and preprocessor (ViT-B/32 = small model, CPU-friendly)
model, preprocess = clip.load("ViT-B/32", device=device)
model.eval()

# Display app title and information
st.set_page_config(page_title="Few-Shot Fault Detection", layout="wide")
st.title("🛠️ Few-Shot Fault Detection (Industrial Quality Control)")

st.markdown("""
This demo uses the **smaller `ViT-B/32` encoder** from OpenAI's CLIP model to classify test images as **Nominal** or **Defective**, based on few-shot learning using user-provided reference images.

⚠️ **Note**: This app is running on a **free CPU tier** and is meant for demonstration purposes. For more advanced use cases, including GPU acceleration, custom training, and larger models, please refer to:

📄 [Megahed et al. (2025)](https://arxiv.org/abs/2501.12596):  
*Adapting OpenAI's CLIP Model for Few-Shot Image Inspection in Manufacturing Quality Control: An Expository Case Study with Multiple Application Examples*

🔗 [GitHub & Colab links available in the paper](https://arxiv.org/abs/2501.12596)
""")

# --- Few-shot classification logic ---
def few_shot_fault_classification(
    test_images: List[Image.Image],
    test_image_filenames: List[str],
    nominal_images: List[Image.Image],
    nominal_descriptions: List[str],
    defective_images: List[Image.Image],
    defective_descriptions: List[str],
    num_few_shot_nominal_imgs: int,
    file_path: str = '.',
    file_name: str = 'image_classification_results.csv',
    print_one_liner: bool = False
):
    if not isinstance(test_images, list): test_images = [test_images]
    if not isinstance(test_image_filenames, list): test_image_filenames = [test_image_filenames]
    if not isinstance(nominal_images, list): nominal_images = [nominal_images]
    if not isinstance(nominal_descriptions, list): nominal_descriptions = [nominal_descriptions]
    if not isinstance(defective_images, list): defective_images = [defective_images]
    if not isinstance(defective_descriptions, list): defective_descriptions = [defective_descriptions]

    csv_file = os.path.join(file_path, file_name)
    results = []

    with torch.no_grad():
        nominal_features = torch.stack([model.encode_image(img.unsqueeze(0)).squeeze(0).to(device) for img in nominal_images])
        nominal_features /= nominal_features.norm(dim=-1, keepdim=True)

        defective_features = torch.stack([model.encode_image(img.unsqueeze(0)).squeeze(0).to(device) for img in defective_images])
        defective_features /= defective_features.norm(dim=-1, keepdim=True)

        csv_data = []

        for idx, test_img in enumerate(test_images):
            test_features = model.encode_image(test_img.unsqueeze(0)).squeeze(0).to(device)
            test_features /= test_features.norm(dim=-1, keepdim=True)

            max_nom_sim, max_def_sim = -float('inf'), -float('inf')
            max_nom_idx, max_def_idx = -1, -1

            for i in range(nominal_features.shape[0]):
                sim = (test_features @ nominal_features[i].T).item()
                if sim > max_nom_sim:
                    max_nom_sim, max_nom_idx = sim, i

            for j in range(defective_features.shape[0]):
                sim = (test_features @ defective_features[j].T).item()
                if sim > max_def_sim:
                    max_def_sim, max_def_idx = sim, j

            similarities = torch.tensor([max_nom_sim, max_def_sim])
            probabilities = F.softmax(similarities, dim=0).tolist()
            prob_nom, prob_def = probabilities

            classification = "Defective" if prob_def > prob_nom else "Nominal"

            csv_data.append({
                "datetime_of_operation": datetime.now().isoformat(),
                "num_few_shot_nominal_imgs": num_few_shot_nominal_imgs,
                "image_path": test_image_filenames[idx],
                "image_name": test_image_filenames[idx].split('/')[-1],
                "classification_result": classification,
                "non_defect_prob": round(prob_nom, 3),
                "defect_prob": round(prob_def, 3),
                "nominal_description": nominal_descriptions[max_nom_idx],
                "defective_description": defective_descriptions[max_def_idx] if defective_images else "N/A"
            })

            if print_one_liner:
                print(f"{test_image_filenames[idx]} classified as {classification} "
                      f"(Nominal Prob: {prob_nom:.3f}, Defective Prob: {prob_def:.3f})")

    file_exists = os.path.isfile(csv_file)
    with open(csv_file, mode='a' if file_exists else 'w', newline='') as file:
        import csv
        fieldnames = [
            "datetime_of_operation", "num_few_shot_nominal_imgs", "image_path", "image_name",
            "classification_result", "non_defect_prob", "defect_prob",
            "nominal_description", "defective_description"
        ]
        writer = csv.DictWriter(file, fieldnames=fieldnames)
        if not file_exists:
            writer.writeheader()
        for row in csv_data:
            writer.writerow(row)

    return ""

# --- App state ---
if 'nominal_images' not in st.session_state:
    st.session_state.nominal_images = []
if 'defective_images' not in st.session_state:
    st.session_state.defective_images = []
if 'test_images' not in st.session_state:
    st.session_state.test_images = []
if 'results' not in st.session_state:
    st.session_state.results = []

# --- Tabs ---
tab1, tab2, tab3 = st.tabs(["📥 Upload Reference Images", "🔍 Test Classification", "📊 Results"])

# Tab 1: Upload Reference Images
with tab1:
    st.header("Upload Reference Images")
    nominal_files = st.file_uploader("Upload Nominal Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])
    defective_files = st.file_uploader("Upload Defective Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])

    if nominal_files:
        st.session_state.nominal_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in nominal_files]
        st.session_state.nominal_descriptions = [file.name for file in nominal_files]
        st.success(f"Uploaded {len(nominal_files)} nominal images.")

    if defective_files:
        st.session_state.defective_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in defective_files]
        st.session_state.defective_descriptions = [file.name for file in defective_files]
        st.success(f"Uploaded {len(defective_files)} defective images.")

# Tab 2: Test Classification
with tab2:
    st.header("Upload Test Image(s)")
    test_files = st.file_uploader("Upload Test Images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg'])

    if st.button("🔍 Run Classification") and test_files:
        test_images = [preprocess(Image.open(file).convert("RGB")).to(device) for file in test_files]
        test_filenames = [file.name for file in test_files]

        few_shot_fault_classification(
            test_images=test_images,
            test_image_filenames=test_filenames,
            nominal_images=st.session_state.nominal_images,
            nominal_descriptions=st.session_state.nominal_descriptions,
            defective_images=st.session_state.defective_images,
            defective_descriptions=st.session_state.defective_descriptions,
            num_few_shot_nominal_imgs=len(st.session_state.nominal_images),
            file_path=".",
            file_name="streamlit_results.csv",
            print_one_liner=False
        )

        st.success("Classification complete!")
        st.session_state.results = "streamlit_results.csv"

# Tab 3: View/Download Results
with tab3:
    st.header("Classification Results")
    if os.path.exists("streamlit_results.csv"):
        df = pd.read_csv("streamlit_results.csv")
        st.dataframe(df)
        st.download_button("📥 Download Results", data=df.to_csv(index=False), file_name="classification_results.csv", mime="text/csv")
    else:
        st.info("No results yet. Please classify some test images.")