import torch import torch.nn as nn from flask import Flask, request, jsonify, render_template from flask_cors import CORS import io import os from PIL import Image from diffusers import StableDiffusionPipeline import os token = os.getenv("HF_TOKEN") # Define the MIDM model class MIDM(nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(MIDM, self).__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, output_dim) self.sigmoid = nn.Sigmoid() def forward(self, x): out = self.fc1(x) out = self.relu(out) out = self.fc2(out) out = self.sigmoid(out) return out app = Flask(__name__, static_folder='static', template_folder='templates') CORS(app) # Load models once when the app starts to avoid reloading for each request stable_diff_pipe = None model = None def load_models(): global stable_diff_pipe, model # Load Stable Diffusion model pipeline stable_diff_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu") # Initialize MIDM model input_dim = 10 # Example dimension, adjust based on how you process the features hidden_dim = 64 output_dim = 1 model = MIDM(input_dim, hidden_dim, output_dim) # For a real application, you would load your trained weights here # model.load_state_dict(torch.load('path/to/your/model.pth')) model.eval() # Function to extract features from the image using Stable Diffusion def extract_image_features(image): """ Extracts image features using the Stable Diffusion pipeline. """ # Preprocess the image and get the feature vector image_input = stable_diff_pipe.feature_extractor(image, return_tensors="pt").pixel_values.to(stable_diff_pipe.device) # Generate the image embedding using the model with torch.no_grad(): generated_features = stable_diff_pipe.vae.encode(image_input).latent_dist.mean return generated_features @app.route('/') def index(): return render_template('index.html') @app.route('/api/check-membership', methods=['POST']) def check_membership(): # Ensure models are loaded if stable_diff_pipe is None or model is None: load_models() if 'image' not in request.files: return jsonify({'error': 'No image found in request'}), 400 try: # Get the image from the request file = request.files['image'] image_bytes = file.read() image = Image.open(io.BytesIO(image_bytes)) # Get image features using Stable Diffusion image_features = extract_image_features(image) # Preprocess the features for MIDM model processed_features = image_features.reshape(1, -1)[:, :10] # Select first 10 features (example) # Perform inference with torch.no_grad(): output = model(processed_features) probability = output.item() predicted = int(output > 0.5) return jsonify({ 'probability': probability, 'predicted_class': predicted, 'message': f"Predicted membership probability: {probability}", 'is_in_training_data': "Likely" if predicted == 1 else "Unlikely" }) except Exception as e: return jsonify({'error': str(e)}), 500 if __name__ == '__main__': port = int(os.environ.get('PORT', 7860)) app.run(host='0.0.0.0', port=port)