zurin14 commited on
Commit
ab96885
·
verified ·
1 Parent(s): ceda26e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -11
app.py CHANGED
@@ -18,45 +18,62 @@ def fetch_text_from_url(url):
18
  except Exception as e:
19
  return None, f"Error fetching URL: {e}"
20
 
 
21
  # Function to summarize text using T5
22
  def summarize_t5(text, size):
23
- model_name = "t5-small"
24
  tokenizer = T5Tokenizer.from_pretrained(model_name)
25
  model = T5ForConditionalGeneration.from_pretrained(model_name)
26
 
27
  input_text = f"summarize: {text}"
28
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
29
 
 
30
  if size == "Short":
31
- max_len = 50
32
  elif size == "Medium":
33
- max_len = 100
34
  else: # Long
35
- max_len = 200
36
 
37
- summary_ids = model.generate(inputs["input_ids"], max_length=max_len, min_length=10, length_penalty=2.0, num_beams=4)
 
 
 
 
 
 
 
38
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
39
  return summary
40
 
41
  # Function to summarize text using BART
42
  def summarize_bart(text, size):
43
- model_name = "facebook/bart-large-cnn"
44
  tokenizer = BartTokenizer.from_pretrained(model_name)
45
  model = BartForConditionalGeneration.from_pretrained(model_name)
46
 
47
  inputs = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
48
 
 
49
  if size == "Short":
50
- max_len = 50
51
  elif size == "Medium":
52
- max_len = 100
53
  else: # Long
54
- max_len = 200
55
 
56
- summary_ids = model.generate(inputs["input_ids"], max_length=max_len, min_length=10, length_penalty=2.0, num_beams=4)
 
 
 
 
 
 
 
 
57
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
58
  return summary
59
-
60
  # Function to convert text to speech and save as a file
61
  def text_to_speech(text):
62
  tts = gtts.gTTS(text)
 
18
  except Exception as e:
19
  return None, f"Error fetching URL: {e}"
20
 
21
+ # Function to summarize text using T5
22
  # Function to summarize text using T5
23
  def summarize_t5(text, size):
24
+ model_name = "C:\\Users\\zurin\\Desktop\\text summarization\\fine_tuned_t52"
25
  tokenizer = T5Tokenizer.from_pretrained(model_name)
26
  model = T5ForConditionalGeneration.from_pretrained(model_name)
27
 
28
  input_text = f"summarize: {text}"
29
  inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
30
 
31
+ # Define length parameters
32
  if size == "Short":
33
+ min_len, max_len = 30, 50
34
  elif size == "Medium":
35
+ min_len, max_len = 50, 100
36
  else: # Long
37
+ min_len, max_len = 100, 200
38
 
39
+ summary_ids = model.generate(
40
+ inputs["input_ids"],
41
+ max_length=max_len,
42
+ min_length=min_len, # Use the specified min_length instead of fixed 10
43
+ length_penalty=1.0, # Reduced from 2.0 to allow more length variation
44
+ num_beams=4,
45
+ early_stopping=True
46
+ )
47
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
48
  return summary
49
 
50
  # Function to summarize text using BART
51
  def summarize_bart(text, size):
52
+ model_name = "C:\\Users\\zurin\\Desktop\\text summarization\\fine_tuned_bart"
53
  tokenizer = BartTokenizer.from_pretrained(model_name)
54
  model = BartForConditionalGeneration.from_pretrained(model_name)
55
 
56
  inputs = tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
57
 
58
+ # Define length parameters
59
  if size == "Short":
60
+ min_len, max_len = 30, 50
61
  elif size == "Medium":
62
+ min_len, max_len = 50, 100
63
  else: # Long
64
+ min_len, max_len = 100, 200
65
 
66
+ summary_ids = model.generate(
67
+ inputs["input_ids"],
68
+ max_length=max_len,
69
+ min_length=min_len,
70
+ length_penalty=0.8, # Reduced from 1.0 to encourage length variation
71
+ num_beams=6,
72
+ no_repeat_ngram_size=2, # Added to prevent repetition
73
+ early_stopping=True
74
+ )
75
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
76
  return summary
 
77
  # Function to convert text to speech and save as a file
78
  def text_to_speech(text):
79
  tts = gtts.gTTS(text)