AlexanderBenady commited on
Commit
072c3d7
1 Parent(s): 34c450a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -0
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import warnings
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, MarianMTModel, MarianTokenizer, AutoModelForSequenceClassification, AutoProcessor, pipeline
5
+ import torch
6
+ from pydub import AudioSegment
7
+ import gradio as gr
8
+
9
+ # Suppress specific warnings related to transformers and audio processing
10
+ warnings.filterwarnings("ignore", category=UserWarning)
11
+ warnings.filterwarnings("ignore", message="Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.")
12
+ warnings.filterwarnings("ignore", message="Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English.This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'.")
13
+
14
+ # Set up logging
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
16
+
17
+ # Set the computation device and data type for the model based on CUDA availability
18
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
19
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
20
+
21
+ # Preload necessary models and tokenizers
22
+ summarizer_tokenizer = AutoTokenizer.from_pretrained('cranonieu2021/pegasus-on-lectures')
23
+ summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("cranonieu2021/pegasus-on-lectures", torch_dtype=torch_dtype).to(device)
24
+ translator_tokenizer = MarianTokenizer.from_pretrained("sfarjebespalaia/enestranslatorforsummaries")
25
+ translator_model = MarianMTModel.from_pretrained("sfarjebespalaia/enestranslatorforsummaries", torch_dtype=torch_dtype).to(device)
26
+ classifier_tokenizer = AutoTokenizer.from_pretrained("gserafico/roberta-base-finetuned-classifier-roberta1")
27
+ classifier_model = AutoModelForSequenceClassification.from_pretrained("gserafico/roberta-base-finetuned-classifier-roberta1", torch_dtype=torch_dtype).to(device)
28
+
29
+ def transcribe_audio(audio_file_path):
30
+ """
31
+ Transcribes audio from a file to text using the specified model.
32
+
33
+ Parameters:
34
+ audio_file_path (str): Path to the audio file.
35
+
36
+ Returns:
37
+ str: Transcribed text.
38
+ """
39
+ try:
40
+ model_id = "openai/whisper-large-v3"
41
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True)
42
+ model.to(device)
43
+ processor = AutoProcessor.from_pretrained(model_id)
44
+ pipe = pipeline("automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, device=device)
45
+ result = pipe(audio_file_path)
46
+ logging.info("Audio transcription completed successfully.")
47
+ return result['text']
48
+ except Exception as e:
49
+ logging.error(f"Error transcribing audio: {e}")
50
+ raise
51
+
52
+ def load_and_process_input(file_info):
53
+ """
54
+ Loads and processes an input file based on its extension.
55
+
56
+ Parameters:
57
+ file_info (str): Path to the file.
58
+
59
+ Returns:
60
+ str: Processed text or transcription of audio.
61
+ """
62
+ file_path = file_info # Assuming it's just the path
63
+ original_filename = os.path.basename(file_path) # Extract filename from path
64
+
65
+ extension = os.path.splitext(original_filename)[-1].lower()
66
+ try:
67
+ if extension == ".txt":
68
+ with open(file_path, 'r', encoding='utf-8') as file:
69
+ text = file.read()
70
+ elif extension in [".mp3", ".wav"]:
71
+ if extension == ".mp3":
72
+ file_path = convert_mp3_to_wav(file_path)
73
+ text = transcribe_audio(file_path)
74
+ else:
75
+ raise ValueError("Unsupported file type provided.")
76
+ except Exception as e:
77
+ logging.error(f"Error processing input file: {e}")
78
+ raise
79
+ return text
80
+
81
+ def convert_mp3_to_wav(file_path):
82
+ """
83
+ Converts an MP3 audio file to WAV format.
84
+
85
+ Parameters:
86
+ file_path (str): Path to the MP3 file.
87
+
88
+ Returns:
89
+ str: Path to the WAV file created.
90
+ """
91
+ try:
92
+ wav_file_path = file_path.replace(".mp3", ".wav")
93
+ audio = AudioSegment.from_file(file_path, format='mp3')
94
+ audio.export(wav_file_path, format="wav")
95
+ logging.info("MP3 file converted to WAV.")
96
+ return wav_file_path
97
+ except Exception as e:
98
+ logging.error(f"Error converting MP3 to WAV: {e}")
99
+ raise
100
+
101
+ def process_text(text, summarization=False, translation=False, classification=False):
102
+ """
103
+ Processes text for summarization, translation, and classification based on options selected.
104
+
105
+ Parameters:
106
+ text (str): Text to process.
107
+ summarization (bool): Whether to perform summarization.
108
+ translation (bool): Whether to perform translation.
109
+ classification (bool): Whether to perform classification.
110
+
111
+ Returns:
112
+ dict: Results of the processing tasks.
113
+ """
114
+ results = {}
115
+ intermediate_text = text # Start with the original text
116
+
117
+ # Summary generation
118
+ if summarization:
119
+ inputs = summarizer_tokenizer(intermediate_text, max_length=1024, return_tensors="pt", truncation=True)
120
+ summary_ids = summarizer_model.generate(inputs.input_ids, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
121
+ summary_text = summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
122
+ results['summarized_text'] = summary_text
123
+ intermediate_text = summary_text # Use summary for further processing if needed
124
+
125
+ # Text translation
126
+ if translation:
127
+ tokenized_text = translator_tokenizer.prepare_seq2seq_batch([intermediate_text], return_tensors="pt")
128
+ translated = translator_model.generate(**tokenized_text)
129
+ translated_text = ' '.join(translator_tokenizer.decode(t, skip_special_tokens=True) for t in translated)
130
+ results['translated_text'] = translated_text.strip()
131
+
132
+ # Text classification
133
+ if classification:
134
+ inputs = classifier_tokenizer(intermediate_text, return_tensors="pt", truncation=True, padding=True)
135
+ with torch.no_grad():
136
+ outputs = classifier_model(**inputs)
137
+ predicted_class_idx = torch.argmax(outputs.logits, dim=1).item()
138
+ labels = {
139
+ 0: 'Social Sciences',
140
+ 1: 'Arts',
141
+ 2: 'Natural Sciences',
142
+ 3: 'Business and Law',
143
+ 4: 'Engineering and Technology'
144
+ }
145
+ results['classification_result'] = labels[predicted_class_idx]
146
+
147
+ return results
148
+
149
+ def display_results(results):
150
+ """
151
+ Displays the results of the text processing tasks.
152
+
153
+ Parameters:
154
+ results (dict): Dictionary containing the results of text processing.
155
+ """
156
+ if 'summarized_text' in results:
157
+ print("Summarized Text:")
158
+ print(results['summarized_text'])
159
+ if 'translated_text' in results:
160
+ print("Translated Text:")
161
+ print(results['translated_text'])
162
+ if 'classification_result' in results:
163
+ print('Classification Result:')
164
+ print(f"This text is classified under: {results['classification_result']}")
165
+
166
+ def wrap_process_file(file_obj, tasks):
167
+ """
168
+ Processes the uploaded file and returns results for selected tasks.
169
+
170
+ Parameters:
171
+ file_obj (tuple): File object containing the file path and original filename.
172
+ tasks (list): List of tasks to be performed on the file.
173
+
174
+ Returns:
175
+ tuple: Results of the tasks.
176
+ """
177
+ if file_obj is None:
178
+ return "Please upload a file to proceed.", "", "", ""
179
+
180
+ # Assuming file_obj is a tuple containing (temp file path, original file name)
181
+ text = load_and_process_input(file_obj)
182
+ results = process_text(text, 'Summarization' in tasks, 'Translation' in tasks, 'Classification' in tasks)
183
+
184
+ return (results.get('summarized_text', ''),
185
+ results.get('translated_text', ''),
186
+ results.get('classification_result', ''))
187
+
188
+ def create_gradio_interface():
189
+ """
190
+ Creates a Gradio interface for file processing and result display.
191
+
192
+ Returns:
193
+ gr.Blocks: Gradio interface configured for the application.
194
+ """
195
+ with gr.Blocks(theme="huggingface") as demo:
196
+ gr.Markdown("# LectorSync 1.0")
197
+ gr.Markdown("## Upload your file and select the tasks:")
198
+ with gr.Row():
199
+ file_input = gr.File(label="Upload your text, mp3, or wav file")
200
+ task_choice = gr.CheckboxGroup(["Summarization", "Translation", "Classification"], label="Select Tasks")
201
+ submit_button = gr.Button("Process")
202
+ output_summary = gr.Text(label="Summarized Text")
203
+ output_translation = gr.Text(label="Translated Text")
204
+ output_classification = gr.Text(label="Classification Result")
205
+
206
+ submit_button.click(
207
+ fn=wrap_process_file,
208
+ inputs=[file_input, task_choice],
209
+ outputs=[output_summary, output_translation, output_classification]
210
+ )
211
+
212
+ return demo
213
+
214
+ if __name__ == "__main__":
215
+ demo = create_gradio_interface()
216
+ demo.launch()