File size: 2,321 Bytes
74885f3
 
 
 
306eb78
74885f3
 
 
306eb78
74885f3
306eb78
 
 
 
 
 
 
 
 
 
 
 
74885f3
 
 
 
 
 
 
306eb78
 
74885f3
 
 
306eb78
74885f3
306eb78
 
74885f3
306eb78
 
 
 
 
 
 
 
74885f3
 
 
 
 
 
 
 
 
 
306eb78
74885f3
 
306eb78
74885f3
 
 
 
306eb78
 
74885f3
306eb78
74885f3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import whisper

# Initialize tokenizer and model for spell checking
tokenizer = AutoTokenizer.from_pretrained("Bhuvana/t5-base-spellchecker")
model = AutoModelForSeq2SeqLM.from_pretrained("Bhuvana/t5-base-spellchecker")

# Function to correct spelling errors in a given input text
def correct(inputs):
    '''Corrects spelling errors in the input text using the spell checker model.
    
    Args:
        inputs (str): The input text to be spell-checked.
    
    Returns:
        str: The corrected version of the input text.
    '''
    # Encode the input text using the tokenizer
    input_ids = tokenizer.encode(inputs, return_tensors='pt')
    
    # Generate corrected output using the spell checker model
    sample_output = model.generate(
        input_ids,
        do_sample=True,
        max_length=50,
        top_p=0.99,
        num_return_sequences=1
    )
    
    # Decode the corrected output and remove special tokens
    res = tokenizer.decode(sample_output[0], skip_special_tokens=True)
    return res

# Load the whisper model for audio transcription
whisper_model = whisper.load_model("base")

# Function to transcribe audio file
def transcribe(audio_file):
    '''Transcribes the content of an audio file.
    
    Args:
        audio_file (str): The path to the audio file.
    
    Returns:
        str: The transcribed text from the audio file, with spelling errors corrected.
    '''
    # Load audio and pad/trim it to fit 30 seconds
    audio = whisper.load_audio(audio_file)
    audio = whisper.pad_or_trim(audio)

    # Convert audio data to PyTorch tensor and float data type
    mel = torch.from_numpy(audio).float()

    # Make log-Mel spectrogram and move to the same device as the model
    mel = whisper.log_mel_spectrogram(mel).to(model.device)

    # Detect the spoken language using the whisper model
    _, probs = whisper_model.detect_language(mel)

    # Decode the audio using the whisper model
    options = whisper.DecodingOptions(fp16=False)
    result = whisper.decode(whisper_model, mel, options)
    result_text = result.text
    
    # Print the transcribed text
    print('result_text:' + result_text)

    # Correct any spelling errors in the transcribed text
    return correct(result_text)