Spaces:
Running
Running
from flask import Flask, render_template, request, jsonify | |
from werkzeug.utils import secure_filename | |
import os | |
import json | |
from textwrap import dedent | |
from crewai import Agent, Crew, Process, Task | |
from crewai_tools import SerperDevTool, PDFSearchTool | |
from crewai import LLM | |
from typing import List, Dict, Union, Optional | |
import tempfile | |
import logging | |
from datetime import datetime | |
import hashlib | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
app = Flask(__name__) | |
# Configuration | |
UPLOAD_FOLDER = tempfile.gettempdir() | |
ALLOWED_EXTENSIONS = {'pdf'} | |
MAX_FILE_SIZE = 16 * 1024 * 1024 # 16MB | |
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
app.config['MAX_CONTENT_LENGTH'] = MAX_FILE_SIZE | |
# Load environment variables | |
def get_env_variable(var_name: str) -> str: | |
"""Safely get environment variable with error handling.""" | |
value = os.environ.get(var_name) | |
if value is None: | |
raise ValueError(f"Environment variable {var_name} is not set") | |
return value | |
try: | |
# API Keys configuration | |
os.environ["GOOGLE_API_KEY"] = get_env_variable("GEMINI_API_KEY") | |
os.environ["GEMINI_API_KEY"] = get_env_variable("GEMINI_API_KEY") | |
os.environ["SERPER_API_KEY"] = get_env_variable("SERPER_API_KEY") | |
except ValueError as e: | |
logger.error(f"Configuration error: {e}") | |
raise | |
# Initialize LLM | |
llm = LLM( | |
model="gemini/gemini-1.5-flash", | |
temperature=0.7, | |
timeout=120, | |
max_tokens=8000, | |
) | |
# Initialize tools | |
search_tool = SerperDevTool() | |
def create_pdf_tool(pdf_path: Optional[str] = None) -> PDFSearchTool: | |
"""Create a PDFSearchTool with optional PDF path.""" | |
config = { | |
'llm': { | |
'provider': 'google', | |
'config': { | |
'model': 'gemini-1.5-flash', | |
}, | |
}, | |
'embedder': { | |
'provider': 'google', | |
'config': { | |
'model': 'models/embedding-001', | |
'task_type': 'retrieval_document', | |
}, | |
}, | |
} | |
if pdf_path: | |
return PDFSearchTool(pdf=pdf_path, config=config) | |
return PDFSearchTool(config=config) | |
def allowed_file(filename: str) -> bool: | |
"""Check if the file extension is allowed.""" | |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
def generate_safe_filename(filename: str) -> str: | |
"""Generate a safe filename with timestamp and hash.""" | |
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
file_hash = hashlib.md5(filename.encode()).hexdigest()[:10] | |
ext = filename.rsplit('.', 1)[1].lower() | |
return f"upload_{timestamp}_{file_hash}.{ext}" | |
class FlashcardGenerator: | |
def __init__(self, topic: str, pdf_path: Optional[str] = None): | |
self.topic = topic | |
self.pdf_path = pdf_path | |
self.researcher = self._create_researcher() | |
self.writer = self._create_writer() | |
def _create_researcher(self) -> Agent: | |
"""Create the researcher agent with appropriate tools.""" | |
tools = [search_tool] | |
if self.pdf_path: | |
tools.append(create_pdf_tool(self.pdf_path)) | |
return Agent( | |
role='Chercheur de Sujets', | |
goal=dedent(f"""Trouver les informations les plus pertinentes et précises sur {self.topic} | |
en utilisant l'API SerpApi et en analysant les PDFs fournis."""), | |
backstory=dedent("""Un chercheur expert spécialisé dans la collecte d'informations sur divers sujets. | |
Capable d'utiliser l'API SerpApi pour des recherches précises et d'analyser des documents PDF."""), | |
tools=tools, | |
llm=llm, | |
verbose=True, | |
allow_delegation=False | |
) | |
def _create_writer(self) -> Agent: | |
"""Create the writer agent.""" | |
return Agent( | |
role='Rédacteur de Flashcards', | |
goal=dedent("""Créer des flashcards claires et concises en format question-réponse | |
basées sur les informations fournies par le Chercheur."""), | |
backstory=dedent("""Un expert en pédagogie et en création de matériel d'apprentissage. | |
Capable de transformer des informations complexes en flashcards simples et mémorisables."""), | |
llm=llm, | |
verbose=True, | |
allow_delegation=False | |
) | |
def create_research_task(self) -> Task: | |
"""Create the research task.""" | |
description = f"""Effectuer une recherche approfondie sur le sujet '{self.topic}'.""" | |
if self.pdf_path: | |
description += f" Analyser également le contenu du PDF fourni: {self.pdf_path}" | |
return Task( | |
description=dedent(description), | |
expected_output="Une liste d'informations pertinentes sur le sujet.", | |
agent=self.researcher | |
) | |
def create_flashcard_task(self, research_task: Task) -> Task: | |
"""Create the flashcard creation task.""" | |
return Task( | |
description=dedent("""Transformer les informations fournies par le Chercheur | |
en une série de flashcards au format JSON. je veux une vingtaine de flashcard très robuste et difficile . Chaque flashcard doit avoir une question | |
d'un côté et une réponse concise de l'autre. Les réponses doivent être claires et informatives."""), | |
expected_output="Une liste de flashcards au format JSON.", | |
agent=self.writer, | |
context=[research_task] | |
) | |
def generate(self) -> List[Dict[str, str]]: | |
"""Generate flashcards using the crew workflow.""" | |
research_task = self.create_research_task() | |
flashcard_task = self.create_flashcard_task(research_task) | |
crew = Crew( | |
agents=[self.researcher, self.writer], | |
tasks=[research_task, flashcard_task], | |
process=Process.sequential, | |
verbose=True | |
) | |
result = crew.kickoff() | |
return self.extract_json_from_result(result.tasks_output[-1].raw) | |
def extract_json_from_result(result_text: str) -> List[Dict[str, str]]: | |
"""Extract and validate JSON from the result text.""" | |
try: | |
json_start = result_text.find('[') | |
json_end = result_text.rfind(']') + 1 | |
if json_start == -1 or json_end == 0: | |
raise ValueError("JSON non trouvé dans le résultat") | |
json_str = result_text[json_start:json_end] | |
flashcards = json.loads(json_str) | |
# Validate flashcard format | |
for card in flashcards: | |
if not isinstance(card, dict) or 'question' not in card or 'answer' not in card: | |
raise ValueError("Format de flashcard invalide") | |
return flashcards | |
except (json.JSONDecodeError, ValueError) as e: | |
logger.error(f"Error extracting JSON: {str(e)}") | |
raise ValueError(f"Erreur lors de l'extraction du JSON : {str(e)}") | |
def index(): | |
"""Render the main page.""" | |
return render_template('index.html') | |
def generate_flashcards(): | |
"""Handle flashcard generation requests.""" | |
try: | |
# Validate topic | |
topic = request.form.get('topic') | |
if not topic: | |
return jsonify({'error': 'Veuillez entrer un sujet.'}), 400 | |
# Handle file upload | |
pdf_path = None | |
if 'file' in request.files: | |
file = request.files['file'] | |
if file and file.filename: | |
if not allowed_file(file.filename): | |
return jsonify({'error': 'Format de fichier non supporté. Veuillez utiliser un PDF.'}), 400 | |
if file.content_length and file.content_length > MAX_FILE_SIZE: | |
return jsonify({'error': 'Le fichier est trop volumineux. Maximum 16MB.'}), 400 | |
# Generate safe filename and save file | |
safe_filename = generate_safe_filename(file.filename) | |
pdf_path = os.path.join(app.config['UPLOAD_FOLDER'], safe_filename) | |
file.save(pdf_path) | |
logger.info(f"File saved: {pdf_path}") | |
try: | |
# Generate flashcards | |
generator = FlashcardGenerator(topic, pdf_path) | |
flashcards = generator.generate() | |
return jsonify({ | |
'success': True, | |
'flashcards': flashcards | |
}) | |
finally: | |
# Clean up PDF file | |
if pdf_path and os.path.exists(pdf_path): | |
os.remove(pdf_path) | |
logger.info(f"File cleaned up: {pdf_path}") | |
except Exception as e: | |
logger.error(f"Error generating flashcards: {str(e)}") | |
return jsonify({'error': str(e)}), 500 | |
def request_entity_too_large(error): | |
"""Handle file size limit exceeded error.""" | |
return jsonify({'error': 'Le fichier est trop volumineux. Maximum 16MB.'}), 413 | |
if __name__ == '__main__': | |
app.run(debug=True) |