Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import T5Tokenizer, T5ForConditionalGeneration, BartTokenizer, BartForConditionalGeneration | |
import requests | |
from bs4 import BeautifulSoup | |
import gtts | |
from io import BytesIO | |
import base64 | |
import os | |
# Function to fetch text from a URL | |
def fetch_text_from_url(url): | |
try: | |
response = requests.get(url) | |
soup = BeautifulSoup(response.content, 'html.parser') | |
paragraphs = soup.find_all('p') | |
text = ' '.join([para.get_text() for para in paragraphs]) | |
return text, None | |
except Exception as e: | |
return None, f"Error fetching URL: {e}" | |
# Function to summarize text using T5 | |
# Function to summarize text using T5 | |
def summarize_t5(text, size): | |
model_name = "C:\\Users\\zurin\\Desktop\\text summarization\\fine_tuned_t52" | |
tokenizer = T5Tokenizer.from_pretrained(model_name) | |
model = T5ForConditionalGeneration.from_pretrained(model_name) | |
input_text = f"summarize: {text}" | |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
# Define length parameters | |
if size == "Short": | |
min_len, max_len = 30, 50 | |
elif size == "Medium": | |
min_len, max_len = 50, 100 | |
else: # Long | |
min_len, max_len = 100, 200 | |
summary_ids = model.generate( | |
inputs["input_ids"], | |
max_length=max_len, | |
min_length=min_len, # Use the specified min_length instead of fixed 10 | |
length_penalty=1.0, # Reduced from 2.0 to allow more length variation | |
num_beams=4, | |
early_stopping=True | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |
# Function to summarize text using BART | |
def summarize_bart(text, size): | |
model_name = "C:\\Users\\zurin\\Desktop\\text summarization\\fine_tuned_bart" | |
tokenizer = BartTokenizer.from_pretrained(model_name) | |
model = BartForConditionalGeneration.from_pretrained(model_name) | |
inputs = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True) | |
# Define length parameters | |
if size == "Short": | |
min_len, max_len = 30, 50 | |
elif size == "Medium": | |
min_len, max_len = 50, 100 | |
else: # Long | |
min_len, max_len = 100, 200 | |
summary_ids = model.generate( | |
inputs["input_ids"], | |
max_length=max_len, | |
min_length=min_len, | |
length_penalty=0.8, # Reduced from 1.0 to encourage length variation | |
num_beams=6, | |
no_repeat_ngram_size=2, # Added to prevent repetition | |
early_stopping=True | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |
# Function to convert text to speech and save as a file | |
def text_to_speech(text): | |
tts = gtts.gTTS(text) | |
audio_file_path = "summary_audio.mp3" | |
tts.save(audio_file_path) | |
return audio_file_path | |
# Main function to handle summarization | |
def summarize_news(input_type, text_input, url_input, model_choice, size_choice): | |
# Determine the input text based on the input type | |
if input_type == "Text": | |
if not text_input: | |
return "Please provide text to summarize.", None, None | |
input_text = text_input | |
else: # URL | |
if not url_input: | |
return "Please provide a URL to summarize.", None, None | |
input_text, error = fetch_text_from_url(url_input) | |
if error: | |
return error, None, None | |
# Summarize the text | |
if model_choice == "T5": | |
summary = summarize_t5(input_text, size_choice) | |
else: # BART | |
summary = summarize_bart(input_text, size_choice) | |
# Generate audio for the summary | |
audio_file = text_to_speech(summary) | |
return summary, audio_file, None | |
# Custom CSS for the design | |
custom_css = """ | |
<style> | |
/* Background for the entire app */ | |
body { | |
background: linear-gradient(135deg, #E6E6FA 0%, #D8BFD8 100%) !important; | |
font-family: 'Arial', sans-serif; | |
min-height: 100vh; | |
margin: 0; | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
} | |
/* White container for all elements */ | |
.container { | |
background-color: #FFFFFF !important; | |
border-radius: 15px !important; | |
padding: 30px !important; | |
margin: 20px auto !important; | |
max-width: 800px !important; | |
box-shadow: 0 4px 10px rgba(0, 0, 0, 0.1) !important; | |
width: 100%; | |
} | |
/* Title styling */ | |
.title { | |
font-size: 36px; | |
color: #000000 !important; | |
text-align: center; | |
font-weight: bold; | |
margin-bottom: 10px; | |
} | |
/* Subtitle styling */ | |
.subtitle { | |
font-size: 18px; | |
color: #000000 !important; | |
text-align: center; | |
margin-bottom: 20px; | |
} | |
/* Labels for inputs */ | |
label { | |
color: #000000 !important; | |
} | |
/* Input and textarea styling */ | |
input, textarea { | |
background-color: #F5F5F5 !important; | |
color: #000000 !important; | |
border: 1px solid #D3D3D3 !important; | |
border-radius: 5px !important; | |
} | |
/* Dropdown styling */ | |
select { | |
background-color: #F5F5F5 !important; | |
color: #000000 !important; | |
border: 1px solid #D3D3D3 !important; | |
border-radius: 5px !important; | |
padding: 5px !important; | |
} | |
/* Button styling */ | |
button { | |
background-color: #9370DB !important; | |
color: white !important; | |
border-radius: 10px !important; | |
padding: 8px 20px !important; | |
border: none !important; | |
display: block !important; | |
margin: 20px auto !important; | |
cursor: pointer !important; | |
} | |
button:hover { | |
background-color: #4B0082 !important; | |
} | |
/* Footer styling */ | |
.footer { | |
text-align: center; | |
color: #000000 !important; | |
font-size: 14px; | |
margin-top: 30px; | |
} | |
.footer-heart { | |
color: #FF0000 !important; | |
} | |
/* Output text and error messages */ | |
.output-text, .error-text { | |
color: #000000 !important; | |
} | |
</style> | |
""" | |
# Gradio app | |
with gr.Blocks() as app: | |
# Inject custom CSS | |
gr.HTML(custom_css) | |
# Main container | |
with gr.Column(elem_classes=["container"]): | |
# Title and subtitle | |
gr.HTML('<p class="title">BBC News Summarizer</p>') | |
gr.HTML('<p class="subtitle">Summarize news articles with T5 or BART in your preferred length!</p>') | |
# Input section | |
input_type = gr.Radio(choices=["Text", "URL"], label="Choose input type:", value="Text") | |
with gr.Row(): | |
text_input = gr.Textbox(label="Enter news text here:", lines=5, visible=True, placeholder="Paste your news text here...") | |
url_input = gr.Textbox(label="Enter news URL here:", visible=False, placeholder="Enter a news article URL...") | |
# Show/hide text input or URL input based on input type | |
def update_input_visibility(input_type): | |
return ( | |
gr.update(visible=(input_type == "Text")), | |
gr.update(visible=(input_type == "URL")) | |
) | |
input_type.change( | |
fn=update_input_visibility, | |
inputs=input_type, | |
outputs=[text_input, url_input] | |
) | |
# Model selection | |
model_choice = gr.Dropdown(choices=["T5", "BART"], label="Choose summarization model:", value="T5") | |
# Summary size selection | |
size_choice = gr.Dropdown(choices=["Short", "Medium", "Long"], label="Choose summary size:", value="Short") | |
# Summarize button | |
summarize_button = gr.Button("Get Summary") | |
# Outputs | |
summary_output = gr.Textbox(label="Summary:", elem_classes=["output-text"]) | |
audio_output = gr.Audio(label="Listen to the Summary:") | |
error_output = gr.Textbox(label="Error:", elem_classes=["error-text"], visible=False) | |
# Footer | |
gr.HTML('<p class="footer">Powered by xAI\'s Grok | Made with <span class="footer-heart">❤️</span> for news enthusiasts</p>') | |
# Button click event | |
summarize_button.click( | |
fn=summarize_news, | |
inputs=[input_type, text_input, url_input, model_choice, size_choice], | |
outputs=[summary_output, audio_output, error_output] | |
) | |
# Launch the app | |
app.launch() |