Dhanush4149 commited on
Commit
0330532
·
verified ·
1 Parent(s): d769310

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -64
app.py CHANGED
@@ -1,70 +1,104 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
- import os
4
- # Set Hugging Face cache directory
5
- os.environ['TRANSFORMERS_CACHE'] = os.getenv('HF_HOME', os.path.expanduser('~/.cache/huggingface/hub'))
6
 
7
- # Function to load all three models
8
- @st.cache_resource
9
- def load_models():
10
- bart_summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
11
- t5_summarizer = pipeline("summarization", model="t5-large")
12
- pegasus_summarizer = pipeline("summarization", model="google/pegasus-cnn_dailymail")
13
- return bart_summarizer, t5_summarizer, pegasus_summarizer
14
-
15
- # Streamlit app layout
16
- st.title("Text Summarization with Pre-trained Models: BART, T5, Pegasus")
17
-
18
- # Load models
19
- with st.spinner("Loading models..."):
20
- bart_model, t5_model, pegasus_model = load_models()
21
-
22
- # Input text
23
- text_input = st.text_area("Enter text to summarize:")
 
 
 
24
 
25
- # Compression rate slider
26
- compression_rate = st.slider(
27
- "Summary Compression Rate",
28
- min_value=0.1,
29
- max_value=0.5,
30
- value=0.3,
31
- step=0.05,
32
- help="Adjust how much of the original text to keep in the summary"
33
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- if text_input:
36
- word_count = len(text_input.split())
37
- st.write(f"**Input Word Count:** {word_count}")
38
 
39
- if st.button("Generate Summaries"):
40
- with st.spinner("Generating summaries..."):
41
- # Calculate dynamic max length based on compression rate
42
- max_length = max(50, int(word_count * compression_rate))
43
-
44
- # Generate summaries
45
- bart_summary = bart_model(
46
- text_input, max_length=max_length, min_length=30, num_beams=4, early_stopping=True
47
- )[0]['summary_text']
48
-
49
- t5_summary = t5_model(
50
- text_input, max_length=max_length, min_length=30, num_beams=4, early_stopping=True
51
- )[0]['summary_text']
52
-
53
- pegasus_summary = pegasus_model(
54
- text_input, max_length=max_length, min_length=30, num_beams=4, early_stopping=True
55
- )[0]['summary_text']
56
-
57
- # Display summaries
58
- st.subheader("BART Summary")
59
- st.write(bart_summary)
60
- st.write(f"**Word Count:** {len(bart_summary.split())}")
61
-
62
- st.subheader("T5 Summary")
63
- st.write(t5_summary)
64
- st.write(f"**Word Count:** {len(t5_summary.split())}")
65
-
66
- st.subheader("Pegasus Summary")
67
- st.write(pegasus_summary)
68
- st.write(f"**Word Count:** {len(pegasus_summary.split())}")
69
- else:
70
- st.warning("Please enter text to summarize.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from transformers import pipeline
3
+ import traceback
 
 
4
 
5
+ def load_pipelines():
6
+ """
7
+ Load summarization pipelines with error handling.
8
+
9
+ Returns:
10
+ dict: Dictionary of model pipelines
11
+ """
12
+ try:
13
+ bart_pipeline = pipeline("summarization", model="facebook/bart-large-cnn")
14
+ t5_pipeline = pipeline("summarization", model="t5-large")
15
+ pegasus_pipeline = pipeline("pegasus-cnn_dailymail")
16
+ return {
17
+ 'BART': bart_pipeline,
18
+ 'T5': t5_pipeline,
19
+ 'Pegasus': pegasus_pipeline
20
+ }
21
+ except Exception as e:
22
+ st.error(f"Error loading models: {str(e)}")
23
+ st.error(traceback.format_exc())
24
+ return {}
25
 
26
+ def generate_summary(pipeline, text, model_name):
27
+ """
28
+ Generate summary for a specific model with error handling.
29
+
30
+ Args:
31
+ pipeline: Hugging Face summarization pipeline
32
+ text (str): Input text to summarize
33
+ model_name (str): Name of the model
34
+
35
+ Returns:
36
+ str: Generated summary or error message
37
+ """
38
+ try:
39
+ prompt = "Summarize the below paragraph"
40
+ summary = pipeline(f"{prompt}\n{text}",
41
+ max_length=150,
42
+ min_length=50,
43
+ length_penalty=2.0,
44
+ num_beams=4,
45
+ early_stopping=True)[0]['summary_text']
46
+ return summary
47
+ except Exception as e:
48
+ error_msg = f"Error in {model_name} summarization: {str(e)}"
49
+ st.error(error_msg)
50
+ return error_msg
51
 
52
+ def main():
53
+ st.title("Text Summarization with Pre-trained Models")
 
54
 
55
+ # Text input
56
+ text_input = st.text_area("Enter text to summarize:")
57
+
58
+ # Generate button
59
+ if st.button("Generate Summary"):
60
+ if not text_input:
61
+ st.error("Please enter text to summarize.")
62
+ return
63
+
64
+ # Load pipelines
65
+ pipelines = load_pipelines()
66
+ if not pipelines:
67
+ st.error("Failed to load models. Please check your internet connection or try again later.")
68
+ return
69
+
70
+ # Create columns for progressive display
71
+ bart_col, t5_col, pegasus_col = st.columns(3)
72
+
73
+ # BART Summary
74
+ with bart_col:
75
+ with st.spinner('Generating BART Summary...'):
76
+ bart_progress = st.progress(0)
77
+ bart_progress.progress(50)
78
+ bart_summary = generate_summary(pipelines['BART'], text_input, 'BART')
79
+ bart_progress.progress(100)
80
+ st.subheader("BART Summary")
81
+ st.write(bart_summary)
82
+
83
+ # T5 Summary
84
+ with t5_col:
85
+ with st.spinner('Generating T5 Summary...'):
86
+ t5_progress = st.progress(0)
87
+ t5_progress.progress(50)
88
+ t5_summary = generate_summary(pipelines['T5'], text_input, 'T5')
89
+ t5_progress.progress(100)
90
+ st.subheader("T5 Summary")
91
+ st.write(t5_summary)
92
+
93
+ # Pegasus Summary
94
+ with pegasus_col:
95
+ with st.spinner('Generating Pegasus Summary...'):
96
+ pegasus_progress = st.progress(0)
97
+ pegasus_progress.progress(50)
98
+ pegasus_summary = generate_summary(pipelines['Pegasus'], text_input, 'Pegasus')
99
+ pegasus_progress.progress(100)
100
+ st.subheader("Pegasus Summary")
101
+ st.write(pegasus_summary)
102
+
103
+ if __name__ == "__main__":
104
+ main()