hsalih01 commited on
Commit
48dbf61
·
verified ·
1 Parent(s): f643417

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -142
app.py CHANGED
@@ -1,143 +1,46 @@
1
- import os
 
 
2
  import numpy as np
3
- from tensorflow.keras.preprocessing.image import ImageDataGenerator
4
- from tensorflow.keras.applications import ResNet50
5
- from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout, BatchNormalization
6
- from tensorflow.keras.models import Model
7
- from tensorflow.keras.optimizers import Adam
8
- from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
9
- import matplotlib.pyplot as plt
10
- from collections import Counter
11
-
12
- # Pfad zum Verzeichnis der Bilddaten
13
- base_dir = 'pokemon-images-first-generation17000-files/pokemon' # Pfad anpassen
14
-
15
- # Überprüfen, ob das Verzeichnis existiert
16
- if not os.path.exists(base_dir):
17
- raise Exception("Der angegebene Basispfad existiert nicht. Bitte überprüfen Sie den Pfad.")
18
-
19
- # Definiere die Pokémon-Klassen, die klassifiziert werden sollen
20
- classes = ['Dragonite', 'Bulbasaur', 'Golbat']
21
-
22
- # Data Augmentation und Daten-Generatoren
23
- train_datagen = ImageDataGenerator(
24
- rescale=1./255,
25
- rotation_range=45,
26
- width_shift_range=0.25,
27
- height_shift_range=0.25,
28
- shear_range=0.25,
29
- zoom_range=0.3,
30
- horizontal_flip=True,
31
- fill_mode='nearest',
32
- validation_split=0.2)
33
-
34
- validation_datagen = ImageDataGenerator(
35
- rescale=1./255,
36
- validation_split=0.2)
37
-
38
- train_generator = train_datagen.flow_from_directory(
39
- base_dir,
40
- target_size=(150, 150),
41
- batch_size=32,
42
- classes=classes,
43
- class_mode='categorical',
44
- subset='training')
45
-
46
- validation_generator = validation_datagen.flow_from_directory(
47
- base_dir,
48
- target_size=(150, 150),
49
- batch_size=32,
50
- classes=classes,
51
- class_mode='categorical',
52
- subset='validation')
53
-
54
- # Basis-Modell ResNet50 laden und anpassen
55
- base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(150, 150, 3))
56
- base_model.trainable = False # Initial alle Schichten einfrieren
57
-
58
- x = base_model.output
59
- x = GlobalAveragePooling2D()(x)
60
- x = Dense(1024, activation='relu')(x)
61
- x = Dropout(0.3)(x)
62
- x = BatchNormalization()(x)
63
- predictions = Dense(len(classes), activation='softmax')(x)
64
- model = Model(inputs=base_model.input, outputs=predictions)
65
-
66
- # Callbacks einrichten
67
- callbacks = [
68
- ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.00001, verbose=1),
69
- EarlyStopping(monitor='val_loss', patience=10, verbose=1),
70
- ModelCheckpoint('best_model.keras', monitor='val_loss', save_best_only=True, verbose=1)
71
- ]
72
-
73
- # Modell kompilieren mit Adam Optimizer
74
- model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])
75
-
76
- # Diagnostic output
77
- print(f"Total training samples: {train_generator.samples}")
78
- print(f"Total validation samples: {validation_generator.samples}")
79
- print(f"Training batch size: {train_generator.batch_size}")
80
- print(f"Validation batch size: {validation_generator.batch_size}")
81
-
82
- # Calculate steps per epoch and validation steps
83
- training_steps_per_epoch = int(np.ceil(train_generator.samples / train_generator.batch_size))
84
- validation_steps_per_epoch = int(np.ceil(validation_generator.samples / validation_generator.batch_size))
85
-
86
- print(f"Calculated training steps per epoch: {training_steps_per_epoch}")
87
- print(f"Calculated validation steps per epoch: {validation_steps_per_epoch}")
88
-
89
- # Model training with updated steps per epoch
90
- history = model.fit(
91
- train_generator,
92
- steps_per_epoch=training_steps_per_epoch,
93
- validation_data=validation_generator,
94
- validation_steps=validation_steps_per_epoch,
95
- epochs=15,
96
- callbacks=callbacks
97
- )
98
-
99
- # Modell speichern
100
- model.save('pokemon_classifier_model.h5')
101
-
102
- # Fine-Tuning: Einige Schichten des Basis-Modells freigeben
103
- base_model.trainable = True
104
- for layer in base_model.layers[:165]:
105
- layer.trainable = False
106
-
107
- # Modell neu kompilieren mit Adam für Fine-Tuning
108
- model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
109
-
110
- # Fine-Tuning trainieren mit Callbacks
111
- fine_tune_history = model.fit(
112
- train_generator,
113
- steps_per_epoch=train_generator.samples // train_generator.batch_size,
114
- validation_data=validation_generator,
115
- validation_steps=validation_generator.samples // validation_generator.batch_size,
116
- epochs=10,
117
- callbacks=callbacks)
118
-
119
- # Bilder und deren Labels aus dem Validierungsdatengenerator holen
120
- class_indices = train_generator.class_indices
121
- class_names = list(class_indices.keys())
122
-
123
- for images, labels in validation_generator:
124
- # Nur die ersten 5 Bilder verwenden
125
- for i in range(5):
126
- img = images[i]
127
- label = labels[i]
128
- plt.imshow(img)
129
-
130
- # Vorhersage erstellen
131
- img_array = np.expand_dims(img, axis=0) # Modell erwartet einen Batch
132
- predictions = model.predict(img_array)
133
- predicted_class_index = np.argmax(predictions, axis=1)[0]
134
- predicted_class_name = class_names[predicted_class_index]
135
-
136
- # Tatsächliche Klasse herausfinden
137
- true_class_index = np.argmax(label)
138
- true_class_name = class_names[true_class_index]
139
-
140
- plt.title(f'Vorhergesagt: {predicted_class_name}, Wahr: {true_class_name}')
141
- plt.show()
142
-
143
- break # Nur die erste Charge von Bildern verwenden
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ from PIL import Image
4
  import numpy as np
5
+
6
+ # Load the pre-trained Pokémon model
7
+ model_path = "pokemon_classifier_model.keras"
8
+ model = tf.keras.models.load_model(model_path)
9
+
10
+ # Define the Pokémon classes
11
+ classes = ['Doduo', 'Geodude', 'Zubat'] # Adjust these as per your model's classes
12
+
13
+ # Define the image classification function
14
+ def classify_image(image):
15
+ try:
16
+ # Ensure the image is in RGB and normalize it
17
+ if image.ndim == 2: # Check if the image is grayscale
18
+ image = np.stack((image,)*3, axis=-1) # Convert grayscale to RGB by repeating the gray channel
19
+ elif image.shape[2] == 4: # Check if the image has an alpha channel
20
+ image = image[:, :, :3] # Drop the alpha channel
21
+ image = Image.fromarray(image.astype('uint8'), 'RGB') # Convert to PIL Image to resize
22
+ image = image.resize((150, 150)) # Resize to match the model's input size
23
+
24
+ image_array = np.array(image) / 255.0 # Convert to array and normalize
25
+ image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
26
+
27
+ # Predict using the model
28
+ prediction = model.predict(image_array)
29
+ predicted_class = classes[np.argmax(prediction)]
30
+ confidence = np.max(prediction)
31
+
32
+ return f"Predicted Pokémon: {predicted_class}, Confidence: {np.round(confidence * 100, 2)}%"
33
+ except Exception as e:
34
+ return str(e) # Return the error message if something goes wrong
35
+
36
+ # Create Gradio interface
37
+ input_image = gr.Image() # Using Gradio's Image component correctly
38
+ output_label = gr.Label()
39
+
40
+ interface = gr.Interface(fn=classify_image,
41
+ inputs=input_image,
42
+ outputs=output_label,
43
+ examples=["pokemon-images-first-generation17000-files/pokemon/Bulbasaur/00000000.png", "pokemon-images-first-generation17000-files/pokemon/Dragonite/00000000.jpg", "pokemon-images-first-generation17000-files/pokemon/Golbat/00000000.png"],
44
+ description="Upload an image of a Pokémon to classify!")
45
+
46
+ interface.launch()