File size: 7,814 Bytes
17e4cef
 
3f0629a
 
 
17e4cef
3f0629a
 
17e4cef
3f0629a
3fb74d1
3f0629a
17e4cef
3f0629a
17e4cef
 
3f0629a
 
 
 
3fb74d1
3f0629a
 
 
 
 
 
17e4cef
 
3f0629a
 
 
 
 
 
 
 
 
 
 
 
17e4cef
 
 
 
3f0629a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17e4cef
 
3f0629a
 
 
 
 
 
 
 
17e4cef
3f0629a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17e4cef
 
3f0629a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17e4cef
 
 
 
 
 
3f0629a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ae2319
17e4cef
3f0629a
 
17e4cef
 
 
3f0629a
17e4cef
 
3f0629a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import matplotlib.pyplot as plt
import networkx as nx
import io
from PIL import Image
import torch
import os

print("Installation complete. Loading models...")

# Load models once at startup
model_name = "csebuetnlp/mT5_multilingual_XLSum"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# If you have a GPU, use it
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model = model.to(device)

# Load question generator once
question_generator = pipeline(
    "text2text-generation",
    model="valhalla/t5-small-e2e-qg",
    device=device if device == "cuda" else -1
)

def summarize_text(text, src_lang):
    inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device)

    # Use more efficient generation parameters
    summary_ids = model.generate(
        inputs["input_ids"],
        max_length=150,
        min_length=30,
        length_penalty=2.0,
        num_beams=4,
        early_stopping=True
    )

    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

def generate_questions(summary):
    # Generate questions one at a time with beam search
    questions = []
    for _ in range(3):  # Generate 3 questions
        result = question_generator(
            summary,
            max_length=64,
            num_beams=4,
            do_sample=True,
            top_k=30,
            top_p=0.95,
            temperature=0.7
        )
        questions.append(result[0]['generated_text'])

    # Remove duplicates
    questions = list(set(questions))
    return questions

def generate_concept_map(summary, questions):
    # Use NetworkX and matplotlib for rendering
    G = nx.DiGraph()

    # Add summary as central node
    summary_short = summary[:50] + "..." if len(summary) > 50 else summary
    G.add_node("summary", label=summary_short)

    # Add question nodes and edges
    for i, question in enumerate(questions):
        q_short = question[:30] + "..." if len(question) > 30 else question
        node_id = f"Q{i}"
        G.add_node(node_id, label=q_short)
        G.add_edge("summary", node_id)

    # Create the plot directly in memory
    plt.figure(figsize=(10, 8))
    pos = nx.spring_layout(G, seed=42)  # Fixed seed for consistent layout
    nx.draw(G, pos, with_labels=False, node_color='skyblue',
            node_size=1500, arrows=True, connectionstyle='arc3,rad=0.1',
            edgecolors='black', linewidths=1)

    # Add labels with better font handling
    # FIX: Removed 'wrap' parameter which is not supported in this version of NetworkX
    labels = nx.get_node_attributes(G, 'label')
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=9,
                         font_family='sans-serif')

    # Save to memory buffer
    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
    buf.seek(0)
    plt.close()

    return Image.open(buf)

def analyze_text(text, lang):
    if not text.strip():
        return "Please enter some text.", "No questions generated.", None

    # Process the text
    try:
        print("Generating summary...")
        summary = summarize_text(text, lang)

        print("Generating questions...")
        questions = generate_questions(summary)

        print("Creating concept map...")
        concept_map_image = generate_concept_map(summary, questions)

        # Format questions as a list
        questions_text = "\n".join([f"- {q}" for q in questions])

        return summary, questions_text, concept_map_image
    except Exception as e:
        import traceback
        print(f"Error processing text: {str(e)}")
        print(traceback.format_exc())
        return f"Error processing text: {str(e)}", "", None

# Alternative simpler concept map function in case the above still has issues
def generate_simple_concept_map(summary, questions):
    """Fallback concept map generator with minimal dependencies"""
    plt.figure(figsize=(10, 8))

    # Create a simple radial layout
    n_questions = len(questions)

    # Draw the central node (summary)
    plt.scatter([0], [0], s=1000, color='skyblue', edgecolors='black')
    plt.text(0, 0, summary[:50] + "..." if len(summary) > 50 else summary,
             ha='center', va='center', fontsize=9)

    # Draw the question nodes in a circle around the summary
    radius = 5
    for i, question in enumerate(questions):
        angle = 2 * 3.14159 * i / max(n_questions, 1)
        x = radius * 0.8 * -1 * (max(n_questions, 1) - 1) * ((i / max(n_questions - 1, 1)) - 0.5)
        y = radius * 0.6 * (i % 2 * 2 - 1)

        # Draw node
        plt.scatter([x], [y], s=800, color='lightgreen', edgecolors='black')

        # Draw edge from summary to question
        plt.plot([0, x], [0, y], 'k-', alpha=0.6)

        # Add question text
        plt.text(x, y, question[:30] + "..." if len(question) > 30 else question,
                ha='center', va='center', fontsize=8)

    plt.axis('equal')
    plt.axis('off')

    # Save to memory buffer
    buf = io.BytesIO()
    plt.savefig(buf, format='png', dpi=100, bbox_inches='tight')
    buf.seek(0)
    plt.close()

    return Image.open(buf)

examples = [
    ["الذكاء الاصطناعي هو فرع من علوم الكمبيوتر يهدف إلى إنشاء آلات ذكية تعمل وتتفاعل مثل البشر. بعض الأنشطة التي صممت أجهزة الكمبيوتر الذكية للقيام بها تشمل: التعرف على الصوت، التعلم، التخطيط، وحل المشاكل.", "ar"],
    ["Artificial intelligence is a branch of computer science that aims to create intelligent machines that work and react like humans. Some of the activities computers with artificial intelligence are designed for include: Speech recognition, learning, planning, and problem-solving.", "en"]
]

print("Creating Gradio interface...")

# Modify the analyze_text function to use the fallback concept map if needed
def analyze_text_with_fallback(text, lang):
    if not text.strip():
        return "Please enter some text.", "No questions generated.", None

    try:
        print("Generating summary...")
        summary = summarize_text(text, lang)

        print("Generating questions...")
        questions = generate_questions(summary)

        print("Creating concept map...")
        try:
            # Try the main concept map generator first
            concept_map_image = generate_concept_map(summary, questions)
        except Exception as e:
            print(f"Main concept map failed: {e}, using fallback")
            # If it fails, use the fallback generator
            concept_map_image = generate_simple_concept_map(summary, questions)

        # Format questions as a list
        questions_text = "\n".join([f"- {q}" for q in questions])

        return summary, questions_text, concept_map_image
    except Exception as e:
        import traceback
        print(f"Error processing text: {str(e)}")
        print(traceback.format_exc())
        return f"Error processing text: {str(e)}", "", None

iface = gr.Interface(
    fn=analyze_text_with_fallback,  # Use the function with fallback
    inputs=[gr.Textbox(lines=10, placeholder="Enter text here..."), gr.Dropdown(["ar", "en"], label="Language")],
    outputs=[gr.Textbox(label="Summary"), gr.Textbox(label="Questions"), gr.Image(label="Concept Map")],
    examples=examples,
    title="AI Study Assistant",
    description="Enter a text in Arabic or English and the model will summarize it and generate questions and a concept map."
)

# For Colab, we need to use a public URL
iface.launch(share=True)