DuyTa commited on
Commit
a9c93ba
·
1 Parent(s): 6777e61

Update realtime_Lora.py

Browse files
Files changed (1) hide show
  1. realtime_Lora.py +135 -0
realtime_Lora.py CHANGED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import torch
4
+ from transformers import (
5
+ AutomaticSpeechRecognitionPipeline,
6
+ WhisperForConditionalGeneration,
7
+ WhisperTokenizer,
8
+ WhisperProcessor,
9
+ )
10
+ from peft import PeftModel, PeftConfig
11
+ import speech_recognition as sr
12
+ from datetime import datetime, timedelta
13
+ from queue import Queue
14
+ from tempfile import NamedTemporaryFile
15
+ from time import sleep
16
+ from sys import platform
17
+
18
+
19
+
20
+ def main():
21
+ # Set your default configuration values here
22
+ peft_model_id = "DuyTa/Vietnamese_ASR"
23
+ language = "Vietnamese"
24
+ task = "transcribe"
25
+ default_energy_threshold = 900
26
+ default_record_timeout = 0.6
27
+ default_phrase_timeout = 3
28
+
29
+ # The last time a recording was retrieved from the queue.
30
+ phrase_time = None
31
+ # Current raw audio bytes.
32
+ last_sample = bytes()
33
+ # Thread safe Queue for passing data from the threaded recording callback.
34
+ data_queue = Queue()
35
+ # We use SpeechRecognizer to record our audio because it has a nice feature where it can detect when speech ends.
36
+ recorder = sr.Recognizer()
37
+ recorder.energy_threshold = default_energy_threshold
38
+ # Definitely do this, dynamic energy compensation lowers the energy threshold dramatically to a point where the SpeechRecognizer never stops recording.
39
+ recorder.dynamic_energy_threshold = False
40
+
41
+ source = sr.Microphone(sample_rate=16000) # Use default microphone source for non-Linux platforms
42
+
43
+ # Load / Download model
44
+ peft_config = PeftConfig.from_pretrained(peft_model_id)
45
+ model = WhisperForConditionalGeneration.from_pretrained(
46
+ peft_config.base_model_name_or_path
47
+ )
48
+ model = PeftModel.from_pretrained(model, peft_model_id)
49
+
50
+ model.to("cuda:0")
51
+ pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, batch_size=8, torch_dtype=torch.float32, device="cuda:0")
52
+
53
+ processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)
54
+
55
+
56
+ record_timeout = default_record_timeout
57
+ phrase_timeout = default_phrase_timeout
58
+
59
+ temp_file = NamedTemporaryFile().name
60
+ transcription = ['']
61
+
62
+ with source:
63
+ recorder.adjust_for_ambient_noise(source)
64
+
65
+ def record_callback(_, audio:sr.AudioData) -> None:
66
+ """
67
+ Threaded callback function to receive audio data when recordings finish.
68
+ audio: An AudioData containing the recorded bytes.
69
+ """
70
+ # Grab the raw bytes and push it into the thread safe queue.
71
+ data = audio.get_raw_data()
72
+ data_queue.put(data)
73
+
74
+ # Create a background thread that will pass us raw audio bytes.
75
+ # We could do this manually but SpeechRecognizer provides a nice helper.
76
+ recorder.listen_in_background(source, record_callback, phrase_time_limit=record_timeout)
77
+
78
+ print("Model loaded.\n")
79
+
80
+ while True:
81
+ try:
82
+ now = datetime.utcnow()
83
+ # Pull raw recorded audio from the queue.
84
+ if not data_queue.empty():
85
+ phrase_complete = False
86
+ # If enough time has passed between recordings, consider the phrase complete.
87
+ # Clear the current working audio buffer to start over with the new data.
88
+ if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout):
89
+ last_sample = bytes()
90
+ phrase_complete = True
91
+ # This is the last time we received new audio data from the queue.
92
+ phrase_time = now
93
+
94
+ # Concatenate our current audio data with the latest audio data.
95
+ while not data_queue.empty():
96
+ data = data_queue.get()
97
+ last_sample += data
98
+
99
+ # Use AudioData to convert the raw data to wav data.
100
+ audio_data = sr.AudioData(last_sample, source.SAMPLE_RATE, source.SAMPLE_WIDTH)
101
+ wav_data = io.BytesIO(audio_data.get_wav_data())
102
+
103
+ # Write wav data to the temporary file as bytes.
104
+ with open(temp_file, 'w+b') as f:
105
+ f.write(wav_data.read())
106
+
107
+ # Read the transcription.
108
+ text = pipe(temp_file, chunk_length_s=30, return_timestamps=False, generate_kwargs={"language": language, "task": task})["text"]
109
+
110
+
111
+ # If we detected a pause between recordings, add a new item to our transcription.
112
+ # Otherwise edit the existing one.
113
+ if phrase_complete:
114
+ transcription.append(text)
115
+ else:
116
+ transcription[-1] = text
117
+
118
+ # Clear the console to reprint the updated transcription.
119
+ os.system('cls' if os.name == 'nt' else 'clear')
120
+ for line in transcription:
121
+ print(line)
122
+ # Flush stdout.
123
+ print('', end='', flush=True)
124
+
125
+ # Infinite loops are bad for processors, must sleep.
126
+ sleep(0.25)
127
+ except KeyboardInterrupt:
128
+ break
129
+
130
+ print("\n\nTranscription:")
131
+ for line in transcription:
132
+ print(line)
133
+
134
+ if __name__ == "__main__":
135
+ main()