from flask import Flask, request, jsonify, render_template_string, send_from_directory from werkzeug.utils import secure_filename from werkzeug.security import generate_password_hash, check_password_hash import pytesseract from PIL import Image import numpy as np import faiss import os import pickle from pdf2image import convert_from_bytes import torch import clip import io import json import uuid from datetime import datetime, timedelta import jwt import sqlite3 import tempfile app = Flask(__name__) app.config['SECRET_KEY'] = 'your-secret-key-change-this-in-production' # Security configuration SECRET_KEY = "your-secret-key-change-this-in-production" ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 # Set CLIP cache to writable directory os.environ['CLIP_CACHE'] = '/app/clip_cache' os.makedirs('/app/clip_cache', exist_ok=True) # Directories INDEX_PATH = "data/index.faiss" LABELS_PATH = "data/labels.pkl" DATABASE_PATH = "data/documents.db" UPLOADS_DIR = "data/uploads" os.makedirs("data", exist_ok=True) os.makedirs("static", exist_ok=True) os.makedirs(UPLOADS_DIR, exist_ok=True) # Initialize database def init_db(): conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() # Users table cursor.execute(''' CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT UNIQUE NOT NULL, password_hash TEXT NOT NULL, is_active BOOLEAN DEFAULT TRUE ) ''') # Documents table cursor.execute(''' CREATE TABLE IF NOT EXISTS documents ( id TEXT PRIMARY KEY, filename TEXT NOT NULL, original_filename TEXT NOT NULL, category TEXT NOT NULL, similarity REAL NOT NULL, ocr_text TEXT, upload_date TEXT NOT NULL, file_path TEXT NOT NULL ) ''') # Insert default admin user if not exists cursor.execute('SELECT * FROM users WHERE username = ?', ('admin',)) if not cursor.fetchone(): admin_hash = generate_password_hash('admin123') cursor.execute('INSERT INTO users (username, password_hash) VALUES (?, ?)', ('admin', admin_hash)) conn.commit() conn.close() init_db() # Initialize index and labels index = faiss.IndexFlatL2(512) labels = [] if os.path.exists(INDEX_PATH) and os.path.exists(LABELS_PATH): try: index = faiss.read_index(INDEX_PATH) with open(LABELS_PATH, "rb") as f: labels = pickle.load(f) print(f"✅ Loaded existing index with {len(labels)} labels") except Exception as e: print(f"⚠️ Failed to load existing index: {e}") if os.path.exists(INDEX_PATH): os.remove(INDEX_PATH) if os.path.exists(LABELS_PATH): os.remove(LABELS_PATH) # Initialize CLIP model with custom cache device = "cuda" if torch.cuda.is_available() else "cpu" try: clip_model, preprocess = clip.load("ViT-B/32", device=device, download_root='/app/clip_cache') print("✅ CLIP model loaded successfully") except Exception as e: print(f"❌ Failed to load CLIP model: {e}") # Fallback initialization clip_model = None preprocess = None # Helper functions def save_index(): try: faiss.write_index(index, INDEX_PATH) with open(LABELS_PATH, "wb") as f: pickle.dump(labels, f) except Exception as e: print(f"❌ Failed to save index: {e}") def authenticate_user(username: str, password: str): conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() cursor.execute('SELECT password_hash FROM users WHERE username = ? AND is_active = TRUE', (username,)) result = cursor.fetchone() conn.close() if result and check_password_hash(result[0], password): return {"username": username} return None def create_access_token(data: dict): expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) to_encode = data.copy() to_encode.update({"exp": expire}) return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) def verify_token(token: str): try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username = payload.get("sub") return username if username else None except jwt.PyJWTError: return None def image_from_pdf(pdf_bytes): try: images = convert_from_bytes(pdf_bytes, dpi=200) return images[0] except Exception as e: print(f"❌ PDF conversion error: {e}") return None def extract_text(image): try: if image.mode != 'RGB': image = image.convert('RGB') custom_config = r'--oem 3 --psm 6' text = pytesseract.image_to_string(image, config=custom_config) return text.strip() if text.strip() else "❓ No text detected" except Exception as e: return f"❌ OCR error: {str(e)}" def get_clip_embedding(image): try: if clip_model is None: return None if image.mode != 'RGB': image = image.convert('RGB') image_input = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): image_features = clip_model.encode_image(image_input) image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy()[0] except Exception as e: print(f"❌ CLIP embedding error: {e}") return None def save_uploaded_file(file_content: bytes, filename: str) -> str: file_id = str(uuid.uuid4()) file_extension = os.path.splitext(filename)[1] saved_filename = f"{file_id}{file_extension}" file_path = os.path.join(UPLOADS_DIR, saved_filename) with open(file_path, 'wb') as f: f.write(file_content) return saved_filename # Routes @app.route("/") def dashboard(): return render_template_string('''