from flask import Flask, request, jsonify, render_template_string, send_from_directory from werkzeug.utils import secure_filename from werkzeug.security import generate_password_hash, check_password_hash import pytesseract from PIL import Image import numpy as np import faiss import os import pickle from pdf2image import convert_from_bytes import torch import clip import io import json import uuid from datetime import datetime, timedelta import jwt import sqlite3 import tempfile app = Flask(__name__) app.config['SECRET_KEY'] = 'your-secret-key-change-this-in-production' # Security configuration SECRET_KEY = "your-secret-key-change-this-in-production" ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 # Set CLIP cache to writable directory os.environ['CLIP_CACHE'] = '/app/clip_cache' os.makedirs('/app/clip_cache', exist_ok=True) # Directories INDEX_PATH = "data/index.faiss" LABELS_PATH = "data/labels.pkl" DATABASE_PATH = "data/documents.db" UPLOADS_DIR = "data/uploads" os.makedirs("data", exist_ok=True) os.makedirs("static", exist_ok=True) os.makedirs(UPLOADS_DIR, exist_ok=True) # Initialize database def init_db(): conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() # Users table cursor.execute(''' CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT UNIQUE NOT NULL, password_hash TEXT NOT NULL, is_active BOOLEAN DEFAULT TRUE ) ''') # Documents table cursor.execute(''' CREATE TABLE IF NOT EXISTS documents ( id TEXT PRIMARY KEY, filename TEXT NOT NULL, original_filename TEXT NOT NULL, category TEXT NOT NULL, similarity REAL NOT NULL, ocr_text TEXT, upload_date TEXT NOT NULL, file_path TEXT NOT NULL ) ''') # Insert default admin user if not exists cursor.execute('SELECT * FROM users WHERE username = ?', ('admin',)) if not cursor.fetchone(): admin_hash = generate_password_hash('admin123') cursor.execute('INSERT INTO users (username, password_hash) VALUES (?, ?)', ('admin', admin_hash)) conn.commit() conn.close() init_db() # Initialize index and labels index = faiss.IndexFlatL2(512) labels = [] if os.path.exists(INDEX_PATH) and os.path.exists(LABELS_PATH): try: index = faiss.read_index(INDEX_PATH) with open(LABELS_PATH, "rb") as f: labels = pickle.load(f) print(f"✅ Loaded existing index with {len(labels)} labels") except Exception as e: print(f"⚠️ Failed to load existing index: {e}") if os.path.exists(INDEX_PATH): os.remove(INDEX_PATH) if os.path.exists(LABELS_PATH): os.remove(LABELS_PATH) # Initialize CLIP model with custom cache device = "cuda" if torch.cuda.is_available() else "cpu" try: clip_model, preprocess = clip.load("ViT-B/32", device=device, download_root='/app/clip_cache') print("✅ CLIP model loaded successfully") except Exception as e: print(f"❌ Failed to load CLIP model: {e}") # Fallback initialization clip_model = None preprocess = None # Helper functions def save_index(): try: faiss.write_index(index, INDEX_PATH) with open(LABELS_PATH, "wb") as f: pickle.dump(labels, f) except Exception as e: print(f"❌ Failed to save index: {e}") def authenticate_user(username: str, password: str): conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() cursor.execute('SELECT password_hash FROM users WHERE username = ? AND is_active = TRUE', (username,)) result = cursor.fetchone() conn.close() if result and check_password_hash(result[0], password): return {"username": username} return None def create_access_token(data: dict): expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) to_encode = data.copy() to_encode.update({"exp": expire}) return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) def verify_token(token: str): try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username = payload.get("sub") return username if username else None except jwt.PyJWTError: return None def image_from_pdf(pdf_bytes): try: images = convert_from_bytes(pdf_bytes, dpi=200) return images[0] except Exception as e: print(f"❌ PDF conversion error: {e}") return None def extract_text(image): try: if image.mode != 'RGB': image = image.convert('RGB') custom_config = r'--oem 3 --psm 6' text = pytesseract.image_to_string(image, config=custom_config) return text.strip() if text.strip() else "❓ No text detected" except Exception as e: return f"❌ OCR error: {str(e)}" def get_clip_embedding(image): try: if clip_model is None: return None if image.mode != 'RGB': image = image.convert('RGB') image_input = preprocess(image).unsqueeze(0).to(device) with torch.no_grad(): image_features = clip_model.encode_image(image_input) image_features = image_features / image_features.norm(dim=-1, keepdim=True) return image_features.cpu().numpy()[0] except Exception as e: print(f"❌ CLIP embedding error: {e}") return None def save_uploaded_file(file_content: bytes, filename: str) -> str: file_id = str(uuid.uuid4()) file_extension = os.path.splitext(filename)[1] saved_filename = f"{file_id}{file_extension}" file_path = os.path.join(UPLOADS_DIR, saved_filename) with open(file_path, 'wb') as f: f.write(file_content) return saved_filename # Routes @app.route("/") def dashboard(): return render_template_string(''' Document Classification System

Document Classification System

Login

''') @app.route("/api/login", methods=["POST"]) def login(): username = request.form.get("username") password = request.form.get("password") user = authenticate_user(username, password) if not user: return jsonify({"detail": "Incorrect username or password"}), 401 access_token = create_access_token(data={"sub": user["username"]}) return jsonify({"access_token": access_token, "token_type": "bearer", "username": user["username"]}) @app.route("/api/upload-category", methods=["POST"]) def upload_category(): # Verify token auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({"error": "Missing or invalid token"}), 401 token = auth_header.split(' ')[1] username = verify_token(token) if not username: return jsonify({"error": "Invalid token"}), 401 try: label = request.form.get("label") file = request.files.get("file") if not label or not file: return jsonify({"error": "Missing label or file"}), 400 file_content = file.read() if file.content_type and file.content_type.startswith('application/pdf'): image = image_from_pdf(file_content) else: image = Image.open(io.BytesIO(file_content)) if image is None: return jsonify({"error": "Failed to process image"}), 400 embedding = get_clip_embedding(image) if embedding is None: return jsonify({"error": "Failed to generate embedding"}), 400 index.add(np.array([embedding])) labels.append(label.strip()) save_index() return jsonify({"message": f"✅ Added category '{label}' (Total: {len(labels)} categories)", "status": "success"}) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/classify-document", methods=["POST"]) def classify_document(): # Verify token auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({"error": "Missing or invalid token"}), 401 token = auth_header.split(' ')[1] username = verify_token(token) if not username: return jsonify({"error": "Invalid token"}), 401 try: if len(labels) == 0: return jsonify({"error": "No categories in database. Please add some first."}), 400 file = request.files.get("file") if not file: return jsonify({"error": "Missing file"}), 400 file_content = file.read() if file.content_type and file.content_type.startswith('application/pdf'): image = image_from_pdf(file_content) else: image = Image.open(io.BytesIO(file_content)) if image is None: return jsonify({"error": "Failed to process image"}), 400 embedding = get_clip_embedding(image) if embedding is None: return jsonify({"error": "Failed to generate embedding"}), 400 k = min(3, len(labels)) D, I = index.search(np.array([embedding]), k=k) if len(labels) > 0 and I[0][0] < len(labels): similarity = 1 - D[0][0] confidence_threshold = 0.35 best_match = labels[I[0][0]] matches = [] for i in range(min(k, len(D[0]))): if I[0][i] < len(labels): sim = 1 - D[0][i] matches.append({"category": labels[I[0][i]], "similarity": round(sim, 3)}) # Save classified document to SQLite if similarity >= confidence_threshold: saved_filename = save_uploaded_file(file_content, file.filename) ocr_text = extract_text(image) document_id = str(uuid.uuid4()) conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() cursor.execute(''' INSERT INTO documents (id, filename, original_filename, category, similarity, ocr_text, upload_date, file_path) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ''', (document_id, saved_filename, file.filename, best_match, round(similarity, 3), ocr_text, datetime.now().isoformat(), os.path.join(UPLOADS_DIR, saved_filename))) conn.commit() conn.close() return jsonify({ "status": "success", "category": best_match, "similarity": round(similarity, 3), "confidence": "high", "matches": matches, "document_saved": True, "document_id": document_id }) else: return jsonify({ "status": "low_confidence", "category": best_match, "similarity": round(similarity, 3), "confidence": "low", "matches": matches, "document_saved": False }) return jsonify({"error": "Document not recognized"}), 400 except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/api/documents", methods=["GET"]) def get_all_documents(): # Verify token auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith('Bearer '): return jsonify({"error": "Missing or invalid token"}), 401 token = auth_header.split(' ')[1] username = verify_token(token) if not username: return jsonify({"error": "Invalid token"}), 401 conn = sqlite3.connect(DATABASE_PATH) cursor = conn.cursor() cursor.execute('SELECT * FROM documents ORDER BY upload_date DESC') documents = [] for row in cursor.fetchall(): documents.append({ "id": row[0], "filename": row[1], "original_filename": row[2], "category": row[3], "similarity": row[4], "ocr_text": row[5], "upload_date": row[6], "file_path": row[7] }) conn.close() return jsonify({"documents": documents, "count": len(documents)}) if __name__ == "__main__": app.run(host="0.0.0.0", port=7860, debug=True)