s-ahal commited on
Commit
c176b63
·
verified ·
1 Parent(s): af95279

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +3 -7
server.py CHANGED
@@ -41,20 +41,16 @@ def load_models():
41
  stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
42
 
43
  # Initialize MIDM model
44
- input_dim = 10 # Example dimension, adjust based on how you process the features
45
  hidden_dim = 64
46
  output_dim = 1
47
  model = MIDM(input_dim, hidden_dim, output_dim)
48
 
49
- # For a real application, you would load your trained weights here
50
- # model.load_state_dict(torch.load('path/to/your/model.pth'))
51
  model.eval()
52
 
53
  # Function to extract features from the image using Stable Diffusion
54
  def extract_image_features(image):
55
- """
56
- Extracts image features using the Stable Diffusion pipeline.
57
- """
58
  # Preprocess the image and get the feature vector
59
  image_input = stable_diff_pipe.feature_extractor(image, return_tensors="pt").pixel_values.to(stable_diff_pipe.device)
60
 
@@ -87,7 +83,7 @@ def check_membership():
87
  image_features = extract_image_features(image)
88
 
89
  # Preprocess the features for MIDM model
90
- processed_features = image_features.reshape(1, -1)[:, :10] # Select first 10 features (example)
91
 
92
  # Perform inference
93
  with torch.no_grad():
 
41
  stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
42
 
43
  # Initialize MIDM model
44
+ input_dim = 10
45
  hidden_dim = 64
46
  output_dim = 1
47
  model = MIDM(input_dim, hidden_dim, output_dim)
48
 
 
 
49
  model.eval()
50
 
51
  # Function to extract features from the image using Stable Diffusion
52
  def extract_image_features(image):
53
+ #Extracts image features using the Stable Diffusion pipeline.
 
 
54
  # Preprocess the image and get the feature vector
55
  image_input = stable_diff_pipe.feature_extractor(image, return_tensors="pt").pixel_values.to(stable_diff_pipe.device)
56
 
 
83
  image_features = extract_image_features(image)
84
 
85
  # Preprocess the features for MIDM model
86
+ processed_features = image_features.reshape(1, -1)[:, :10] # Select first 10 features
87
 
88
  # Perform inference
89
  with torch.no_grad():