Spaces:
Running
Running
# Copyright 2019 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Installation test for YAMNet.""" | |
import numpy as np | |
import tensorflow as tf | |
import params | |
import yamnet | |
class YAMNetTest(tf.test.TestCase): | |
_yamnet_graph = None | |
_yamnet = None | |
_yamnet_classes = None | |
def setUpClass(cls): | |
super(YAMNetTest, cls).setUpClass() | |
cls._yamnet_graph = tf.Graph() | |
with cls._yamnet_graph.as_default(): | |
cls._yamnet = yamnet.yamnet_frames_model(params) | |
cls._yamnet.load_weights('yamnet.h5') | |
cls._yamnet_classes = yamnet.class_names('yamnet_class_map.csv') | |
def clip_test(self, waveform, expected_class_name, top_n=10): | |
"""Run the model on the waveform, check that expected class is in top-n.""" | |
with YAMNetTest._yamnet_graph.as_default(): | |
prediction = np.mean(YAMNetTest._yamnet.predict( | |
np.reshape(waveform, [1, -1]), steps=1)[0], axis=0) | |
top_n_class_names = YAMNetTest._yamnet_classes[ | |
np.argsort(prediction)[-top_n:]] | |
self.assertIn(expected_class_name, top_n_class_names) | |
def testZeros(self): | |
self.clip_test( | |
waveform=np.zeros((1, int(3 * params.SAMPLE_RATE))), | |
expected_class_name='Silence') | |
def testRandom(self): | |
np.random.seed(51773) # Ensure repeatability. | |
self.clip_test( | |
waveform=np.random.uniform(-1.0, +1.0, | |
(1, int(3 * params.SAMPLE_RATE))), | |
expected_class_name='White noise') | |
def testSine(self): | |
self.clip_test( | |
waveform=np.reshape( | |
np.sin(2 * np.pi * 440 * np.linspace( | |
0, 3, int(3 *params.SAMPLE_RATE))), | |
[1, -1]), | |
expected_class_name='Sine wave') | |
if __name__ == '__main__': | |
tf.test.main() | |