s-ahal commited on
Commit
9b256ac
·
verified ·
1 Parent(s): c176b63

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +10 -4
server.py CHANGED
@@ -33,11 +33,11 @@ CORS(app)
33
  stable_diff_pipe = None
34
  model = None
35
 
36
- def load_models():
37
  global stable_diff_pipe, model
38
 
39
  # Load Stable Diffusion model pipeline
40
- stable_diff_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
41
  stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
42
 
43
  # Initialize MIDM model
@@ -66,9 +66,15 @@ def index():
66
 
67
  @app.route('/api/check-membership', methods=['POST'])
68
  def check_membership():
69
- # Ensure models are loaded
 
 
 
70
  if stable_diff_pipe is None or model is None:
71
- load_models()
 
 
 
72
 
73
  if 'image' not in request.files:
74
  return jsonify({'error': 'No image found in request'}), 400
 
33
  stable_diff_pipe = None
34
  model = None
35
 
36
+ def load_models(model_name="CompVis/stable-diffusion-v1-4"):
37
  global stable_diff_pipe, model
38
 
39
  # Load Stable Diffusion model pipeline
40
+ stable_diff_pipe = StableDiffusionPipeline.from_pretrained(model_name)
41
  stable_diff_pipe.to("cuda" if torch.cuda.is_available() else "cpu")
42
 
43
  # Initialize MIDM model
 
66
 
67
  @app.route('/api/check-membership', methods=['POST'])
68
  def check_membership():
69
+ # Get the model name from the request
70
+ model_name = request.form.get('model', 'CompVis/stable-diffusion-v1-4')
71
+
72
+ # Ensure models are loaded with the selected model
73
  if stable_diff_pipe is None or model is None:
74
+ load_models(model_name)
75
+ elif stable_diff_pipe.name_or_path != model_name:
76
+ # Reload the model if a different one is selected
77
+ load_models(model_name)
78
 
79
  if 'image' not in request.files:
80
  return jsonify({'error': 'No image found in request'}), 400