mouadenna commited on
Commit
6bc0e93
1 Parent(s): aa44aa5

Update pages/4_Writing.py

Browse files
Files changed (1) hide show
  1. pages/4_Writing.py +86 -36
pages/4_Writing.py CHANGED
@@ -3,6 +3,8 @@ from streamlit_drawable_canvas import st_canvas
3
  import cv2
4
  from tensorflow.keras.models import load_model
5
  import numpy as np
 
 
6
 
7
 
8
  arabic_chars = ['alef','beh','teh','theh','jeem','hah','khah','dal','thal','reh','zain','seen','sheen',
@@ -34,41 +36,89 @@ def add_logo():
34
  unsafe_allow_html=True,
35
  )
36
  add_logo()
 
 
 
 
37
  def predict_image(image_path, model_path):
38
- model = load_model(model_path)
39
-
40
- img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
41
- img = cv2.resize(img, (32, 32))
42
- img = img.reshape(1, 32, 32, 1)
43
- img = img.astype('float32') / 255.0
44
-
45
- pred = model.predict(img)
46
- predicted_label = arabic_chars[np.argmax(pred)]
47
-
48
- return predicted_label
49
-
50
- canvas_result = st_canvas(
51
- fill_color="rgba(255, 255, 255, 0.3)", # Filled color (white)
52
- stroke_width=30, # Stroke width
53
- stroke_color="#FFFFFF", # Stroke color (white)
54
- background_color="#000000", # Canvas background color (black)
55
- update_streamlit=True,
56
- height=400,
57
- width=400,
58
- drawing_mode="freedraw",
59
- key="canvas",
60
- )
61
-
62
- if st.button("Predict"):
63
- if canvas_result.image_data is not None:
64
- image = canvas_result.image_data
65
- image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
66
- image = cv2.resize(image, (32, 32))
67
- cv2.imwrite("temp_image.png", image)
68
-
69
- model_path = "saved_model.h5" # Replace with the path to your trained model
70
- predicted_label = predict_image("temp_image.png", model_path)
71
-
72
- st.write(f"Predicted Character: {predicted_label}")
 
 
 
 
73
  else:
74
- st.write("Please draw something on the canvas.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import cv2
4
  from tensorflow.keras.models import load_model
5
  import numpy as np
6
+ import random
7
+ import os
8
 
9
 
10
  arabic_chars = ['alef','beh','teh','theh','jeem','hah','khah','dal','thal','reh','zain','seen','sheen',
 
36
  unsafe_allow_html=True,
37
  )
38
  add_logo()
39
+
40
+
41
+
42
+
43
  def predict_image(image_path, model_path):
44
+ try:
45
+ model = load_model(model_path)
46
+ except Exception as e:
47
+ st.error(f"Error loading model: {e}")
48
+ return None
49
+
50
+ try:
51
+ img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
52
+ img = cv2.resize(img, (32, 32))
53
+ img = img.reshape(1, 32, 32, 1)
54
+ img = img.astype('float32') / 255.0
55
+
56
+ pred = model.predict(img)
57
+ predicted_label = arabic_chars[np.argmax(pred)]
58
+ return predicted_label
59
+ except Exception as e:
60
+ st.error(f"Error processing image: {e}")
61
+ return None
62
+
63
+ def get_random_image(folder_path):
64
+ try:
65
+ char = random.choice(arabic_chars)
66
+ image_path = os.path.join(folder_path, f"{char}.png")
67
+ return image_path, char
68
+ except Exception as e:
69
+ st.error(f"Error loading random image: {e}")
70
+ return None, None
71
+
72
+ # Streamlit app
73
+ st.title("Arabic Character Recognition")
74
+
75
+ # Load and display a random image
76
+ folder_path = "arabic letters"
77
+ if 'image_path' not in st.session_state:
78
+ st.session_state.image_path, st.session_state.correct_char = get_random_image(folder_path)
79
+ col1,col2,col3=st.columns([1,1,1])
80
+ with col1:
81
+ if st.session_state.image_path and st.session_state.correct_char:
82
+ st.image(st.session_state.image_path, caption=f"Draw this character: {st.session_state.correct_char}",width=350,)
83
  else:
84
+ st.error("Error loading the random image.")
85
+
86
+ with col2:
87
+ canvas_result = st_canvas(
88
+ fill_color="rgba(255, 255, 255, 0.3)", # Filled color (white)
89
+ stroke_width=19, # Stroke width
90
+ stroke_color="#FFFFFF", # Stroke color (white)
91
+ background_color="#000000", # Canvas background color (black)
92
+ update_streamlit=True,
93
+ height=400,
94
+ width=400,
95
+ drawing_mode="freedraw",
96
+ key="canvas",
97
+ )
98
+ with col3:
99
+ if st.button("Check"):
100
+ if canvas_result.image_data is not None:
101
+ image = canvas_result.image_data
102
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
103
+ image = cv2.resize(image, (32, 32))
104
+
105
+ # Save the image temporarily
106
+ temp_image_path = "temp_image.png"
107
+ cv2.imwrite(temp_image_path, image)
108
+
109
+ # Load and predict using the model
110
+ model_path = "saved_model.h5" # Replace with the path to your trained model
111
+ if os.path.exists(model_path):
112
+ predicted_label = predict_image(temp_image_path, model_path)
113
+ if predicted_label:
114
+ #st.write(f"Predicted Character: {predicted_label}")
115
+ if predicted_label == st.session_state.correct_char:
116
+ st.success("You are correct!")
117
+ st.session_state.image_path, st.session_state.correct_char = get_random_image(folder_path)
118
+ canvas_result.clear_background()
119
+ else:
120
+ st.error("The prediction does not match the displayed character. Try again.")
121
+ else:
122
+ st.error("Model file not found. Please check the model path.")
123
+ else:
124
+ st.write("Please draw something on the canvas.")