Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,143 +1,46 @@
|
|
1 |
-
import
|
|
|
|
|
2 |
import numpy as np
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
#
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
#
|
20 |
-
|
21 |
-
|
22 |
-
#
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
validation_generator = validation_datagen.flow_from_directory(
|
47 |
-
base_dir,
|
48 |
-
target_size=(150, 150),
|
49 |
-
batch_size=32,
|
50 |
-
classes=classes,
|
51 |
-
class_mode='categorical',
|
52 |
-
subset='validation')
|
53 |
-
|
54 |
-
# Basis-Modell ResNet50 laden und anpassen
|
55 |
-
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(150, 150, 3))
|
56 |
-
base_model.trainable = False # Initial alle Schichten einfrieren
|
57 |
-
|
58 |
-
x = base_model.output
|
59 |
-
x = GlobalAveragePooling2D()(x)
|
60 |
-
x = Dense(1024, activation='relu')(x)
|
61 |
-
x = Dropout(0.3)(x)
|
62 |
-
x = BatchNormalization()(x)
|
63 |
-
predictions = Dense(len(classes), activation='softmax')(x)
|
64 |
-
model = Model(inputs=base_model.input, outputs=predictions)
|
65 |
-
|
66 |
-
# Callbacks einrichten
|
67 |
-
callbacks = [
|
68 |
-
ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.00001, verbose=1),
|
69 |
-
EarlyStopping(monitor='val_loss', patience=10, verbose=1),
|
70 |
-
ModelCheckpoint('best_model.keras', monitor='val_loss', save_best_only=True, verbose=1)
|
71 |
-
]
|
72 |
-
|
73 |
-
# Modell kompilieren mit Adam Optimizer
|
74 |
-
model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])
|
75 |
-
|
76 |
-
# Diagnostic output
|
77 |
-
print(f"Total training samples: {train_generator.samples}")
|
78 |
-
print(f"Total validation samples: {validation_generator.samples}")
|
79 |
-
print(f"Training batch size: {train_generator.batch_size}")
|
80 |
-
print(f"Validation batch size: {validation_generator.batch_size}")
|
81 |
-
|
82 |
-
# Calculate steps per epoch and validation steps
|
83 |
-
training_steps_per_epoch = int(np.ceil(train_generator.samples / train_generator.batch_size))
|
84 |
-
validation_steps_per_epoch = int(np.ceil(validation_generator.samples / validation_generator.batch_size))
|
85 |
-
|
86 |
-
print(f"Calculated training steps per epoch: {training_steps_per_epoch}")
|
87 |
-
print(f"Calculated validation steps per epoch: {validation_steps_per_epoch}")
|
88 |
-
|
89 |
-
# Model training with updated steps per epoch
|
90 |
-
history = model.fit(
|
91 |
-
train_generator,
|
92 |
-
steps_per_epoch=training_steps_per_epoch,
|
93 |
-
validation_data=validation_generator,
|
94 |
-
validation_steps=validation_steps_per_epoch,
|
95 |
-
epochs=15,
|
96 |
-
callbacks=callbacks
|
97 |
-
)
|
98 |
-
|
99 |
-
# Modell speichern
|
100 |
-
model.save('pokemon_classifier_model.h5')
|
101 |
-
|
102 |
-
# Fine-Tuning: Einige Schichten des Basis-Modells freigeben
|
103 |
-
base_model.trainable = True
|
104 |
-
for layer in base_model.layers[:165]:
|
105 |
-
layer.trainable = False
|
106 |
-
|
107 |
-
# Modell neu kompilieren mit Adam für Fine-Tuning
|
108 |
-
model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
|
109 |
-
|
110 |
-
# Fine-Tuning trainieren mit Callbacks
|
111 |
-
fine_tune_history = model.fit(
|
112 |
-
train_generator,
|
113 |
-
steps_per_epoch=train_generator.samples // train_generator.batch_size,
|
114 |
-
validation_data=validation_generator,
|
115 |
-
validation_steps=validation_generator.samples // validation_generator.batch_size,
|
116 |
-
epochs=10,
|
117 |
-
callbacks=callbacks)
|
118 |
-
|
119 |
-
# Bilder und deren Labels aus dem Validierungsdatengenerator holen
|
120 |
-
class_indices = train_generator.class_indices
|
121 |
-
class_names = list(class_indices.keys())
|
122 |
-
|
123 |
-
for images, labels in validation_generator:
|
124 |
-
# Nur die ersten 5 Bilder verwenden
|
125 |
-
for i in range(5):
|
126 |
-
img = images[i]
|
127 |
-
label = labels[i]
|
128 |
-
plt.imshow(img)
|
129 |
-
|
130 |
-
# Vorhersage erstellen
|
131 |
-
img_array = np.expand_dims(img, axis=0) # Modell erwartet einen Batch
|
132 |
-
predictions = model.predict(img_array)
|
133 |
-
predicted_class_index = np.argmax(predictions, axis=1)[0]
|
134 |
-
predicted_class_name = class_names[predicted_class_index]
|
135 |
-
|
136 |
-
# Tatsächliche Klasse herausfinden
|
137 |
-
true_class_index = np.argmax(label)
|
138 |
-
true_class_name = class_names[true_class_index]
|
139 |
-
|
140 |
-
plt.title(f'Vorhergesagt: {predicted_class_name}, Wahr: {true_class_name}')
|
141 |
-
plt.show()
|
142 |
-
|
143 |
-
break # Nur die erste Charge von Bildern verwenden
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import tensorflow as tf
|
3 |
+
from PIL import Image
|
4 |
import numpy as np
|
5 |
+
|
6 |
+
# Load the pre-trained Pokémon model
|
7 |
+
model_path = "pokemon_classifier_model.keras"
|
8 |
+
model = tf.keras.models.load_model(model_path)
|
9 |
+
|
10 |
+
# Define the Pokémon classes
|
11 |
+
classes = ['Doduo', 'Geodude', 'Zubat'] # Adjust these as per your model's classes
|
12 |
+
|
13 |
+
# Define the image classification function
|
14 |
+
def classify_image(image):
|
15 |
+
try:
|
16 |
+
# Ensure the image is in RGB and normalize it
|
17 |
+
if image.ndim == 2: # Check if the image is grayscale
|
18 |
+
image = np.stack((image,)*3, axis=-1) # Convert grayscale to RGB by repeating the gray channel
|
19 |
+
elif image.shape[2] == 4: # Check if the image has an alpha channel
|
20 |
+
image = image[:, :, :3] # Drop the alpha channel
|
21 |
+
image = Image.fromarray(image.astype('uint8'), 'RGB') # Convert to PIL Image to resize
|
22 |
+
image = image.resize((150, 150)) # Resize to match the model's input size
|
23 |
+
|
24 |
+
image_array = np.array(image) / 255.0 # Convert to array and normalize
|
25 |
+
image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
|
26 |
+
|
27 |
+
# Predict using the model
|
28 |
+
prediction = model.predict(image_array)
|
29 |
+
predicted_class = classes[np.argmax(prediction)]
|
30 |
+
confidence = np.max(prediction)
|
31 |
+
|
32 |
+
return f"Predicted Pokémon: {predicted_class}, Confidence: {np.round(confidence * 100, 2)}%"
|
33 |
+
except Exception as e:
|
34 |
+
return str(e) # Return the error message if something goes wrong
|
35 |
+
|
36 |
+
# Create Gradio interface
|
37 |
+
input_image = gr.Image() # Using Gradio's Image component correctly
|
38 |
+
output_label = gr.Label()
|
39 |
+
|
40 |
+
interface = gr.Interface(fn=classify_image,
|
41 |
+
inputs=input_image,
|
42 |
+
outputs=output_label,
|
43 |
+
examples=["pokemon-images-first-generation17000-files/pokemon/Bulbasaur/00000000.png", "pokemon-images-first-generation17000-files/pokemon/Dragonite/00000000.jpg", "pokemon-images-first-generation17000-files/pokemon/Golbat/00000000.png"],
|
44 |
+
description="Upload an image of a Pokémon to classify!")
|
45 |
+
|
46 |
+
interface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|