Spaces:
Build error
Build error
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| import matplotlib.pyplot as plt | |
| import networkx as nx | |
| import io | |
| from PIL import Image | |
| import torch | |
| import os | |
| print("Installation complete. Loading models...") | |
| # Load models once at startup | |
| model_name = "csebuetnlp/mT5_multilingual_XLSum" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| # If you have a GPU, use it | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| model = model.to(device) | |
| # Load question generator once | |
| question_generator = pipeline( | |
| "text2text-generation", | |
| model="valhalla/t5-small-e2e-qg", | |
| device=device if device == "cuda" else -1 | |
| ) | |
| def summarize_text(text, src_lang): | |
| inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device) | |
| # Use more efficient generation parameters | |
| summary_ids = model.generate( | |
| inputs["input_ids"], | |
| max_length=150, | |
| min_length=30, | |
| length_penalty=2.0, | |
| num_beams=4, | |
| early_stopping=True | |
| ) | |
| summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
| return summary | |
| def generate_questions(summary): | |
| # Generate questions one at a time with beam search | |
| questions = [] | |
| for _ in range(3): # Generate 3 questions | |
| result = question_generator( | |
| summary, | |
| max_length=64, | |
| num_beams=4, | |
| do_sample=True, | |
| top_k=30, | |
| top_p=0.95, | |
| temperature=0.7 | |
| ) | |
| questions.append(result[0]['generated_text']) | |
| # Remove duplicates | |
| questions = list(set(questions)) | |
| return questions | |
| def generate_concept_map(summary, questions): | |
| # Use NetworkX and matplotlib for rendering | |
| G = nx.DiGraph() | |
| # Add summary as central node | |
| summary_short = summary[:50] + "..." if len(summary) > 50 else summary | |
| G.add_node("summary", label=summary_short) | |
| # Add question nodes and edges | |
| for i, question in enumerate(questions): | |
| q_short = question[:30] + "..." if len(question) > 30 else question | |
| node_id = f"Q{i}" | |
| G.add_node(node_id, label=q_short) | |
| G.add_edge("summary", node_id) | |
| # Create the plot directly in memory | |
| plt.figure(figsize=(10, 8)) | |
| pos = nx.spring_layout(G, seed=42) # Fixed seed for consistent layout | |
| nx.draw(G, pos, with_labels=False, node_color='skyblue', | |
| node_size=1500, arrows=True, connectionstyle='arc3,rad=0.1', | |
| edgecolors='black', linewidths=1) | |
| # Add labels with better font handling | |
| # FIX: Removed 'wrap' parameter which is not supported in this version of NetworkX | |
| labels = nx.get_node_attributes(G, 'label') | |
| nx.draw_networkx_labels(G, pos, labels=labels, font_size=9, | |
| font_family='sans-serif') | |
| # Save to memory buffer | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') | |
| buf.seek(0) | |
| plt.close() | |
| return Image.open(buf) | |
| def analyze_text(text, lang): | |
| if not text.strip(): | |
| return "Please enter some text.", "No questions generated.", None | |
| # Process the text | |
| try: | |
| print("Generating summary...") | |
| summary = summarize_text(text, lang) | |
| print("Generating questions...") | |
| questions = generate_questions(summary) | |
| print("Creating concept map...") | |
| concept_map_image = generate_concept_map(summary, questions) | |
| # Format questions as a list | |
| questions_text = "\n".join([f"- {q}" for q in questions]) | |
| return summary, questions_text, concept_map_image | |
| except Exception as e: | |
| import traceback | |
| print(f"Error processing text: {str(e)}") | |
| print(traceback.format_exc()) | |
| return f"Error processing text: {str(e)}", "", None | |
| # Alternative simpler concept map function in case the above still has issues | |
| def generate_simple_concept_map(summary, questions): | |
| """Fallback concept map generator with minimal dependencies""" | |
| plt.figure(figsize=(10, 8)) | |
| # Create a simple radial layout | |
| n_questions = len(questions) | |
| # Draw the central node (summary) | |
| plt.scatter([0], [0], s=1000, color='skyblue', edgecolors='black') | |
| plt.text(0, 0, summary[:50] + "..." if len(summary) > 50 else summary, | |
| ha='center', va='center', fontsize=9) | |
| # Draw the question nodes in a circle around the summary | |
| radius = 5 | |
| for i, question in enumerate(questions): | |
| angle = 2 * 3.14159 * i / max(n_questions, 1) | |
| x = radius * 0.8 * -1 * (max(n_questions, 1) - 1) * ((i / max(n_questions - 1, 1)) - 0.5) | |
| y = radius * 0.6 * (i % 2 * 2 - 1) | |
| # Draw node | |
| plt.scatter([x], [y], s=800, color='lightgreen', edgecolors='black') | |
| # Draw edge from summary to question | |
| plt.plot([0, x], [0, y], 'k-', alpha=0.6) | |
| # Add question text | |
| plt.text(x, y, question[:30] + "..." if len(question) > 30 else question, | |
| ha='center', va='center', fontsize=8) | |
| plt.axis('equal') | |
| plt.axis('off') | |
| # Save to memory buffer | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') | |
| buf.seek(0) | |
| plt.close() | |
| return Image.open(buf) | |
| examples = [ | |
| ["الذكاء الاصطناعي هو فرع من علوم الكمبيوتر يهدف إلى إنشاء آلات ذكية تعمل وتتفاعل مثل البشر. بعض الأنشطة التي صممت أجهزة الكمبيوتر الذكية للقيام بها تشمل: التعرف على الصوت، التعلم، التخطيط، وحل المشاكل.", "ar"], | |
| ["Artificial intelligence is a branch of computer science that aims to create intelligent machines that work and react like humans. Some of the activities computers with artificial intelligence are designed for include: Speech recognition, learning, planning, and problem-solving.", "en"] | |
| ] | |
| print("Creating Gradio interface...") | |
| # Modify the analyze_text function to use the fallback concept map if needed | |
| def analyze_text_with_fallback(text, lang): | |
| if not text.strip(): | |
| return "Please enter some text.", "No questions generated.", None | |
| try: | |
| print("Generating summary...") | |
| summary = summarize_text(text, lang) | |
| print("Generating questions...") | |
| questions = generate_questions(summary) | |
| print("Creating concept map...") | |
| try: | |
| # Try the main concept map generator first | |
| concept_map_image = generate_concept_map(summary, questions) | |
| except Exception as e: | |
| print(f"Main concept map failed: {e}, using fallback") | |
| # If it fails, use the fallback generator | |
| concept_map_image = generate_simple_concept_map(summary, questions) | |
| # Format questions as a list | |
| questions_text = "\n".join([f"- {q}" for q in questions]) | |
| return summary, questions_text, concept_map_image | |
| except Exception as e: | |
| import traceback | |
| print(f"Error processing text: {str(e)}") | |
| print(traceback.format_exc()) | |
| return f"Error processing text: {str(e)}", "", None | |
| iface = gr.Interface( | |
| fn=analyze_text_with_fallback, # Use the function with fallback | |
| inputs=[gr.Textbox(lines=10, placeholder="Enter text here..."), gr.Dropdown(["ar", "en"], label="Language")], | |
| outputs=[gr.Textbox(label="Summary"), gr.Textbox(label="Questions"), gr.Image(label="Concept Map")], | |
| examples=examples, | |
| title="AI Study Assistant", | |
| description="Enter a text in Arabic or English and the model will summarize it and generate questions and a concept map." | |
| ) | |
| # For Colab, we need to use a public URL | |
| iface.launch(share=True) |