Hmjz100 commited on
Commit
1701e7a
·
verified ·
1 Parent(s): c974b7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -217,11 +217,13 @@ class InferenceModel(object):
217
  def predict_tokens(self, batch, seed=0):
218
  """从预处理的数据集批次中预测 tokens。"""
219
  print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
220
- prediction, _ = self._predict_fn(self._train_state.params, batch, jax.random.PRNGKey(seed))
 
221
  return self.vocabulary.decode_tf(prediction).numpy()
222
 
223
  def __call__(self, audio):
224
  """从音频样本推断出音符序列。
 
225
  参数:
226
  audio:16kHz 的单个音频样本的 1 维 numpy 数组。
227
  返回:
 
217
  def predict_tokens(self, batch, seed=0):
218
  """从预处理的数据集批次中预测 tokens。"""
219
  print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
220
+ prediction, _ = self._predict_fn(
221
+ self._train_state.params, batch, jax.random.PRNGKey(seed))
222
  return self.vocabulary.decode_tf(prediction).numpy()
223
 
224
  def __call__(self, audio):
225
  """从音频样本推断出音符序列。
226
+
227
  参数:
228
  audio:16kHz 的单个音频样本的 1 维 numpy 数组。
229
  返回: