s-ahal commited on
Commit
ff13394
·
verified ·
1 Parent(s): 7acc52e

Upload 6 files

Browse files
Files changed (6) hide show
  1. Dockerfile.txt +20 -0
  2. requirements.txt +5 -0
  3. server.py +107 -0
  4. static/script.js +87 -0
  5. static/style.css +165 -0
  6. templates/index.html +66 -0
Dockerfile.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Set working directory
4
+ WORKDIR /code
5
+
6
+ # Install OS-level dependencies
7
+ RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
8
+
9
+ # Install Python dependencies
10
+ COPY requirements.txt .
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Copy app files
14
+ COPY . .
15
+
16
+ # Expose the port
17
+ EXPOSE 7860
18
+
19
+ # Run the Flask app
20
+ CMD ["python", "server.py"]
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ flask
4
+ flask-cors
5
+ pillow
server.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from flask import Flask, request, jsonify, render_template
4
+ from flask_cors import CORS
5
+ import io
6
+ import os
7
+ from PIL import Image
8
+ from transformers import CLIPProcessor, CLIPModel
9
+ import numpy as np
10
+
11
+ # Define the MIDM model
12
+ class MIDM(nn.Module):
13
+ def __init__(self, input_dim, hidden_dim, output_dim):
14
+ super(MIDM, self).__init__()
15
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
16
+ self.relu = nn.ReLU()
17
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
18
+ self.sigmoid = nn.Sigmoid()
19
+
20
+ def forward(self, x):
21
+ out = self.fc1(x)
22
+ out = self.relu(out)
23
+ out = self.fc2(out)
24
+ out = self.sigmoid(out)
25
+ return out
26
+
27
+ app = Flask(__name__, static_folder='static', template_folder='templates')
28
+ CORS(app)
29
+
30
+ # Load models once when the app starts to avoid reloading for each request
31
+ processor = None
32
+ clip_model = None
33
+ model = None
34
+
35
+ def load_models():
36
+ global processor, clip_model, model
37
+
38
+ # Load CLIP model and processor
39
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
40
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
41
+
42
+ # Initialize MIDM model
43
+ input_dim = 10 # Using first 10 features as in your notebook
44
+ hidden_dim = 64
45
+ output_dim = 1
46
+ model = MIDM(input_dim, hidden_dim, output_dim)
47
+
48
+ # For a real application, you would load your trained weights here
49
+ # model.load_state_dict(torch.load('path/to/your/model.pth'))
50
+ model.eval()
51
+
52
+ # Function to get image features using CLIP
53
+ def get_image_features(image):
54
+ """
55
+ Extracts image features using the CLIP model.
56
+ """
57
+ # Preprocess the image and get features
58
+ inputs = processor(images=image, return_tensors="pt")
59
+ # Only use the image encoder to get the image features
60
+ with torch.no_grad():
61
+ image_features = clip_model.get_image_features(**inputs)
62
+ return image_features
63
+
64
+ @app.route('/')
65
+ def index():
66
+ return render_template('index.html')
67
+
68
+ @app.route('/api/check-membership', methods=['POST'])
69
+ def check_membership():
70
+ # Ensure models are loaded
71
+ if processor is None or clip_model is None or model is None:
72
+ load_models()
73
+
74
+ if 'image' not in request.files:
75
+ return jsonify({'error': 'No image found in request'}), 400
76
+
77
+ try:
78
+ # Get the image from the request
79
+ file = request.files['image']
80
+ image_bytes = file.read()
81
+ image = Image.open(io.BytesIO(image_bytes))
82
+
83
+ # Get image features using CLIP
84
+ image_features = get_image_features(image)
85
+
86
+ # Preprocess the features for MIDM model
87
+ processed_features = image_features.reshape(1, -1)[:, :10] # Select first 10 features
88
+
89
+ # Perform inference
90
+ with torch.no_grad():
91
+ output = model(processed_features)
92
+ probability = output.item()
93
+ predicted = int(output > 0.5)
94
+
95
+ return jsonify({
96
+ 'probability': probability,
97
+ 'predicted_class': predicted,
98
+ 'message': f"Predicted membership probability: {probability}",
99
+ 'is_in_training_data': "Likely" if predicted == 1 else "Unlikely"
100
+ })
101
+
102
+ except Exception as e:
103
+ return jsonify({'error': str(e)}), 500
104
+
105
+ if __name__ == '__main__':
106
+ port = int(os.environ.get('PORT', 7860))
107
+ app.run(host='0.0.0.0', port=port)
static/script.js ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ document.addEventListener('DOMContentLoaded', function() {
2
+ const imageUpload = document.getElementById('image-upload');
3
+ const previewContainer = document.getElementById('preview-container');
4
+ const imagePreview = document.getElementById('image-preview');
5
+ const uploadPlaceholder = document.getElementById('upload-placeholder');
6
+ const submitButton = document.getElementById('submit-button');
7
+ const uploadForm = document.getElementById('upload-form');
8
+ const errorMessage = document.getElementById('error-message');
9
+ const resultContainer = document.getElementById('result-container');
10
+ const resultMessage = document.getElementById('result-message');
11
+ const membershipStatus = document.getElementById('membership-status');
12
+ const probabilityFill = document.getElementById('probability-fill');
13
+ const probabilityText = document.getElementById('probability-text');
14
+ const loading = document.getElementById('loading');
15
+
16
+ let selectedFile = null;
17
+
18
+ // Handle image selection
19
+ imageUpload.addEventListener('change', function(e) {
20
+ selectedFile = e.target.files[0];
21
+
22
+ if (selectedFile) {
23
+ const reader = new FileReader();
24
+
25
+ reader.onload = function(e) {
26
+ imagePreview.src = e.target.result;
27
+ previewContainer.classList.remove('hidden');
28
+ uploadPlaceholder.classList.add('hidden');
29
+ submitButton.disabled = false;
30
+ errorMessage.classList.add('hidden');
31
+ resultContainer.classList.add('hidden');
32
+ };
33
+
34
+ reader.readAsDataURL(selectedFile);
35
+ }
36
+ });
37
+
38
+ // Handle form submission
39
+ uploadForm.addEventListener('submit', function(e) {
40
+ e.preventDefault();
41
+
42
+ if (!selectedFile) {
43
+ errorMessage.textContent = 'Please select an image first';
44
+ errorMessage.classList.remove('hidden');
45
+ return;
46
+ }
47
+
48
+ // Show loading indicator
49
+ loading.classList.remove('hidden');
50
+ submitButton.disabled = true;
51
+ errorMessage.classList.add('hidden');
52
+ resultContainer.classList.add('hidden');
53
+
54
+ const formData = new FormData();
55
+ formData.append('image', selectedFile);
56
+
57
+ fetch('/api/check-membership', {
58
+ method: 'POST',
59
+ body: formData
60
+ })
61
+ .then(response => {
62
+ if (!response.ok) {
63
+ throw new Error(`Server responded with ${response.status}`);
64
+ }
65
+ return response.json();
66
+ })
67
+ .then(data => {
68
+ // Display results
69
+ resultMessage.textContent = data.message;
70
+ membershipStatus.innerHTML = `This image is <strong>${data.is_in_training_data}</strong> in the model's training data.`;
71
+
72
+ const probability = data.probability * 100;
73
+ probabilityFill.style.width = `${probability}%`;
74
+ probabilityText.textContent = `${probability.toFixed(2)}%`;
75
+
76
+ resultContainer.classList.remove('hidden');
77
+ })
78
+ .catch(error => {
79
+ errorMessage.textContent = `Error: ${error.message}`;
80
+ errorMessage.classList.remove('hidden');
81
+ })
82
+ .finally(() => {
83
+ loading.classList.add('hidden');
84
+ submitButton.disabled = false;
85
+ });
86
+ });
87
+ });
static/style.css ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ * {
2
+ box-sizing: border-box;
3
+ margin: 0;
4
+ padding: 0;
5
+ }
6
+
7
+ body {
8
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif;
9
+ line-height: 1.6;
10
+ color: #333;
11
+ background-color: #f5f5f5;
12
+ }
13
+
14
+ .app {
15
+ max-width: 800px;
16
+ margin: 0 auto;
17
+ padding: 20px;
18
+ }
19
+
20
+ .app-header {
21
+ text-align: center;
22
+ margin-bottom: 30px;
23
+ }
24
+
25
+ .app-header h1 {
26
+ margin-bottom: 10px;
27
+ color: #2c3e50;
28
+ }
29
+
30
+ .upload-container {
31
+ margin-bottom: 20px;
32
+ }
33
+
34
+ .upload-label {
35
+ display: block;
36
+ cursor: pointer;
37
+ }
38
+
39
+ .upload-placeholder {
40
+ border: 2px dashed #ccc;
41
+ border-radius: 8px;
42
+ padding: 60px;
43
+ text-align: center;
44
+ color: #777;
45
+ transition: all 0.3s ease;
46
+ }
47
+
48
+ .upload-placeholder:hover {
49
+ border-color: #4CAF50;
50
+ color: #4CAF50;
51
+ }
52
+
53
+ .image-preview {
54
+ max-width: 100%;
55
+ max-height: 300px;
56
+ border-radius: 8px;
57
+ display: block;
58
+ margin: 0 auto;
59
+ }
60
+
61
+ .file-input {
62
+ display: none;
63
+ }
64
+
65
+ .submit-button {
66
+ display: block;
67
+ width: 100%;
68
+ padding: 12px;
69
+ background-color: #4CAF50;
70
+ color: white;
71
+ border: none;
72
+ border-radius: 4px;
73
+ font-size: 16px;
74
+ cursor: pointer;
75
+ margin-bottom: 20px;
76
+ transition: background-color 0.3s;
77
+ }
78
+
79
+ .submit-button:hover {
80
+ background-color: #45a049;
81
+ }
82
+
83
+ .submit-button:disabled {
84
+ background-color: #cccccc;
85
+ cursor: not-allowed;
86
+ }
87
+
88
+ .error-message {
89
+ color: #e74c3c;
90
+ margin-bottom: 20px;
91
+ padding: 10px;
92
+ background-color: #fadbd8;
93
+ border-radius: 4px;
94
+ }
95
+
96
+ .result-container {
97
+ background-color: #f9f9f9;
98
+ border-radius: 8px;
99
+ padding: 20px;
100
+ margin-top: 20px;
101
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
102
+ }
103
+
104
+ .result-container h2 {
105
+ margin-bottom: 15px;
106
+ color: #2c3e50;
107
+ }
108
+
109
+ .result-message {
110
+ margin-bottom: 15px;
111
+ }
112
+
113
+ .membership-status {
114
+ font-size: 18px;
115
+ margin-bottom: 15px;
116
+ }
117
+
118
+ .probability-bar {
119
+ height: 24px;
120
+ background-color: #e0e0e0;
121
+ border-radius: 12px;
122
+ position: relative;
123
+ overflow: hidden;
124
+ margin-top: 15px;
125
+ }
126
+
127
+ .probability-fill {
128
+ height: 100%;
129
+ background-color: #4CAF50;
130
+ transition: width 0.3s ease;
131
+ width: 0%;
132
+ }
133
+
134
+ .probability-text {
135
+ position: absolute;
136
+ top: 50%;
137
+ left: 50%;
138
+ transform: translate(-50%, -50%);
139
+ color: black;
140
+ font-weight: bold;
141
+ }
142
+
143
+ .loading {
144
+ text-align: center;
145
+ margin: 20px 0;
146
+ }
147
+
148
+ .spinner {
149
+ border: 4px solid #f3f3f3;
150
+ border-top: 4px solid #4CAF50;
151
+ border-radius: 50%;
152
+ width: 40px;
153
+ height: 40px;
154
+ animation: spin 1s linear infinite;
155
+ margin: 0 auto 10px;
156
+ }
157
+
158
+ @keyframes spin {
159
+ 0% { transform: rotate(0deg); }
160
+ 100% { transform: rotate(360deg); }
161
+ }
162
+
163
+ .hidden {
164
+ display: none;
165
+ }
templates/index.html ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Image Membership Inference</title>
7
+ <link rel="stylesheet" href="{{ url_for('static', filename='style.css') }}">
8
+ </head>
9
+ <body>
10
+ <div class="app">
11
+ <header class="app-header">
12
+ <h1>Image Membership Inference</h1>
13
+ <p>Check if an image is likely within the model's training data</p>
14
+ </header>
15
+
16
+ <main>
17
+ <form id="upload-form">
18
+ <div class="upload-container">
19
+ <label for="image-upload" class="upload-label">
20
+ <div id="preview-container" class="hidden">
21
+ <img id="image-preview" src="" alt="Preview">
22
+ </div>
23
+ <div id="upload-placeholder" class="upload-placeholder">
24
+ <span>Click to upload an image</span>
25
+ </div>
26
+ </label>
27
+ <input
28
+ id="image-upload"
29
+ type="file"
30
+ accept="image/*"
31
+ class="file-input"
32
+ />
33
+ </div>
34
+
35
+ <button
36
+ type="submit"
37
+ id="submit-button"
38
+ class="submit-button"
39
+ disabled
40
+ >
41
+ Check Membership
42
+ </button>
43
+ </form>
44
+
45
+ <div id="error-message" class="error-message hidden"></div>
46
+
47
+ <div id="result-container" class="result-container hidden">
48
+ <h2>Result</h2>
49
+ <p id="result-message" class="result-message"></p>
50
+ <p id="membership-status" class="membership-status"></p>
51
+ <div class="probability-bar">
52
+ <div id="probability-fill" class="probability-fill"></div>
53
+ <span id="probability-text" class="probability-text"></span>
54
+ </div>
55
+ </div>
56
+
57
+ <div id="loading" class="loading hidden">
58
+ <div class="spinner"></div>
59
+ <p>Processing image...</p>
60
+ </div>
61
+ </main>
62
+ </div>
63
+
64
+ <script src="{{ url_for('static', filename='script.js') }}"></script>
65
+ </body>
66
+ </html>