s-ahal commited on
Commit
7e1e741
·
verified ·
1 Parent(s): 8bb6a0e

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +22 -22
server.py CHANGED
@@ -5,8 +5,7 @@ from flask_cors import CORS
5
  import io
6
  import os
7
  from PIL import Image
8
- from transformers import CLIPProcessor, CLIPModel
9
- import numpy as np
10
 
11
  # Define the MIDM model
12
  class MIDM(nn.Module):
@@ -28,19 +27,18 @@ app = Flask(__name__, static_folder='static', template_folder='templates')
28
  CORS(app)
29
 
30
  # Load models once when the app starts to avoid reloading for each request
31
- processor = None
32
- clip_model = None
33
  model = None
34
 
35
  def load_models():
36
- global processor, clip_model, model
37
 
38
- # Load CLIP model and processor
39
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch14")
40
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch14")
41
 
42
  # Initialize MIDM model
43
- input_dim = 10 # Using first 10 features as in your notebook
44
  hidden_dim = 64
45
  output_dim = 1
46
  model = MIDM(input_dim, hidden_dim, output_dim)
@@ -49,17 +47,19 @@ def load_models():
49
  # model.load_state_dict(torch.load('path/to/your/model.pth'))
50
  model.eval()
51
 
52
- # Function to get image features using CLIP
53
- def get_image_features(image):
54
  """
55
- Extracts image features using the CLIP model.
56
  """
57
- # Preprocess the image and get features
58
- inputs = processor(images=image, return_tensors="pt")
59
- # Only use the image encoder to get the image features
 
60
  with torch.no_grad():
61
- image_features = clip_model.get_image_features(**inputs)
62
- return image_features
 
63
 
64
  @app.route('/')
65
  def index():
@@ -68,7 +68,7 @@ def index():
68
  @app.route('/api/check-membership', methods=['POST'])
69
  def check_membership():
70
  # Ensure models are loaded
71
- if processor is None or clip_model is None or model is None:
72
  load_models()
73
 
74
  if 'image' not in request.files:
@@ -79,12 +79,12 @@ def check_membership():
79
  file = request.files['image']
80
  image_bytes = file.read()
81
  image = Image.open(io.BytesIO(image_bytes))
82
-
83
- # Get image features using CLIP
84
- image_features = get_image_features(image)
85
 
86
  # Preprocess the features for MIDM model
87
- processed_features = image_features.reshape(1, -1)[:, :10] # Select first 10 features
88
 
89
  # Perform inference
90
  with torch.no_grad():
 
5
  import io
6
  import os
7
  from PIL import Image
8
+ from diffusers import StableDiffusionPipeline
 
9
 
10
  # Define the MIDM model
11
  class MIDM(nn.Module):
 
27
  CORS(app)
28
 
29
  # Load models once when the app starts to avoid reloading for each request
30
+ stable_diff_pipe = None
 
31
  model = None
32
 
33
  def load_models():
34
+ global stable_diff_pipe, model
35
 
36
+ # Load Stable Diffusion model pipeline
37
+ stable_diff_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4-original")
38
+ stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
39
 
40
  # Initialize MIDM model
41
+ input_dim = 10 # Example dimension, adjust based on how you process the features
42
  hidden_dim = 64
43
  output_dim = 1
44
  model = MIDM(input_dim, hidden_dim, output_dim)
 
47
  # model.load_state_dict(torch.load('path/to/your/model.pth'))
48
  model.eval()
49
 
50
+ # Function to extract features from the image using Stable Diffusion
51
+ def extract_image_features(image):
52
  """
53
+ Extracts image features using the Stable Diffusion pipeline.
54
  """
55
+ # Preprocess the image and get the feature vector
56
+ image_input = stable_diff_pipe.feature_extractor(image, return_tensors="pt").pixel_values.to(stable_diff_pipe.device)
57
+
58
+ # Generate the image embedding using the model
59
  with torch.no_grad():
60
+ generated_features = stable_diff_pipe.vae.encode(image_input).latent_dist.mean
61
+
62
+ return generated_features
63
 
64
  @app.route('/')
65
  def index():
 
68
  @app.route('/api/check-membership', methods=['POST'])
69
  def check_membership():
70
  # Ensure models are loaded
71
+ if stable_diff_pipe is None or model is None:
72
  load_models()
73
 
74
  if 'image' not in request.files:
 
79
  file = request.files['image']
80
  image_bytes = file.read()
81
  image = Image.open(io.BytesIO(image_bytes))
82
+
83
+ # Get image features using Stable Diffusion
84
+ image_features = extract_image_features(image)
85
 
86
  # Preprocess the features for MIDM model
87
+ processed_features = image_features.reshape(1, -1)[:, :10] # Select first 10 features (example)
88
 
89
  # Perform inference
90
  with torch.no_grad():