|
import torch |
|
import torch.nn as nn |
|
from flask import Flask, request, jsonify, render_template |
|
from flask_cors import CORS |
|
import io |
|
import os |
|
from PIL import Image |
|
from transformers import CLIPProcessor, CLIPModel |
|
import numpy as np |
|
|
|
|
|
class MIDM(nn.Module): |
|
def __init__(self, input_dim, hidden_dim, output_dim): |
|
super(MIDM, self).__init__() |
|
self.fc1 = nn.Linear(input_dim, hidden_dim) |
|
self.relu = nn.ReLU() |
|
self.fc2 = nn.Linear(hidden_dim, output_dim) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, x): |
|
out = self.fc1(x) |
|
out = self.relu(out) |
|
out = self.fc2(out) |
|
out = self.sigmoid(out) |
|
return out |
|
|
|
app = Flask(__name__, static_folder='static', template_folder='templates') |
|
CORS(app) |
|
|
|
|
|
processor = None |
|
clip_model = None |
|
model = None |
|
|
|
def load_models(): |
|
global processor, clip_model, model |
|
|
|
|
|
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
|
input_dim = 10 |
|
hidden_dim = 64 |
|
output_dim = 1 |
|
model = MIDM(input_dim, hidden_dim, output_dim) |
|
|
|
|
|
|
|
model.eval() |
|
|
|
|
|
def get_image_features(image): |
|
""" |
|
Extracts image features using the CLIP model. |
|
""" |
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
image_features = clip_model.get_image_features(**inputs) |
|
return image_features |
|
|
|
@app.route('/') |
|
def index(): |
|
return render_template('index.html') |
|
|
|
@app.route('/api/check-membership', methods=['POST']) |
|
def check_membership(): |
|
|
|
if processor is None or clip_model is None or model is None: |
|
load_models() |
|
|
|
if 'image' not in request.files: |
|
return jsonify({'error': 'No image found in request'}), 400 |
|
|
|
try: |
|
|
|
file = request.files['image'] |
|
image_bytes = file.read() |
|
image = Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
|
image_features = get_image_features(image) |
|
|
|
|
|
processed_features = image_features.reshape(1, -1)[:, :10] |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(processed_features) |
|
probability = output.item() |
|
predicted = int(output > 0.5) |
|
|
|
return jsonify({ |
|
'probability': probability, |
|
'predicted_class': predicted, |
|
'message': f"Predicted membership probability: {probability}", |
|
'is_in_training_data': "Likely" if predicted == 1 else "Unlikely" |
|
}) |
|
|
|
except Exception as e: |
|
return jsonify({'error': str(e)}), 500 |
|
|
|
if __name__ == '__main__': |
|
port = int(os.environ.get('PORT', 7860)) |
|
app.run(host='0.0.0.0', port=port) |