Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import ( | |
AutoModelForCTC, | |
Wav2Vec2Processor, | |
AutoProcessor, | |
WhisperProcessor, | |
WhisperForConditionalGeneration, | |
TextStreamer | |
) | |
from unsloth import FastLanguageModel | |
import numpy as np | |
import librosa | |
from scipy.signal import butter, sosfilt | |
# Initialize device | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
def apply_filter(audio_signal, sr, filter_type, cutoff_freq, slope=2, gain=0): | |
""" | |
Apply low-pass, high-pass, notch, or high-shelf filter to an audio signal. | |
""" | |
nyquist = sr / 2.0 | |
if filter_type == "lowpass": | |
sos = butter(slope, cutoff_freq / nyquist, btype="low", output="sos") | |
elif filter_type == "highpass": | |
sos = butter(slope, cutoff_freq / nyquist, btype="high", output="sos") | |
elif filter_type == "notch": | |
sos = butter(slope, [cutoff_freq[0] / nyquist, cutoff_freq[1] / nyquist], btype="bandstop", output="sos") | |
elif filter_type == "highshelf": | |
gain_linear = 10 ** (gain / 20.0) | |
omega = 2 * np.pi * cutoff_freq / sr | |
alpha = np.sin(omega) / (2 * slope) | |
A = gain_linear | |
b0 = A * ((A + 1) + (A - 1) * np.cos(omega) + 2 * np.sqrt(A) * alpha) | |
b1 = -2 * A * ((A - 1) + (A + 1) * np.cos(omega)) | |
b2 = A * ((A + 1) + (A - 1) * np.cos(omega) - 2 * np.sqrt(A) * alpha) | |
a0 = (A + 1) - (A - 1) * np.cos(omega) + 2 * np.sqrt(A) * alpha | |
a1 = 2 * ((A - 1) - (A + 1) * np.cos(omega)) | |
a2 = (A + 1) - (A - 1) * np.cos(omega) - 2 * np.sqrt(A) * alpha | |
b = np.array([b0, b1, b2]) / a0 | |
a = np.array([a0, a1, a2]) / a0 | |
sos = np.array([[b[0], b[1], b[2], 1, a[1], a[2]]]) | |
else: | |
raise ValueError("Invalid filter type.") | |
return sosfilt(sos, audio_signal) | |
def process_audio_filters(audio_signal, sr): | |
""" | |
Apply a series of filters to clean up the audio | |
""" | |
# Apply high-pass filter to remove low frequency noise | |
audio_signal = apply_filter(audio_signal, sr, "highpass", 80) | |
# Apply low-pass filter to remove high frequency noise | |
audio_signal = apply_filter(audio_signal, sr, "lowpass", 8000) | |
# Apply notch filter to remove power line interference (50/60 Hz) | |
audio_signal = apply_filter(audio_signal, sr, "notch", [45, 65]) | |
# Apply high-shelf filter to boost high frequencies for clarity | |
audio_signal = apply_filter(audio_signal, sr, "highshelf", 3000, slope=1, gain=3) | |
return audio_signal | |
class ModelManager: | |
def __init__(self): | |
self.asr_models = {} | |
self.llm_model = None | |
self.llm_tokenizer = None | |
def load_wav2vec2_base(self): | |
model = AutoModelForCTC.from_pretrained("kabir259/w2v2-base_kabir").to(DEVICE) | |
processor = Wav2Vec2Processor.from_pretrained("kabir259/w2v2-base_kabir") | |
return model, processor | |
def load_wav2vec2_bert(self): | |
model = AutoModelForCTC.from_pretrained("Kabir259/w2v2-BERT_kabir").to(DEVICE) | |
processor = AutoProcessor.from_pretrained("Kabir259/w2v2-BERT_kabir") | |
return model, processor | |
def load_whisper_small(self): | |
model = WhisperForConditionalGeneration.from_pretrained("Kabir259/whisper-small_kabir").to(DEVICE) | |
processor = WhisperProcessor.from_pretrained("Kabir259/whisper-small_kabir") | |
model.generation_config.task = "transcribe" | |
return model, processor | |
def load_qwen2(self): | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name="Kabir259/QWEN2-Medical", | |
max_seq_length=512, | |
dtype=torch.float16, | |
load_in_4bit=True, | |
) | |
FastLanguageModel.for_inference(model) | |
return model, tokenizer | |
def get_asr_model(self, model_name): | |
if model_name not in self.asr_models: | |
if model_name == "wav2vec2-base": | |
self.asr_models[model_name] = self.load_wav2vec2_base() | |
elif model_name == "wav2vec2-BERT": | |
self.asr_models[model_name] = self.load_wav2vec2_bert() | |
elif model_name == "whisper-small": | |
self.asr_models[model_name] = self.load_whisper_small() | |
return self.asr_models[model_name] | |
def get_llm_model(self): | |
if self.llm_model is None: | |
self.llm_model, self.llm_tokenizer = self.load_qwen2() | |
return self.llm_model, self.llm_tokenizer | |
def process_audio(audio_path, asr_model_name, model_manager): | |
model, processor = model_manager.get_asr_model(asr_model_name) | |
# Load and preprocess audio | |
audio, rate = librosa.load(audio_path, sr=16000) | |
# Apply audio filtering | |
filtered_audio = process_audio_filters(audio, rate) | |
if asr_model_name in ["wav2vec2-base", "wav2vec2-BERT"]: | |
# Process audio for wav2vec2 models | |
input_values = processor(filtered_audio, sampling_rate=16000, return_tensors="pt").input_values.to(DEVICE) | |
with torch.no_grad(): | |
logits = model(input_values).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.batch_decode(predicted_ids)[0] | |
else: # whisper model | |
input_features = processor(filtered_audio, sampling_rate=16000, return_tensors="pt").input_features.to(DEVICE) | |
with torch.no_grad(): | |
predicted_ids = model.generate(input_features) | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
return transcription | |
def generate_llm_response(text, model_manager): | |
model, tokenizer = model_manager.get_llm_model() | |
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. | |
### Instruction: | |
Provide medical advice for the following condition or symptom | |
### Input: | |
{0} | |
### Response: | |
""" | |
inputs = tokenizer( | |
[alpaca_prompt.format(text)], | |
return_tensors="pt" | |
).to(DEVICE) | |
text_streamer = TextStreamer(tokenizer) | |
with torch.no_grad(): | |
output_ids = model.generate( | |
**inputs, | |
streamer=text_streamer, | |
max_new_tokens=64 | |
) | |
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
return response | |
def process_pipeline(audio, asr_model_choice, model_manager): | |
# First step: ASR | |
transcription = process_audio(audio, asr_model_choice, model_manager) | |
# Second step: LLM | |
final_response = generate_llm_response(transcription, model_manager) | |
return transcription, final_response | |
# Initialize the model manager | |
model_manager = ModelManager() | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Medical Audio Consultation System") | |
with gr.Row(): | |
audio_input = gr.Audio(source="microphone", type="filepath") | |
asr_model_choice = gr.Dropdown( | |
choices=["wav2vec2-base", "wav2vec2-BERT", "whisper-small"], | |
label="Select ASR Model" | |
) | |
with gr.Row(): | |
transcription_output = gr.Textbox(label="Transcribed Text") | |
final_output = gr.Textbox(label="Medical Advice") | |
submit_btn = gr.Button("Process") | |
submit_btn.click( | |
fn=lambda audio, asr_choice: process_pipeline(audio, asr_choice, model_manager), | |
inputs=[audio_input, asr_model_choice], | |
outputs=[transcription_output, final_output] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |