infer / server.py
s-ahal's picture
Update server.py
701ce65 verified
raw
history blame
3.43 kB
import torch
import torch.nn as nn
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
import io
import os
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import numpy as np
# Define the MIDM model
class MIDM(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(MIDM, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.sigmoid(out)
return out
app = Flask(__name__, static_folder='static', template_folder='templates')
CORS(app)
# Load models once when the app starts to avoid reloading for each request
processor = None
clip_model = None
model = None
def load_models():
global processor, clip_model, model
# Load CLIP model and processor
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# Initialize MIDM model
input_dim = 10 # Using first 10 features as in your notebook
hidden_dim = 64
output_dim = 1
model = MIDM(input_dim, hidden_dim, output_dim)
# For a real application, you would load your trained weights here
# model.load_state_dict(torch.load('path/to/your/model.pth'))
model.eval()
# Function to get image features using CLIP
def get_image_features(image):
"""
Extracts image features using the CLIP model.
"""
# Preprocess the image and get features
inputs = processor(images=image, return_tensors="pt")
# Only use the image encoder to get the image features
with torch.no_grad():
image_features = clip_model.get_image_features(**inputs)
return image_features
@app.route('/')
def index():
return render_template('index.html')
@app.route('/api/check-membership', methods=['POST'])
def check_membership():
# Ensure models are loaded
if processor is None or clip_model is None or model is None:
load_models()
if 'image' not in request.files:
return jsonify({'error': 'No image found in request'}), 400
try:
# Get the image from the request
file = request.files['image']
image_bytes = file.read()
image = Image.open(io.BytesIO(image_bytes))
# Get image features using CLIP
image_features = get_image_features(image)
# Preprocess the features for MIDM model
processed_features = image_features.reshape(1, -1)[:, :10] # Select first 10 features
# Perform inference
with torch.no_grad():
output = model(processed_features)
probability = output.item()
predicted = int(output > 0.5)
return jsonify({
'probability': probability,
'predicted_class': predicted,
'message': f"Predicted membership probability: {probability}",
'is_in_training_data': "Likely" if predicted == 1 else "Unlikely"
})
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port)