ggirishg commited on
Commit
c6028a1
·
verified ·
1 Parent(s): 1dbf1e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -30
app.py CHANGED
@@ -8,7 +8,7 @@ from tensorflow.keras.models import load_model
8
  import tensorflow_hub as hub
9
  import time
10
  import tempfile
11
- import streamlit.components.v1 as components
12
  from io import BytesIO
13
 
14
  # Attempt to set GPU memory growth
@@ -65,27 +65,46 @@ def load_autism_model():
65
 
66
  model = load_autism_model()
67
 
68
- def extract_features(audio_bytes):
69
  sample_rate = 16000
70
- with tempfile.NamedTemporaryFile(delete=True) as temp_audio_file:
71
- temp_audio_file.write(audio_bytes)
72
- temp_audio_file.flush()
73
- array, fs = torchaudio.load(temp_audio_file.name)
74
 
75
  array = np.array(array)
76
  if array.shape[0] > 1:
77
  array = np.mean(array, axis=0, keepdims=True)
78
 
79
- array = array[:, :sample_rate * 10] # Truncate to 10 seconds
 
 
80
  embeddings = m(array)['embedding']
81
  embeddings.shape.assert_is_compatible_with([None, 1024])
82
  embeddings = np.squeeze(np.array(embeddings), axis=0)
83
 
84
  return embeddings
85
 
86
- st.markdown('<span style="color:black; font-size: 48px; font-weight: bold;">Neu</span> <span style="color:black; font-size: 48px; font-weight: bold;">RO:</span> <span style="color:black; font-size: 48px; font-weight: bold;">An Application for Code-Switched Autism Detection in Children</span>', unsafe_allow_html=True)
 
 
 
 
 
87
 
88
- option = st.radio("**Choose an option:**", ["Upload an audio file", "Record audio"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def run_prediction(features):
91
  try:
@@ -158,13 +177,21 @@ def run_prediction(features):
158
  unsafe_allow_html=True
159
  )
160
 
 
 
 
 
161
  if option == "Upload an audio file":
162
  uploaded_file = st.file_uploader("Upload an audio file (.wav)", type=["wav"])
163
  if uploaded_file is not None:
164
- start_time = time.time()
165
  with st.spinner('Extracting features...'):
166
- audio_bytes = uploaded_file.read()
167
- features = extract_features(audio_bytes)
 
 
 
 
168
  run_prediction(features)
169
  elapsed_time = round(time.time() - start_time, 2)
170
  st.write(f"Elapsed Time: {elapsed_time} seconds")
@@ -254,15 +281,24 @@ else: # Option is "Record audio"
254
  };
255
  recorder.onstop = () => {
256
  const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
257
- const audioUrl = URL.createObjectURL(audioBlob);
258
- const a = document.createElement('a');
259
- a.href = audioUrl;
260
- a.download = 'recorded_audio.wav';
261
- document.body.appendChild(a);
262
- a.click();
263
- // Instead of downloading, pass the Blob data to the backend
264
- window.parent.postMessage(audioBlob, "*");
265
-
 
 
 
 
 
 
 
 
 
266
  // Reset
267
  audioChunks = [];
268
  clearInterval(timerInterval);
@@ -292,12 +328,4 @@ else: # Option is "Record audio"
292
  '''
293
  st.components.v1.html(audio_recorder_html, height=600)
294
 
295
- if st.button("Click to Predict"):
296
- try:
297
- # Replace this with the actual method to receive audio bytes from frontend
298
- audio_bytes = st.session_state.get('recorded_audio_bytes')
299
- if audio_bytes:
300
- features = extract_features(audio_bytes)
301
- run_prediction(features)
302
- except Exception as e:
303
- st.error(f"An error occurred: {e}")
 
8
  import tensorflow_hub as hub
9
  import time
10
  import tempfile
11
+ import base64
12
  from io import BytesIO
13
 
14
  # Attempt to set GPU memory growth
 
65
 
66
  model = load_autism_model()
67
 
68
+ def extract_features(path):
69
  sample_rate = 16000
70
+ array, fs = torchaudio.load(path)
 
 
 
71
 
72
  array = np.array(array)
73
  if array.shape[0] > 1:
74
  array = np.mean(array, axis=0, keepdims=True)
75
 
76
+ # Truncate the audio to 10 seconds for reducing memory usage
77
+ array = array[:, :sample_rate * 10]
78
+
79
  embeddings = m(array)['embedding']
80
  embeddings.shape.assert_is_compatible_with([None, 1024])
81
  embeddings = np.squeeze(np.array(embeddings), axis=0)
82
 
83
  return embeddings
84
 
85
+ def save_temp_audio(base64_audio, filename="temp_audio.wav"):
86
+ audio_data = base64.b64decode(base64_audio)
87
+ temp_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
88
+ with open(temp_audio_path.name, "wb") as f:
89
+ f.write(audio_data)
90
+ return temp_audio_path.name
91
 
92
+ def handle_audio_upload():
93
+ json_data = st.experimental_get_query_params().get("upload-audio")
94
+ if json_data:
95
+ base64_audio = json_data['audio_data']
96
+ temp_audio_path = save_temp_audio(base64_audio)
97
+
98
+ # Process the uploaded audio file
99
+ command = f'ffmpeg -i {temp_audio_path} -acodec pcm_s16le -ar 16000 -ac 1 ./recorded_audio2.wav'
100
+ result = subprocess.run(command, shell=True, capture_output=True, text=True)
101
+ if result.returncode != 0:
102
+ st.error(f"Error running ffmpeg: {result.stderr}")
103
+ else:
104
+ features = extract_features("./recorded_audio2.wav")
105
+ run_prediction(features)
106
+ os.remove("./recorded_audio2.wav")
107
+ os.remove(temp_audio_path)
108
 
109
  def run_prediction(features):
110
  try:
 
177
  unsafe_allow_html=True
178
  )
179
 
180
+ st.markdown('<span style="color:black; font-size: 48px; font-weight: bold;">Neu</span> <span style="color:black; font-size: 48px; font-weight: bold;">RO:</span> <span style="color:black; font-size: 48px; font-weight: bold;">An Application for Code-Switched Autism Detection in Children</span>', unsafe_allow_html=True)
181
+
182
+ option = st.radio("**Choose an option:**", ["Upload an audio file", "Record audio"])
183
+
184
  if option == "Upload an audio file":
185
  uploaded_file = st.file_uploader("Upload an audio file (.wav)", type=["wav"])
186
  if uploaded_file is not None:
187
+ start_time = time.time() # Record start time
188
  with st.spinner('Extracting features...'):
189
+ # Process the uploaded file
190
+ temp_audio_path = os.path.join(".", "temp_audio.wav")
191
+ with open(temp_audio_path, "wb") as f:
192
+ f.write(uploaded_file.getbuffer())
193
+ features = extract_features(temp_audio_path)
194
+ os.remove(temp_audio_path)
195
  run_prediction(features)
196
  elapsed_time = round(time.time() - start_time, 2)
197
  st.write(f"Elapsed Time: {elapsed_time} seconds")
 
281
  };
282
  recorder.onstop = () => {
283
  const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
284
+ const reader = new FileReader();
285
+ reader.readAsDataURL(audioBlob);
286
+ reader.onloadend = () => {
287
+ const base64String = reader.result.split(',')[1];
288
+ fetch('/upload-audio', {
289
+ method: 'POST',
290
+ headers: {
291
+ 'Content-Type': 'application/json',
292
+ },
293
+ body: JSON.stringify({ audio_data: base64String }),
294
+ }).then(response => {
295
+ if (response.ok) {
296
+ console.log('Audio uploaded successfully.');
297
+ } else {
298
+ console.error('Audio upload failed.');
299
+ }
300
+ });
301
+ };
302
  // Reset
303
  audioChunks = [];
304
  clearInterval(timerInterval);
 
328
  '''
329
  st.components.v1.html(audio_recorder_html, height=600)
330
 
331
+ handle_audio_upload()