Update realtime_Lora.py
Browse files- realtime_Lora.py +135 -0
realtime_Lora.py
CHANGED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from transformers import (
|
5 |
+
AutomaticSpeechRecognitionPipeline,
|
6 |
+
WhisperForConditionalGeneration,
|
7 |
+
WhisperTokenizer,
|
8 |
+
WhisperProcessor,
|
9 |
+
)
|
10 |
+
from peft import PeftModel, PeftConfig
|
11 |
+
import speech_recognition as sr
|
12 |
+
from datetime import datetime, timedelta
|
13 |
+
from queue import Queue
|
14 |
+
from tempfile import NamedTemporaryFile
|
15 |
+
from time import sleep
|
16 |
+
from sys import platform
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def main():
|
21 |
+
# Set your default configuration values here
|
22 |
+
peft_model_id = "DuyTa/Vietnamese_ASR"
|
23 |
+
language = "Vietnamese"
|
24 |
+
task = "transcribe"
|
25 |
+
default_energy_threshold = 900
|
26 |
+
default_record_timeout = 0.6
|
27 |
+
default_phrase_timeout = 3
|
28 |
+
|
29 |
+
# The last time a recording was retrieved from the queue.
|
30 |
+
phrase_time = None
|
31 |
+
# Current raw audio bytes.
|
32 |
+
last_sample = bytes()
|
33 |
+
# Thread safe Queue for passing data from the threaded recording callback.
|
34 |
+
data_queue = Queue()
|
35 |
+
# We use SpeechRecognizer to record our audio because it has a nice feature where it can detect when speech ends.
|
36 |
+
recorder = sr.Recognizer()
|
37 |
+
recorder.energy_threshold = default_energy_threshold
|
38 |
+
# Definitely do this, dynamic energy compensation lowers the energy threshold dramatically to a point where the SpeechRecognizer never stops recording.
|
39 |
+
recorder.dynamic_energy_threshold = False
|
40 |
+
|
41 |
+
source = sr.Microphone(sample_rate=16000) # Use default microphone source for non-Linux platforms
|
42 |
+
|
43 |
+
# Load / Download model
|
44 |
+
peft_config = PeftConfig.from_pretrained(peft_model_id)
|
45 |
+
model = WhisperForConditionalGeneration.from_pretrained(
|
46 |
+
peft_config.base_model_name_or_path
|
47 |
+
)
|
48 |
+
model = PeftModel.from_pretrained(model, peft_model_id)
|
49 |
+
|
50 |
+
model.to("cuda:0")
|
51 |
+
pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, batch_size=8, torch_dtype=torch.float32, device="cuda:0")
|
52 |
+
|
53 |
+
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
|
54 |
+
|
55 |
+
|
56 |
+
record_timeout = default_record_timeout
|
57 |
+
phrase_timeout = default_phrase_timeout
|
58 |
+
|
59 |
+
temp_file = NamedTemporaryFile().name
|
60 |
+
transcription = ['']
|
61 |
+
|
62 |
+
with source:
|
63 |
+
recorder.adjust_for_ambient_noise(source)
|
64 |
+
|
65 |
+
def record_callback(_, audio:sr.AudioData) -> None:
|
66 |
+
"""
|
67 |
+
Threaded callback function to receive audio data when recordings finish.
|
68 |
+
audio: An AudioData containing the recorded bytes.
|
69 |
+
"""
|
70 |
+
# Grab the raw bytes and push it into the thread safe queue.
|
71 |
+
data = audio.get_raw_data()
|
72 |
+
data_queue.put(data)
|
73 |
+
|
74 |
+
# Create a background thread that will pass us raw audio bytes.
|
75 |
+
# We could do this manually but SpeechRecognizer provides a nice helper.
|
76 |
+
recorder.listen_in_background(source, record_callback, phrase_time_limit=record_timeout)
|
77 |
+
|
78 |
+
print("Model loaded.\n")
|
79 |
+
|
80 |
+
while True:
|
81 |
+
try:
|
82 |
+
now = datetime.utcnow()
|
83 |
+
# Pull raw recorded audio from the queue.
|
84 |
+
if not data_queue.empty():
|
85 |
+
phrase_complete = False
|
86 |
+
# If enough time has passed between recordings, consider the phrase complete.
|
87 |
+
# Clear the current working audio buffer to start over with the new data.
|
88 |
+
if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout):
|
89 |
+
last_sample = bytes()
|
90 |
+
phrase_complete = True
|
91 |
+
# This is the last time we received new audio data from the queue.
|
92 |
+
phrase_time = now
|
93 |
+
|
94 |
+
# Concatenate our current audio data with the latest audio data.
|
95 |
+
while not data_queue.empty():
|
96 |
+
data = data_queue.get()
|
97 |
+
last_sample += data
|
98 |
+
|
99 |
+
# Use AudioData to convert the raw data to wav data.
|
100 |
+
audio_data = sr.AudioData(last_sample, source.SAMPLE_RATE, source.SAMPLE_WIDTH)
|
101 |
+
wav_data = io.BytesIO(audio_data.get_wav_data())
|
102 |
+
|
103 |
+
# Write wav data to the temporary file as bytes.
|
104 |
+
with open(temp_file, 'w+b') as f:
|
105 |
+
f.write(wav_data.read())
|
106 |
+
|
107 |
+
# Read the transcription.
|
108 |
+
text = pipe(temp_file, chunk_length_s=30, return_timestamps=False, generate_kwargs={"language": language, "task": task})["text"]
|
109 |
+
|
110 |
+
|
111 |
+
# If we detected a pause between recordings, add a new item to our transcription.
|
112 |
+
# Otherwise edit the existing one.
|
113 |
+
if phrase_complete:
|
114 |
+
transcription.append(text)
|
115 |
+
else:
|
116 |
+
transcription[-1] = text
|
117 |
+
|
118 |
+
# Clear the console to reprint the updated transcription.
|
119 |
+
os.system('cls' if os.name == 'nt' else 'clear')
|
120 |
+
for line in transcription:
|
121 |
+
print(line)
|
122 |
+
# Flush stdout.
|
123 |
+
print('', end='', flush=True)
|
124 |
+
|
125 |
+
# Infinite loops are bad for processors, must sleep.
|
126 |
+
sleep(0.25)
|
127 |
+
except KeyboardInterrupt:
|
128 |
+
break
|
129 |
+
|
130 |
+
print("\n\nTranscription:")
|
131 |
+
for line in transcription:
|
132 |
+
print(line)
|
133 |
+
|
134 |
+
if __name__ == "__main__":
|
135 |
+
main()
|