eyupipler commited on
Commit
a5050fa
·
verified ·
1 Parent(s): 9f13538

Added Vbai-1.1 Dementia

Browse files
Main Models/Vbai-1.1 Dementia/README.md ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vbai-1.1 Dementia (11178564 parametre)
2
+
3
+ ## "Vbai-1.1 Dementia" modeli, bir önceki modele göre daha fazla veriyle eğitilmiş olup üzerinde ince ayar yapılmış versiyonudur.
4
+
5
+ ## -----------------------------------------------------------------------------------
6
+
7
+ # Vbai-1.1 Dementia (11178564 parameters)
8
+
9
+ ## The "Vbai-1.1 Dementia" model is a fine-tuned version of the previous model, trained with more data.
10
+
11
+ # Kullanım / Usage
12
+
13
+ ```python
14
+ import torch
15
+ import torch.nn as nn
16
+ from torchvision import transforms, models
17
+ from PIL import Image
18
+ import matplotlib.pyplot as plt
19
+ import os
20
+ from torchsummary import summary
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+ model = models.resnet18(pretrained=False)
25
+ num_ftrs = model.fc.in_features
26
+ model.fc = nn.Linear(num_ftrs, 4)
27
+ model.load_state_dict(torch.load('Vbai-1.1 Dementia/path'))
28
+ model = model.to(device)
29
+ model.eval()
30
+ summary(model, (3, 224, 224))
31
+
32
+ transform = transforms.Compose([
33
+ transforms.Resize((224, 224)),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
36
+ ])
37
+
38
+ class_names = ['No Dementia', 'Mild Dementia', 'Avarage Dementia', 'Very Mild Dementia']
39
+
40
+ def predict(image_path, model, transform):
41
+ image = Image.open(image_path).convert('RGB')
42
+ image = transform(image).unsqueeze(0).to(device)
43
+ model.eval()
44
+ with torch.no_grad():
45
+ outputs = model(image)
46
+ probs = torch.nn.functional.softmax(outputs, dim=1)
47
+ _, preds = torch.max(outputs, 1)
48
+ return preds.item(), probs[0][preds.item()].item()
49
+
50
+ def show_image_with_prediction(image_path, prediction, confidence, class_names):
51
+ image = Image.open(image_path)
52
+ plt.imshow(image)
53
+ plt.title(f"Prediction: {class_names[prediction]} (%{confidence * 100:.2f})")
54
+ plt.axis('off')
55
+ plt.show()
56
+
57
+ test_image_path = 'image-path'
58
+ prediction, confidence = predict(test_image_path, model, transform)
59
+ print(f'Prediction: {class_names[prediction]} (%{confidence * 100})')
60
+
61
+ show_image_with_prediction(test_image_path, prediction, confidence, class_names)
62
+ ```
63
+
64
+ # Uygulama / As App
65
+
66
+ ```python
67
+ import sys
68
+ import torch
69
+ import torch.nn as nn
70
+ from torchvision import transforms, models
71
+ from PIL import Image
72
+ import matplotlib.pyplot as plt
73
+ from PyQt5.QtWidgets import QApplication, QWidget, QPushButton, QLabel, QFileDialog, QVBoxLayout, QMessageBox
74
+ from PyQt5.QtGui import QPixmap, QIcon
75
+ from PyQt5.QtCore import Qt
76
+
77
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
+
79
+ transform = transforms.Compose([
80
+ transforms.Resize((224, 224)),
81
+ transforms.ToTensor(),
82
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
83
+ ])
84
+
85
+ class_names = ['No Dementia', 'Mild Dementia', 'Avarage Dementia', 'Very Mild Dementia']
86
+
87
+
88
+ class DementiaApp(QWidget):
89
+ def __init__(self):
90
+ super().__init__()
91
+ self.initUI()
92
+ self.model = None
93
+ self.image_path = None
94
+
95
+ def initUI(self):
96
+ self.setWindowTitle('Prediction App by Neurazum')
97
+ self.setWindowIcon(QIcon('C:/Users/eyupi/PycharmProjects/Neurazum/NeurAI/Assets/neurazumicon.ico'))
98
+ self.setGeometry(2500, 300, 400, 200)
99
+
100
+ self.loadModelButton = QPushButton('Upload Model', self)
101
+ self.loadModelButton.clicked.connect(self.loadModel)
102
+
103
+ self.loadImageButton = QPushButton('Upload Image', self)
104
+ self.loadImageButton.clicked.connect(self.loadImage)
105
+
106
+ self.predictButton = QPushButton('Make a Prediction', self)
107
+ self.predictButton.clicked.connect(self.predict)
108
+ self.predictButton.setEnabled(False)
109
+
110
+ self.resultLabel = QLabel('', self)
111
+ self.resultLabel.setAlignment(Qt.AlignCenter)
112
+
113
+ self.imageLabel = QLabel('', self)
114
+ self.imageLabel.setAlignment(Qt.AlignCenter)
115
+
116
+ layout = QVBoxLayout()
117
+ layout.addWidget(self.loadModelButton)
118
+ layout.addWidget(self.loadImageButton)
119
+ layout.addWidget(self.imageLabel)
120
+ layout.addWidget(self.predictButton)
121
+ layout.addWidget(self.resultLabel)
122
+
123
+ self.setLayout(layout)
124
+
125
+ def loadModel(self):
126
+ options = QFileDialog.Options()
127
+ fileName, _ = QFileDialog.getOpenFileName(self, "Choose Model Path", "",
128
+ "PyTorch Model Files (*.pt);;All Files (*)", options=options)
129
+ if fileName:
130
+ self.model = models.resnet18(pretrained=False)
131
+ num_ftrs = self.model.fc.in_features
132
+ self.model.fc = nn.Linear(num_ftrs, 4)
133
+ self.model.load_state_dict(torch.load(fileName, map_location=device))
134
+ self.model = self.model.to(device)
135
+ self.model.eval()
136
+ self.predictButton.setEnabled(True)
137
+ QMessageBox.information(self, "Model Uploaded", "Model successfully uploaded!")
138
+
139
+ def loadImage(self):
140
+ options = QFileDialog.Options()
141
+ fileName, _ = QFileDialog.getOpenFileName(self, "Choose Image File", "",
142
+ "Image Files (*.jpg *.jpeg *.png);;All Files (*)", options=options)
143
+ if fileName:
144
+ self.image_path = fileName
145
+ pixmap = QPixmap(self.image_path)
146
+ self.imageLabel.setPixmap(pixmap.scaled(224, 224, Qt.KeepAspectRatio))
147
+
148
+ def predict(self):
149
+ if self.model and self.image_path:
150
+ prediction, confidence = self.predictImage(self.image_path, self.model, transform)
151
+ self.resultLabel.setText(f'Prediction: {class_names[prediction]} (%{confidence * 100:.2f})')
152
+ else:
153
+ QMessageBox.warning(self, "Missing Information", "Model and picture must be uploaded.")
154
+
155
+ def predictImage(self, image_path, model, transform):
156
+ image = Image.open(image_path).convert('RGB')
157
+ image = transform(image).unsqueeze(0).to(device)
158
+ model.eval()
159
+ with torch.no_grad():
160
+ outputs = model(image)
161
+ probs = torch.nn.functional.softmax(outputs, dim=1)
162
+ _, preds = torch.max(outputs, 1)
163
+ return preds.item(), probs[0][preds.item()].item()
164
+
165
+
166
+ if __name__ == '__main__':
167
+ app = QApplication(sys.argv)
168
+ ex = DementiaApp()
169
+ ex.show()
170
+ sys.exit(app.exec_())
171
+ ```
172
+
173
+ # Python Sürümü / Python Version
174
+
175
+ ### 3.9 <=> 3.13
176
+
177
+ # Modüller / Modules
178
+
179
+ ```bash
180
+ matplotlib==3.8.0
181
+ Pillow==10.0.1
182
+ torch==2.3.1
183
+ torchsummary==1.5.1
184
+ torchvision==0.18.1
185
+ ```
Main Models/Vbai-1.1 Dementia/Vbai-1.1 Dementia.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a918a7617b157be56975fc9354bc62b83e289d1aacdba77996057dcfdd9f983
3
+ size 44792618