yahiab commited on
Commit
fb8456d
·
1 Parent(s): f311e6e
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -15,22 +15,23 @@ MODEL_LIST = {
15
  # Global variables
16
  current_model = None
17
  current_preprocessor = None
 
18
 
19
  # Load model and preprocessor
20
  def load_model_and_preprocessor(model_name):
21
  """Load model and preprocessor for a given model name."""
22
  global current_model, current_preprocessor
23
- print(f"Loading model and preprocessor for: {model_name}")
24
- current_model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).cuda().eval()
25
  current_preprocessor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
26
- return f"Model {model_name} loaded successfully."
27
 
28
  # Predict function
29
  def predict(image, model, preprocessor):
30
  """Make a prediction on the given image patch using the loaded model."""
31
  if model is None or preprocessor is None:
32
  raise ValueError("Model and preprocessor are not loaded.")
33
- inputs = preprocessor(images=image, return_tensors="pt").to("cuda")
34
  with torch.no_grad():
35
  outputs = model(**inputs)
36
  predicted_class = torch.argmax(outputs.logits, dim=1).item()
 
15
  # Global variables
16
  current_model = None
17
  current_preprocessor = None
18
+ device = "cuda" if torch.cuda.is_available() else "cpu" # Dynamically set device
19
 
20
  # Load model and preprocessor
21
  def load_model_and_preprocessor(model_name):
22
  """Load model and preprocessor for a given model name."""
23
  global current_model, current_preprocessor
24
+ print(f"Loading model and preprocessor for: {model_name} on {device}")
25
+ current_model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).to(device).eval()
26
  current_preprocessor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
27
+ return f"Model {model_name} loaded successfully on {device}."
28
 
29
  # Predict function
30
  def predict(image, model, preprocessor):
31
  """Make a prediction on the given image patch using the loaded model."""
32
  if model is None or preprocessor is None:
33
  raise ValueError("Model and preprocessor are not loaded.")
34
+ inputs = preprocessor(images=image, return_tensors="pt").to(device)
35
  with torch.no_grad():
36
  outputs = model(**inputs)
37
  predicted_class = torch.argmax(outputs.logits, dim=1).item()