File size: 5,019 Bytes
55dac6c
 
83e4ceb
55dac6c
 
ca317b4
55dac6c
 
 
 
 
 
 
 
 
 
 
 
 
32af9ec
ef93443
55dac6c
 
 
 
 
 
 
 
 
e4a22f3
 
 
 
 
 
 
 
 
55dac6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4a22f3
 
 
2f42c1a
 
 
 
55dac6c
 
 
 
 
 
 
 
 
 
 
e4a22f3
 
55dac6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4a22f3
55dac6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from collections import deque
import streamlit as st
import torch 
from streamlit_player import st_player
from transformers import AutoModelForCTC, Wav2Vec2Processor
from streaming import ffmpeg_stream 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
player_options = {
    "events": ["onProgress"],
    "progress_interval": 200,
    "volume": 1.0,
    "playing": True,
    "loop": False,
    "controls": False,
    "muted": False,
    "config": {"youtube": {"playerVars": {"start": 1}}},
}

st.title("YouTube Video Spanish ASR")
st.write("Acknowledgement: This demo is based on Anton Lozhkov's cool Space : https://huggingface.co/spaces/anton-l/youtube-subs-wav2vec")
# disable rapid fading in and out on `st.code` updates
st.markdown("<style>.element-container{opacity:1 !important}</style>", unsafe_allow_html=True)

@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None})
def load_model(model_path="facebook/wav2vec2-large-robust-ft-swbd-300h"):
    processor = Wav2Vec2Processor.from_pretrained(model_path)
    model = AutoModelForCTC.from_pretrained(model_path).to(device)
    return processor, model

model_path = st.radio(
    "Select a model", (
        "jonatasgrosman/wav2vec2-xls-r-1b-spanish", 
        "jonatasgrosman/wav2vec2-large-xlsr-53-spanish", 
        "patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm", 
        "facebook/wav2vec2-large-xlsr-53-spanish", 
        "glob-asr/xls-r-es-test-lm"
    )
)
   
processor, model = load_model(model_path)

def stream_text(url, chunk_duration_ms, pad_duration_ms):
    sampling_rate = processor.feature_extractor.sampling_rate

    # calculate the length of logits to cut from the sides of the output to account for input padding
    output_pad_len = model._get_feat_extract_output_lengths(int(sampling_rate * pad_duration_ms / 1000))

    # define the audio chunk generator
    stream = ffmpeg_stream(url, sampling_rate, chunk_duration_ms=chunk_duration_ms, pad_duration_ms=pad_duration_ms)

    leftover_text = ""
    for i, chunk in enumerate(stream):
        input_values = processor(chunk, sampling_rate=sampling_rate, return_tensors="pt").input_values

        with torch.no_grad():
            logits = model(input_values.to(device)).logits[0]
            if i > 0:
                logits = logits[output_pad_len : len(logits) - output_pad_len]
            else:  # don't count padding at the start of the clip
                logits = logits[: len(logits) - output_pad_len]

            predicted_ids = torch.argmax(logits, dim=-1).cpu().tolist()
            if processor.decode(predicted_ids).strip():
                leftover_ids = processor.tokenizer.encode(leftover_text)
                # concat the last word (or its part) from the last frame with the current text
                text = processor.decode(leftover_ids + predicted_ids)
                # don't return the last word in case it's just partially recognized
                if " " in text:
                    text, leftover_text = text.rsplit(" ", 1)
                else:
                    leftover_text = text
                    text = ""
                if text:
                    yield text
            else:
                yield leftover_text
                leftover_text = ""
    yield leftover_text

def main():
    state = st.session_state
    st.header("Video ASR Streamlit from Youtube Link")

    with st.form(key="inputs_form"):
    
        initial_url = "https://youtu.be/ghOqTkGzX7I?t=60"
        state.youtube_url = st.text_input("YouTube URL", initial_url)
        
        state.chunk_duration_ms = st.slider("Audio chunk duration (ms)", 2000, 10000, 3000, 100)
        state.pad_duration_ms = st.slider("Padding duration (ms)", 100, 5000, 1000, 100)
        submit_button = st.form_submit_button(label="Submit")

    if submit_button or "asr_stream" not in state:
        # a hack to update the video player on value changes
        state.youtube_url = (
            state.youtube_url.split("&hash=")[0]
            + f"&hash={state.chunk_duration_ms}-{state.pad_duration_ms}"
        )
        state.asr_stream = stream_text(
            state.youtube_url, state.chunk_duration_ms, state.pad_duration_ms
        )
        state.chunks_taken = 0
        
        
        state.lines = deque([], maxlen=5)  # limit to the last n lines of subs
        

    player = st_player(state.youtube_url, **player_options, key="youtube_player")

    if "asr_stream" in state and player.data and player.data["played"] < 1.0:
        # check how many seconds were played, and if more than processed - write the next text chunk
        processed_seconds = state.chunks_taken * (state.chunk_duration_ms / 1000)
        if processed_seconds < player.data["playedSeconds"]:
            text = next(state.asr_stream)
            state.lines.append(text)
            state.chunks_taken += 1
    if "lines" in state:
        # print the lines of subs
        st.code("\n".join(state.lines))


if __name__ == "__main__":
    main()