import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline import matplotlib.pyplot as plt import networkx as nx import io from PIL import Image import torch import os print("Installation complete. Loading models...") # Load models once at startup model_name = "csebuetnlp/mT5_multilingual_XLSum" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) model = AutoModelForSeq2SeqLM.from_pretrained(model_name) # If you have a GPU, use it device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") model = model.to(device) # Load question generator once question_generator = pipeline( "text2text-generation", model="valhalla/t5-small-e2e-qg", device=device if device == "cuda" else -1 ) def summarize_text(text, src_lang): inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True).to(device) # Use more efficient generation parameters summary_ids = model.generate( inputs["input_ids"], max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True ) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) return summary def generate_questions(summary): # Generate questions one at a time with beam search questions = [] for _ in range(3): # Generate 3 questions result = question_generator( summary, max_length=64, num_beams=4, do_sample=True, top_k=30, top_p=0.95, temperature=0.7 ) questions.append(result[0]['generated_text']) # Remove duplicates questions = list(set(questions)) return questions def generate_concept_map(summary, questions): # Use NetworkX and matplotlib for rendering G = nx.DiGraph() # Add summary as central node summary_short = summary[:50] + "..." if len(summary) > 50 else summary G.add_node("summary", label=summary_short) # Add question nodes and edges for i, question in enumerate(questions): q_short = question[:30] + "..." if len(question) > 30 else question node_id = f"Q{i}" G.add_node(node_id, label=q_short) G.add_edge("summary", node_id) # Create the plot directly in memory plt.figure(figsize=(10, 8)) pos = nx.spring_layout(G, seed=42) # Fixed seed for consistent layout nx.draw(G, pos, with_labels=False, node_color='skyblue', node_size=1500, arrows=True, connectionstyle='arc3,rad=0.1', edgecolors='black', linewidths=1) # Add labels with better font handling # FIX: Removed 'wrap' parameter which is not supported in this version of NetworkX labels = nx.get_node_attributes(G, 'label') nx.draw_networkx_labels(G, pos, labels=labels, font_size=9, font_family='sans-serif') # Save to memory buffer buf = io.BytesIO() plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') buf.seek(0) plt.close() return Image.open(buf) def analyze_text(text, lang): if not text.strip(): return "Please enter some text.", "No questions generated.", None # Process the text try: print("Generating summary...") summary = summarize_text(text, lang) print("Generating questions...") questions = generate_questions(summary) print("Creating concept map...") concept_map_image = generate_concept_map(summary, questions) # Format questions as a list questions_text = "\n".join([f"- {q}" for q in questions]) return summary, questions_text, concept_map_image except Exception as e: import traceback print(f"Error processing text: {str(e)}") print(traceback.format_exc()) return f"Error processing text: {str(e)}", "", None # Alternative simpler concept map function in case the above still has issues def generate_simple_concept_map(summary, questions): """Fallback concept map generator with minimal dependencies""" plt.figure(figsize=(10, 8)) # Create a simple radial layout n_questions = len(questions) # Draw the central node (summary) plt.scatter([0], [0], s=1000, color='skyblue', edgecolors='black') plt.text(0, 0, summary[:50] + "..." if len(summary) > 50 else summary, ha='center', va='center', fontsize=9) # Draw the question nodes in a circle around the summary radius = 5 for i, question in enumerate(questions): angle = 2 * 3.14159 * i / max(n_questions, 1) x = radius * 0.8 * -1 * (max(n_questions, 1) - 1) * ((i / max(n_questions - 1, 1)) - 0.5) y = radius * 0.6 * (i % 2 * 2 - 1) # Draw node plt.scatter([x], [y], s=800, color='lightgreen', edgecolors='black') # Draw edge from summary to question plt.plot([0, x], [0, y], 'k-', alpha=0.6) # Add question text plt.text(x, y, question[:30] + "..." if len(question) > 30 else question, ha='center', va='center', fontsize=8) plt.axis('equal') plt.axis('off') # Save to memory buffer buf = io.BytesIO() plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') buf.seek(0) plt.close() return Image.open(buf) examples = [ ["الذكاء الاصطناعي هو فرع من علوم الكمبيوتر يهدف إلى إنشاء آلات ذكية تعمل وتتفاعل مثل البشر. بعض الأنشطة التي صممت أجهزة الكمبيوتر الذكية للقيام بها تشمل: التعرف على الصوت، التعلم، التخطيط، وحل المشاكل.", "ar"], ["Artificial intelligence is a branch of computer science that aims to create intelligent machines that work and react like humans. Some of the activities computers with artificial intelligence are designed for include: Speech recognition, learning, planning, and problem-solving.", "en"] ] print("Creating Gradio interface...") # Modify the analyze_text function to use the fallback concept map if needed def analyze_text_with_fallback(text, lang): if not text.strip(): return "Please enter some text.", "No questions generated.", None try: print("Generating summary...") summary = summarize_text(text, lang) print("Generating questions...") questions = generate_questions(summary) print("Creating concept map...") try: # Try the main concept map generator first concept_map_image = generate_concept_map(summary, questions) except Exception as e: print(f"Main concept map failed: {e}, using fallback") # If it fails, use the fallback generator concept_map_image = generate_simple_concept_map(summary, questions) # Format questions as a list questions_text = "\n".join([f"- {q}" for q in questions]) return summary, questions_text, concept_map_image except Exception as e: import traceback print(f"Error processing text: {str(e)}") print(traceback.format_exc()) return f"Error processing text: {str(e)}", "", None iface = gr.Interface( fn=analyze_text_with_fallback, # Use the function with fallback inputs=[gr.Textbox(lines=10, placeholder="Enter text here..."), gr.Dropdown(["ar", "en"], label="Language")], outputs=[gr.Textbox(label="Summary"), gr.Textbox(label="Questions"), gr.Image(label="Concept Map")], examples=examples, title="AI Study Assistant", description="Enter a text in Arabic or English and the model will summarize it and generate questions and a concept map." ) # For Colab, we need to use a public URL iface.launch(share=True)