Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
import matplotlib.pyplot as plt | |
import networkx as nx | |
import io | |
from PIL import Image | |
import torch | |
import os | |
print("Installation complete. Loading models...") | |
# Load models once at startup | |
model_name = "csebuetnlp/mT5_multilingual_XLSum" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# If you have a GPU, use it | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
model = model.to(device) | |
# Load question generator once | |
question_generator = pipeline( | |
"text2text-generation", | |
model="valhalla/t5-small-e2e-qg", | |
device=device if device == "cuda" else -1 | |
) | |
def summarize_text(text, src_lang): | |
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device) | |
# Use more efficient generation parameters | |
summary_ids = model.generate( | |
inputs["input_ids"], | |
max_length=150, | |
min_length=30, | |
length_penalty=2.0, | |
num_beams=4, | |
early_stopping=True | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |
def generate_questions(summary): | |
# Generate questions one at a time with beam search | |
questions = [] | |
for _ in range(3): # Generate 3 questions | |
result = question_generator( | |
summary, | |
max_length=64, | |
num_beams=4, | |
do_sample=True, | |
top_k=30, | |
top_p=0.95, | |
temperature=0.7 | |
) | |
questions.append(result[0]['generated_text']) | |
# Remove duplicates | |
questions = list(set(questions)) | |
return questions | |
def generate_concept_map(summary, questions): | |
# Use NetworkX and matplotlib for rendering | |
G = nx.DiGraph() | |
# Add summary as central node | |
summary_short = summary[:50] + "..." if len(summary) > 50 else summary | |
G.add_node("summary", label=summary_short) | |
# Add question nodes and edges | |
for i, question in enumerate(questions): | |
q_short = question[:30] + "..." if len(question) > 30 else question | |
node_id = f"Q{i}" | |
G.add_node(node_id, label=q_short) | |
G.add_edge("summary", node_id) | |
# Create the plot directly in memory | |
plt.figure(figsize=(10, 8)) | |
pos = nx.spring_layout(G, seed=42) # Fixed seed for consistent layout | |
nx.draw(G, pos, with_labels=False, node_color='skyblue', | |
node_size=1500, arrows=True, connectionstyle='arc3,rad=0.1', | |
edgecolors='black', linewidths=1) | |
# Add labels with better font handling | |
# FIX: Removed 'wrap' parameter which is not supported in this version of NetworkX | |
labels = nx.get_node_attributes(G, 'label') | |
nx.draw_networkx_labels(G, pos, labels=labels, font_size=9, | |
font_family='sans-serif') | |
# Save to memory buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') | |
buf.seek(0) | |
plt.close() | |
return Image.open(buf) | |
def analyze_text(text, lang): | |
if not text.strip(): | |
return "Please enter some text.", "No questions generated.", None | |
# Process the text | |
try: | |
print("Generating summary...") | |
summary = summarize_text(text, lang) | |
print("Generating questions...") | |
questions = generate_questions(summary) | |
print("Creating concept map...") | |
concept_map_image = generate_concept_map(summary, questions) | |
# Format questions as a list | |
questions_text = "\n".join([f"- {q}" for q in questions]) | |
return summary, questions_text, concept_map_image | |
except Exception as e: | |
import traceback | |
print(f"Error processing text: {str(e)}") | |
print(traceback.format_exc()) | |
return f"Error processing text: {str(e)}", "", None | |
# Alternative simpler concept map function in case the above still has issues | |
def generate_simple_concept_map(summary, questions): | |
"""Fallback concept map generator with minimal dependencies""" | |
plt.figure(figsize=(10, 8)) | |
# Create a simple radial layout | |
n_questions = len(questions) | |
# Draw the central node (summary) | |
plt.scatter([0], [0], s=1000, color='skyblue', edgecolors='black') | |
plt.text(0, 0, summary[:50] + "..." if len(summary) > 50 else summary, | |
ha='center', va='center', fontsize=9) | |
# Draw the question nodes in a circle around the summary | |
radius = 5 | |
for i, question in enumerate(questions): | |
angle = 2 * 3.14159 * i / max(n_questions, 1) | |
x = radius * 0.8 * -1 * (max(n_questions, 1) - 1) * ((i / max(n_questions - 1, 1)) - 0.5) | |
y = radius * 0.6 * (i % 2 * 2 - 1) | |
# Draw node | |
plt.scatter([x], [y], s=800, color='lightgreen', edgecolors='black') | |
# Draw edge from summary to question | |
plt.plot([0, x], [0, y], 'k-', alpha=0.6) | |
# Add question text | |
plt.text(x, y, question[:30] + "..." if len(question) > 30 else question, | |
ha='center', va='center', fontsize=8) | |
plt.axis('equal') | |
plt.axis('off') | |
# Save to memory buffer | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') | |
buf.seek(0) | |
plt.close() | |
return Image.open(buf) | |
examples = [ | |
["الذكاء الاصطناعي هو فرع من علوم الكمبيوتر يهدف إلى إنشاء آلات ذكية تعمل وتتفاعل مثل البشر. بعض الأنشطة التي صممت أجهزة الكمبيوتر الذكية للقيام بها تشمل: التعرف على الصوت، التعلم، التخطيط، وحل المشاكل.", "ar"], | |
["Artificial intelligence is a branch of computer science that aims to create intelligent machines that work and react like humans. Some of the activities computers with artificial intelligence are designed for include: Speech recognition, learning, planning, and problem-solving.", "en"] | |
] | |
print("Creating Gradio interface...") | |
# Modify the analyze_text function to use the fallback concept map if needed | |
def analyze_text_with_fallback(text, lang): | |
if not text.strip(): | |
return "Please enter some text.", "No questions generated.", None | |
try: | |
print("Generating summary...") | |
summary = summarize_text(text, lang) | |
print("Generating questions...") | |
questions = generate_questions(summary) | |
print("Creating concept map...") | |
try: | |
# Try the main concept map generator first | |
concept_map_image = generate_concept_map(summary, questions) | |
except Exception as e: | |
print(f"Main concept map failed: {e}, using fallback") | |
# If it fails, use the fallback generator | |
concept_map_image = generate_simple_concept_map(summary, questions) | |
# Format questions as a list | |
questions_text = "\n".join([f"- {q}" for q in questions]) | |
return summary, questions_text, concept_map_image | |
except Exception as e: | |
import traceback | |
print(f"Error processing text: {str(e)}") | |
print(traceback.format_exc()) | |
return f"Error processing text: {str(e)}", "", None | |
iface = gr.Interface( | |
fn=analyze_text_with_fallback, # Use the function with fallback | |
inputs=[gr.Textbox(lines=10, placeholder="Enter text here..."), gr.Dropdown(["ar", "en"], label="Language")], | |
outputs=[gr.Textbox(label="Summary"), gr.Textbox(label="Questions"), gr.Image(label="Concept Map")], | |
examples=examples, | |
title="AI Study Assistant", | |
description="Enter a text in Arabic or English and the model will summarize it and generate questions and a concept map." | |
) | |
# For Colab, we need to use a public URL | |
iface.launch(share=True) |