Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,7 @@ from tensorflow.keras.models import load_model
|
|
8 |
import tensorflow_hub as hub
|
9 |
import time
|
10 |
import tempfile
|
11 |
-
import
|
12 |
from io import BytesIO
|
13 |
|
14 |
# Attempt to set GPU memory growth
|
@@ -65,27 +65,46 @@ def load_autism_model():
|
|
65 |
|
66 |
model = load_autism_model()
|
67 |
|
68 |
-
def extract_features(
|
69 |
sample_rate = 16000
|
70 |
-
|
71 |
-
temp_audio_file.write(audio_bytes)
|
72 |
-
temp_audio_file.flush()
|
73 |
-
array, fs = torchaudio.load(temp_audio_file.name)
|
74 |
|
75 |
array = np.array(array)
|
76 |
if array.shape[0] > 1:
|
77 |
array = np.mean(array, axis=0, keepdims=True)
|
78 |
|
79 |
-
|
|
|
|
|
80 |
embeddings = m(array)['embedding']
|
81 |
embeddings.shape.assert_is_compatible_with([None, 1024])
|
82 |
embeddings = np.squeeze(np.array(embeddings), axis=0)
|
83 |
|
84 |
return embeddings
|
85 |
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
def run_prediction(features):
|
91 |
try:
|
@@ -158,13 +177,21 @@ def run_prediction(features):
|
|
158 |
unsafe_allow_html=True
|
159 |
)
|
160 |
|
|
|
|
|
|
|
|
|
161 |
if option == "Upload an audio file":
|
162 |
uploaded_file = st.file_uploader("Upload an audio file (.wav)", type=["wav"])
|
163 |
if uploaded_file is not None:
|
164 |
-
start_time = time.time()
|
165 |
with st.spinner('Extracting features...'):
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
168 |
run_prediction(features)
|
169 |
elapsed_time = round(time.time() - start_time, 2)
|
170 |
st.write(f"Elapsed Time: {elapsed_time} seconds")
|
@@ -254,15 +281,24 @@ else: # Option is "Record audio"
|
|
254 |
};
|
255 |
recorder.onstop = () => {
|
256 |
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
|
257 |
-
const
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
// Reset
|
267 |
audioChunks = [];
|
268 |
clearInterval(timerInterval);
|
@@ -292,12 +328,4 @@ else: # Option is "Record audio"
|
|
292 |
'''
|
293 |
st.components.v1.html(audio_recorder_html, height=600)
|
294 |
|
295 |
-
|
296 |
-
try:
|
297 |
-
# Replace this with the actual method to receive audio bytes from frontend
|
298 |
-
audio_bytes = st.session_state.get('recorded_audio_bytes')
|
299 |
-
if audio_bytes:
|
300 |
-
features = extract_features(audio_bytes)
|
301 |
-
run_prediction(features)
|
302 |
-
except Exception as e:
|
303 |
-
st.error(f"An error occurred: {e}")
|
|
|
8 |
import tensorflow_hub as hub
|
9 |
import time
|
10 |
import tempfile
|
11 |
+
import base64
|
12 |
from io import BytesIO
|
13 |
|
14 |
# Attempt to set GPU memory growth
|
|
|
65 |
|
66 |
model = load_autism_model()
|
67 |
|
68 |
+
def extract_features(path):
|
69 |
sample_rate = 16000
|
70 |
+
array, fs = torchaudio.load(path)
|
|
|
|
|
|
|
71 |
|
72 |
array = np.array(array)
|
73 |
if array.shape[0] > 1:
|
74 |
array = np.mean(array, axis=0, keepdims=True)
|
75 |
|
76 |
+
# Truncate the audio to 10 seconds for reducing memory usage
|
77 |
+
array = array[:, :sample_rate * 10]
|
78 |
+
|
79 |
embeddings = m(array)['embedding']
|
80 |
embeddings.shape.assert_is_compatible_with([None, 1024])
|
81 |
embeddings = np.squeeze(np.array(embeddings), axis=0)
|
82 |
|
83 |
return embeddings
|
84 |
|
85 |
+
def save_temp_audio(base64_audio, filename="temp_audio.wav"):
|
86 |
+
audio_data = base64.b64decode(base64_audio)
|
87 |
+
temp_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
88 |
+
with open(temp_audio_path.name, "wb") as f:
|
89 |
+
f.write(audio_data)
|
90 |
+
return temp_audio_path.name
|
91 |
|
92 |
+
def handle_audio_upload():
|
93 |
+
json_data = st.experimental_get_query_params().get("upload-audio")
|
94 |
+
if json_data:
|
95 |
+
base64_audio = json_data['audio_data']
|
96 |
+
temp_audio_path = save_temp_audio(base64_audio)
|
97 |
+
|
98 |
+
# Process the uploaded audio file
|
99 |
+
command = f'ffmpeg -i {temp_audio_path} -acodec pcm_s16le -ar 16000 -ac 1 ./recorded_audio2.wav'
|
100 |
+
result = subprocess.run(command, shell=True, capture_output=True, text=True)
|
101 |
+
if result.returncode != 0:
|
102 |
+
st.error(f"Error running ffmpeg: {result.stderr}")
|
103 |
+
else:
|
104 |
+
features = extract_features("./recorded_audio2.wav")
|
105 |
+
run_prediction(features)
|
106 |
+
os.remove("./recorded_audio2.wav")
|
107 |
+
os.remove(temp_audio_path)
|
108 |
|
109 |
def run_prediction(features):
|
110 |
try:
|
|
|
177 |
unsafe_allow_html=True
|
178 |
)
|
179 |
|
180 |
+
st.markdown('<span style="color:black; font-size: 48px; font-weight: bold;">Neu</span> <span style="color:black; font-size: 48px; font-weight: bold;">RO:</span> <span style="color:black; font-size: 48px; font-weight: bold;">An Application for Code-Switched Autism Detection in Children</span>', unsafe_allow_html=True)
|
181 |
+
|
182 |
+
option = st.radio("**Choose an option:**", ["Upload an audio file", "Record audio"])
|
183 |
+
|
184 |
if option == "Upload an audio file":
|
185 |
uploaded_file = st.file_uploader("Upload an audio file (.wav)", type=["wav"])
|
186 |
if uploaded_file is not None:
|
187 |
+
start_time = time.time() # Record start time
|
188 |
with st.spinner('Extracting features...'):
|
189 |
+
# Process the uploaded file
|
190 |
+
temp_audio_path = os.path.join(".", "temp_audio.wav")
|
191 |
+
with open(temp_audio_path, "wb") as f:
|
192 |
+
f.write(uploaded_file.getbuffer())
|
193 |
+
features = extract_features(temp_audio_path)
|
194 |
+
os.remove(temp_audio_path)
|
195 |
run_prediction(features)
|
196 |
elapsed_time = round(time.time() - start_time, 2)
|
197 |
st.write(f"Elapsed Time: {elapsed_time} seconds")
|
|
|
281 |
};
|
282 |
recorder.onstop = () => {
|
283 |
const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
|
284 |
+
const reader = new FileReader();
|
285 |
+
reader.readAsDataURL(audioBlob);
|
286 |
+
reader.onloadend = () => {
|
287 |
+
const base64String = reader.result.split(',')[1];
|
288 |
+
fetch('/upload-audio', {
|
289 |
+
method: 'POST',
|
290 |
+
headers: {
|
291 |
+
'Content-Type': 'application/json',
|
292 |
+
},
|
293 |
+
body: JSON.stringify({ audio_data: base64String }),
|
294 |
+
}).then(response => {
|
295 |
+
if (response.ok) {
|
296 |
+
console.log('Audio uploaded successfully.');
|
297 |
+
} else {
|
298 |
+
console.error('Audio upload failed.');
|
299 |
+
}
|
300 |
+
});
|
301 |
+
};
|
302 |
// Reset
|
303 |
audioChunks = [];
|
304 |
clearInterval(timerInterval);
|
|
|
328 |
'''
|
329 |
st.components.v1.html(audio_recorder_html, height=600)
|
330 |
|
331 |
+
handle_audio_upload()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|