Maram-almasary commited on
Commit
3f0629a
·
verified ·
1 Parent(s): 3fb74d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -24
app.py CHANGED
@@ -1,57 +1,214 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
- import graphviz
 
 
4
  from PIL import Image
5
- import sentencepiece
 
6
 
7
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
 
 
9
  model_name = "csebuetnlp/mT5_multilingual_XLSum"
10
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) # المحول البطيء
11
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
12
 
 
 
 
 
13
 
14
- question_generator = pipeline("text2text-generation", model="valhalla/t5-small-e2e-qg")
 
 
 
 
 
15
 
16
  def summarize_text(text, src_lang):
17
- inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
18
- summary_ids = model.generate(inputs["input_ids"], max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
 
 
 
 
 
 
 
 
 
 
19
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
20
  return summary
21
 
22
  def generate_questions(summary):
23
- questions = question_generator(summary, max_length=64, num_return_sequences=5)
24
- return [q['generated_text'] for q in questions]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def generate_concept_map(summary, questions):
27
- dot = graphviz.Digraph(comment='Concept Map')
28
- dot.node('A', summary)
 
 
 
 
 
 
29
  for i, question in enumerate(questions):
30
- dot.node(f'Q{i}', question)
31
- dot.edge('A', f'Q{i}')
32
- dot.render('concept_map', format='png')
33
- return Image.open('concept_map.png')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def analyze_text(text, lang):
36
- summary = summarize_text(text, lang)
37
- questions = generate_questions(summary)
38
- concept_map_image = generate_concept_map(summary, questions)
39
- return summary, questions, concept_map_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  examples = [
42
  ["الذكاء الاصطناعي هو فرع من علوم الكمبيوتر يهدف إلى إنشاء آلات ذكية تعمل وتتفاعل مثل البشر. بعض الأنشطة التي صممت أجهزة الكمبيوتر الذكية للقيام بها تشمل: التعرف على الصوت، التعلم، التخطيط، وحل المشاكل.", "ar"],
43
  ["Artificial intelligence is a branch of computer science that aims to create intelligent machines that work and react like humans. Some of the activities computers with artificial intelligence are designed for include: Speech recognition, learning, planning, and problem-solving.", "en"]
44
  ]
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  iface = gr.Interface(
48
- fn=analyze_text,
49
- inputs=[gr.Textbox(lines=10, placeholder="Enter text here........"), gr.Dropdown(["ar", "en"], label="Language")],
50
  outputs=[gr.Textbox(label="Summary"), gr.Textbox(label="Questions"), gr.Image(label="Concept Map")],
51
  examples=examples,
52
  title="AI Study Assistant",
53
- description="Enter a text in Arabic or English and the model will summarize it and generate various questions about it in addition to generating a concept map, or you can choose one of the examples."
54
  )
55
 
56
- if __name__ == "__main__":
57
- iface.launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ import matplotlib.pyplot as plt
4
+ import networkx as nx
5
+ import io
6
  from PIL import Image
7
+ import torch
8
+ import os
9
 
10
+ print("Installation complete. Loading models...")
11
 
12
+ # Load models once at startup
13
  model_name = "csebuetnlp/mT5_multilingual_XLSum"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
15
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
16
 
17
+ # If you have a GPU, use it
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ print(f"Using device: {device}")
20
+ model = model.to(device)
21
 
22
+ # Load question generator once
23
+ question_generator = pipeline(
24
+ "text2text-generation",
25
+ model="valhalla/t5-small-e2e-qg",
26
+ device=device if device == "cuda" else -1
27
+ )
28
 
29
  def summarize_text(text, src_lang):
30
+ inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)
31
+
32
+ # Use more efficient generation parameters
33
+ summary_ids = model.generate(
34
+ inputs["input_ids"],
35
+ max_length=150,
36
+ min_length=30,
37
+ length_penalty=2.0,
38
+ num_beams=4,
39
+ early_stopping=True
40
+ )
41
+
42
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
43
  return summary
44
 
45
  def generate_questions(summary):
46
+ # Generate questions one at a time with beam search
47
+ questions = []
48
+ for _ in range(3): # Generate 3 questions
49
+ result = question_generator(
50
+ summary,
51
+ max_length=64,
52
+ num_beams=4,
53
+ do_sample=True,
54
+ top_k=30,
55
+ top_p=0.95,
56
+ temperature=0.7
57
+ )
58
+ questions.append(result[0]['generated_text'])
59
+
60
+ # Remove duplicates
61
+ questions = list(set(questions))
62
+ return questions
63
 
64
  def generate_concept_map(summary, questions):
65
+ # Use NetworkX and matplotlib for rendering
66
+ G = nx.DiGraph()
67
+
68
+ # Add summary as central node
69
+ summary_short = summary[:50] + "..." if len(summary) > 50 else summary
70
+ G.add_node("summary", label=summary_short)
71
+
72
+ # Add question nodes and edges
73
  for i, question in enumerate(questions):
74
+ q_short = question[:30] + "..." if len(question) > 30 else question
75
+ node_id = f"Q{i}"
76
+ G.add_node(node_id, label=q_short)
77
+ G.add_edge("summary", node_id)
78
+
79
+ # Create the plot directly in memory
80
+ plt.figure(figsize=(10, 8))
81
+ pos = nx.spring_layout(G, seed=42) # Fixed seed for consistent layout
82
+ nx.draw(G, pos, with_labels=False, node_color='skyblue',
83
+ node_size=1500, arrows=True, connectionstyle='arc3,rad=0.1',
84
+ edgecolors='black', linewidths=1)
85
+
86
+ # Add labels with better font handling
87
+ # FIX: Removed 'wrap' parameter which is not supported in this version of NetworkX
88
+ labels = nx.get_node_attributes(G, 'label')
89
+ nx.draw_networkx_labels(G, pos, labels=labels, font_size=9,
90
+ font_family='sans-serif')
91
+
92
+ # Save to memory buffer
93
+ buf = io.BytesIO()
94
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
95
+ buf.seek(0)
96
+ plt.close()
97
+
98
+ return Image.open(buf)
99
 
100
  def analyze_text(text, lang):
101
+ if not text.strip():
102
+ return "Please enter some text.", "No questions generated.", None
103
+
104
+ # Process the text
105
+ try:
106
+ print("Generating summary...")
107
+ summary = summarize_text(text, lang)
108
+
109
+ print("Generating questions...")
110
+ questions = generate_questions(summary)
111
+
112
+ print("Creating concept map...")
113
+ concept_map_image = generate_concept_map(summary, questions)
114
+
115
+ # Format questions as a list
116
+ questions_text = "\n".join([f"- {q}" for q in questions])
117
+
118
+ return summary, questions_text, concept_map_image
119
+ except Exception as e:
120
+ import traceback
121
+ print(f"Error processing text: {str(e)}")
122
+ print(traceback.format_exc())
123
+ return f"Error processing text: {str(e)}", "", None
124
+
125
+ # Alternative simpler concept map function in case the above still has issues
126
+ def generate_simple_concept_map(summary, questions):
127
+ """Fallback concept map generator with minimal dependencies"""
128
+ plt.figure(figsize=(10, 8))
129
+
130
+ # Create a simple radial layout
131
+ n_questions = len(questions)
132
+
133
+ # Draw the central node (summary)
134
+ plt.scatter([0], [0], s=1000, color='skyblue', edgecolors='black')
135
+ plt.text(0, 0, summary[:50] + "..." if len(summary) > 50 else summary,
136
+ ha='center', va='center', fontsize=9)
137
+
138
+ # Draw the question nodes in a circle around the summary
139
+ radius = 5
140
+ for i, question in enumerate(questions):
141
+ angle = 2 * 3.14159 * i / max(n_questions, 1)
142
+ x = radius * 0.8 * -1 * (max(n_questions, 1) - 1) * ((i / max(n_questions - 1, 1)) - 0.5)
143
+ y = radius * 0.6 * (i % 2 * 2 - 1)
144
+
145
+ # Draw node
146
+ plt.scatter([x], [y], s=800, color='lightgreen', edgecolors='black')
147
+
148
+ # Draw edge from summary to question
149
+ plt.plot([0, x], [0, y], 'k-', alpha=0.6)
150
+
151
+ # Add question text
152
+ plt.text(x, y, question[:30] + "..." if len(question) > 30 else question,
153
+ ha='center', va='center', fontsize=8)
154
+
155
+ plt.axis('equal')
156
+ plt.axis('off')
157
+
158
+ # Save to memory buffer
159
+ buf = io.BytesIO()
160
+ plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
161
+ buf.seek(0)
162
+ plt.close()
163
+
164
+ return Image.open(buf)
165
 
166
  examples = [
167
  ["الذكاء الاصطناعي هو فرع من علوم الكمبيوتر يهدف إلى إنشاء آلات ذكية تعمل وتتفاعل مثل البشر. بعض الأنشطة التي صممت أجهزة الكمبيوتر الذكية للقيام بها تشمل: التعرف على الصوت، التعلم، التخطيط، وحل المشاكل.", "ar"],
168
  ["Artificial intelligence is a branch of computer science that aims to create intelligent machines that work and react like humans. Some of the activities computers with artificial intelligence are designed for include: Speech recognition, learning, planning, and problem-solving.", "en"]
169
  ]
170
 
171
+ print("Creating Gradio interface...")
172
+
173
+ # Modify the analyze_text function to use the fallback concept map if needed
174
+ def analyze_text_with_fallback(text, lang):
175
+ if not text.strip():
176
+ return "Please enter some text.", "No questions generated.", None
177
+
178
+ try:
179
+ print("Generating summary...")
180
+ summary = summarize_text(text, lang)
181
+
182
+ print("Generating questions...")
183
+ questions = generate_questions(summary)
184
+
185
+ print("Creating concept map...")
186
+ try:
187
+ # Try the main concept map generator first
188
+ concept_map_image = generate_concept_map(summary, questions)
189
+ except Exception as e:
190
+ print(f"Main concept map failed: {e}, using fallback")
191
+ # If it fails, use the fallback generator
192
+ concept_map_image = generate_simple_concept_map(summary, questions)
193
+
194
+ # Format questions as a list
195
+ questions_text = "\n".join([f"- {q}" for q in questions])
196
+
197
+ return summary, questions_text, concept_map_image
198
+ except Exception as e:
199
+ import traceback
200
+ print(f"Error processing text: {str(e)}")
201
+ print(traceback.format_exc())
202
+ return f"Error processing text: {str(e)}", "", None
203
 
204
  iface = gr.Interface(
205
+ fn=analyze_text_with_fallback, # Use the function with fallback
206
+ inputs=[gr.Textbox(lines=10, placeholder="Enter text here..."), gr.Dropdown(["ar", "en"], label="Language")],
207
  outputs=[gr.Textbox(label="Summary"), gr.Textbox(label="Questions"), gr.Image(label="Concept Map")],
208
  examples=examples,
209
  title="AI Study Assistant",
210
+ description="Enter a text in Arabic or English and the model will summarize it and generate questions and a concept map."
211
  )
212
 
213
+ # For Colab, we need to use a public URL
214
+ iface.launch(share=True)