Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -217,11 +217,13 @@ class InferenceModel(object):
|
|
217 |
def predict_tokens(self, batch, seed=0):
|
218 |
"""从预处理的数据集批次中预测 tokens。"""
|
219 |
print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
|
220 |
-
prediction, _ = self._predict_fn(
|
|
|
221 |
return self.vocabulary.decode_tf(prediction).numpy()
|
222 |
|
223 |
def __call__(self, audio):
|
224 |
"""从音频样本推断出音符序列。
|
|
|
225 |
参数:
|
226 |
audio:16kHz 的单个音频样本的 1 维 numpy 数组。
|
227 |
返回:
|
|
|
217 |
def predict_tokens(self, batch, seed=0):
|
218 |
"""从预处理的数据集批次中预测 tokens。"""
|
219 |
print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
|
220 |
+
prediction, _ = self._predict_fn(
|
221 |
+
self._train_state.params, batch, jax.random.PRNGKey(seed))
|
222 |
return self.vocabulary.decode_tf(prediction).numpy()
|
223 |
|
224 |
def __call__(self, audio):
|
225 |
"""从音频样本推断出音符序列。
|
226 |
+
|
227 |
参数:
|
228 |
audio:16kHz 的单个音频样本的 1 维 numpy 数组。
|
229 |
返回:
|