medspeechrec / app.py
Kabir259's picture
Update app.py
e343b08 verified
raw
history blame
7.57 kB
import gradio as gr
import torch
from transformers import (
AutoModelForCTC,
Wav2Vec2Processor,
AutoProcessor,
WhisperProcessor,
WhisperForConditionalGeneration,
TextStreamer
)
from unsloth import FastLanguageModel
import numpy as np
import librosa
from scipy.signal import butter, sosfilt
# Initialize device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def apply_filter(audio_signal, sr, filter_type, cutoff_freq, slope=2, gain=0):
"""
Apply low-pass, high-pass, notch, or high-shelf filter to an audio signal.
"""
nyquist = sr / 2.0
if filter_type == "lowpass":
sos = butter(slope, cutoff_freq / nyquist, btype="low", output="sos")
elif filter_type == "highpass":
sos = butter(slope, cutoff_freq / nyquist, btype="high", output="sos")
elif filter_type == "notch":
sos = butter(slope, [cutoff_freq[0] / nyquist, cutoff_freq[1] / nyquist], btype="bandstop", output="sos")
elif filter_type == "highshelf":
gain_linear = 10 ** (gain / 20.0)
omega = 2 * np.pi * cutoff_freq / sr
alpha = np.sin(omega) / (2 * slope)
A = gain_linear
b0 = A * ((A + 1) + (A - 1) * np.cos(omega) + 2 * np.sqrt(A) * alpha)
b1 = -2 * A * ((A - 1) + (A + 1) * np.cos(omega))
b2 = A * ((A + 1) + (A - 1) * np.cos(omega) - 2 * np.sqrt(A) * alpha)
a0 = (A + 1) - (A - 1) * np.cos(omega) + 2 * np.sqrt(A) * alpha
a1 = 2 * ((A - 1) - (A + 1) * np.cos(omega))
a2 = (A + 1) - (A - 1) * np.cos(omega) - 2 * np.sqrt(A) * alpha
b = np.array([b0, b1, b2]) / a0
a = np.array([a0, a1, a2]) / a0
sos = np.array([[b[0], b[1], b[2], 1, a[1], a[2]]])
else:
raise ValueError("Invalid filter type.")
return sosfilt(sos, audio_signal)
def process_audio_filters(audio_signal, sr):
"""
Apply a series of filters to clean up the audio
"""
# Apply high-pass filter to remove low frequency noise
audio_signal = apply_filter(audio_signal, sr, "highpass", 80)
# Apply low-pass filter to remove high frequency noise
audio_signal = apply_filter(audio_signal, sr, "lowpass", 8000)
# Apply notch filter to remove power line interference (50/60 Hz)
audio_signal = apply_filter(audio_signal, sr, "notch", [45, 65])
# Apply high-shelf filter to boost high frequencies for clarity
audio_signal = apply_filter(audio_signal, sr, "highshelf", 3000, slope=1, gain=3)
return audio_signal
class ModelManager:
def __init__(self):
self.asr_models = {}
self.llm_model = None
self.llm_tokenizer = None
def load_wav2vec2_base(self):
model = AutoModelForCTC.from_pretrained("kabir259/w2v2-base_kabir").to(DEVICE)
processor = Wav2Vec2Processor.from_pretrained("kabir259/w2v2-base_kabir")
return model, processor
def load_wav2vec2_bert(self):
model = AutoModelForCTC.from_pretrained("Kabir259/w2v2-BERT_kabir").to(DEVICE)
processor = AutoProcessor.from_pretrained("Kabir259/w2v2-BERT_kabir")
return model, processor
def load_whisper_small(self):
model = WhisperForConditionalGeneration.from_pretrained("Kabir259/whisper-small_kabir").to(DEVICE)
processor = WhisperProcessor.from_pretrained("Kabir259/whisper-small_kabir")
model.generation_config.task = "transcribe"
return model, processor
def load_qwen2(self):
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="Kabir259/QWEN2-Medical",
max_seq_length=512,
dtype=torch.float16,
load_in_4bit=True,
)
FastLanguageModel.for_inference(model)
return model, tokenizer
def get_asr_model(self, model_name):
if model_name not in self.asr_models:
if model_name == "wav2vec2-base":
self.asr_models[model_name] = self.load_wav2vec2_base()
elif model_name == "wav2vec2-BERT":
self.asr_models[model_name] = self.load_wav2vec2_bert()
elif model_name == "whisper-small":
self.asr_models[model_name] = self.load_whisper_small()
return self.asr_models[model_name]
def get_llm_model(self):
if self.llm_model is None:
self.llm_model, self.llm_tokenizer = self.load_qwen2()
return self.llm_model, self.llm_tokenizer
def process_audio(audio_path, asr_model_name, model_manager):
model, processor = model_manager.get_asr_model(asr_model_name)
# Load and preprocess audio
audio, rate = librosa.load(audio_path, sr=16000)
# Apply audio filtering
filtered_audio = process_audio_filters(audio, rate)
if asr_model_name in ["wav2vec2-base", "wav2vec2-BERT"]:
# Process audio for wav2vec2 models
input_values = processor(filtered_audio, sampling_rate=16000, return_tensors="pt").input_values.to(DEVICE)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
else: # whisper model
input_features = processor(filtered_audio, sampling_rate=16000, return_tensors="pt").input_features.to(DEVICE)
with torch.no_grad():
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription
def generate_llm_response(text, model_manager):
model, tokenizer = model_manager.get_llm_model()
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
Provide medical advice for the following condition or symptom
### Input:
{0}
### Response:
"""
inputs = tokenizer(
[alpaca_prompt.format(text)],
return_tensors="pt"
).to(DEVICE)
text_streamer = TextStreamer(tokenizer)
with torch.no_grad():
output_ids = model.generate(
**inputs,
streamer=text_streamer,
max_new_tokens=64
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return response
def process_pipeline(audio, asr_model_choice, model_manager):
# First step: ASR
transcription = process_audio(audio, asr_model_choice, model_manager)
# Second step: LLM
final_response = generate_llm_response(transcription, model_manager)
return transcription, final_response
# Initialize the model manager
model_manager = ModelManager()
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Medical Audio Consultation System")
with gr.Row():
audio_input = gr.Audio(source="microphone", type="filepath")
asr_model_choice = gr.Dropdown(
choices=["wav2vec2-base", "wav2vec2-BERT", "whisper-small"],
label="Select ASR Model"
)
with gr.Row():
transcription_output = gr.Textbox(label="Transcribed Text")
final_output = gr.Textbox(label="Medical Advice")
submit_btn = gr.Button("Process")
submit_btn.click(
fn=lambda audio, asr_choice: process_pipeline(audio, asr_choice, model_manager),
inputs=[audio_input, asr_model_choice],
outputs=[transcription_output, final_output]
)
if __name__ == "__main__":
demo.launch(share=True)