Labbeti commited on
Commit
f96230b
·
1 Parent(s): 1348542

Add: Options beam size, min pred size and max pred size to UI.

Browse files
Files changed (1) hide show
  1. app.py +26 -3
app.py CHANGED
@@ -15,6 +15,25 @@ def load_conette(*args, **kwargs) -> CoNeTTEModel:
15
 
16
  def main() -> None:
17
  st.header("CoNeTTE model test")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  audios = st.file_uploader(
19
  "Upload an audio file",
20
  type=["wav", "flac", "mp3", "ogg", "avi"],
@@ -22,13 +41,17 @@ def main() -> None:
22
  )
23
 
24
  if audios is not None and len(audios) > 0:
25
- model = load_conette(model_kwds=dict(device="cpu"))
26
-
27
  for audio in audios:
28
  with NamedTemporaryFile() as temp:
29
  temp.write(audio.getvalue())
30
  fpath = temp.name
31
- outputs = model(fpath)
 
 
 
 
 
 
32
  cand = outputs["cands"][0]
33
 
34
  st.write(f"Output for {audio.name}:")
 
15
 
16
  def main() -> None:
17
  st.header("CoNeTTE model test")
18
+ model = load_conette(model_kwds=dict(device="cpu"))
19
+
20
+ task = st.selectbox("Task embedding input", model.tasks, 0)
21
+ beam_size: int = st.select_slider( # type: ignore
22
+ "Beam size",
23
+ list(range(1, 50)),
24
+ model.config.beam_size,
25
+ )
26
+ min_pred_size: int = st.select_slider( # type: ignore
27
+ "Minimal number of words",
28
+ list(range(1, 50)),
29
+ model.config.min_pred_size,
30
+ )
31
+ max_pred_size: int = st.select_slider( # type: ignore
32
+ "Maximal number of words",
33
+ list(range(1, 50)),
34
+ model.config.max_pred_size,
35
+ )
36
+
37
  audios = st.file_uploader(
38
  "Upload an audio file",
39
  type=["wav", "flac", "mp3", "ogg", "avi"],
 
41
  )
42
 
43
  if audios is not None and len(audios) > 0:
 
 
44
  for audio in audios:
45
  with NamedTemporaryFile() as temp:
46
  temp.write(audio.getvalue())
47
  fpath = temp.name
48
+ outputs = model(
49
+ fpath,
50
+ task=task,
51
+ beam_size=beam_size,
52
+ min_pred_size=min_pred_size,
53
+ max_pred_size=max_pred_size,
54
+ )
55
  cand = outputs["cands"][0]
56
 
57
  st.write(f"Output for {audio.name}:")