import torch import torch.nn as nn from flask import Flask, request, jsonify, render_template from flask_cors import CORS import io import os from PIL import Image from transformers import CLIPProcessor, CLIPModel import numpy as np # Define the MIDM model class MIDM(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(MIDM, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, output_dim) self.sigmoid = nn.Sigmoid() def forward(self, x): out = self.fc1(x) out = self.relu(out) out = self.fc2(out) out = self.sigmoid(out) return out app = Flask(__name__, static_folder='static', template_folder='templates') CORS(app) # Load models once when the app starts to avoid reloading for each request processor = None clip_model = None model = None def load_models(): global processor, clip_model, model # Load CLIP model and processor processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") # Initialize MIDM model input_dim = 10 # Using first 10 features as in your notebook hidden_dim = 64 output_dim = 1 model = MIDM(input_dim, hidden_dim, output_dim) # For a real application, you would load your trained weights here # model.load_state_dict(torch.load('path/to/your/model.pth')) model.eval() # Function to get image features using CLIP def get_image_features(image): """ Extracts image features using the CLIP model. """ # Preprocess the image and get features inputs = processor(images=image, return_tensors="pt") # Only use the image encoder to get the image features with torch.no_grad(): image_features = clip_model.get_image_features(**inputs) return image_features @app.route('/') def index(): return render_template('index.html') @app.route('/api/check-membership', methods=['POST']) def check_membership(): # Ensure models are loaded if processor is None or clip_model is None or model is None: load_models() if 'image' not in request.files: return jsonify({'error': 'No image found in request'}), 400 try: # Get the image from the request file = request.files['image'] image_bytes = file.read() image = Image.open(io.BytesIO(image_bytes)) # Get image features using CLIP image_features = get_image_features(image) # Preprocess the features for MIDM model processed_features = image_features.reshape(1, -1)[:, :10] # Select first 10 features # Perform inference with torch.no_grad(): output = model(processed_features) probability = output.item() predicted = int(output > 0.5) return jsonify({ 'probability': probability, 'predicted_class': predicted, 'message': f"Predicted membership probability: {probability}", 'is_in_training_data': "Likely" if predicted == 1 else "Unlikely" }) except Exception as e: return jsonify({'error': str(e)}), 500 if __name__ == '__main__': port = int(os.environ.get('PORT', 7860)) app.run(host='0.0.0.0', port=port)