Spaces:
Sleeping
Sleeping
yahiab
commited on
Commit
·
fb8456d
1
Parent(s):
f311e6e
fix
Browse files
app.py
CHANGED
@@ -15,22 +15,23 @@ MODEL_LIST = {
|
|
15 |
# Global variables
|
16 |
current_model = None
|
17 |
current_preprocessor = None
|
|
|
18 |
|
19 |
# Load model and preprocessor
|
20 |
def load_model_and_preprocessor(model_name):
|
21 |
"""Load model and preprocessor for a given model name."""
|
22 |
global current_model, current_preprocessor
|
23 |
-
print(f"Loading model and preprocessor for: {model_name}")
|
24 |
-
current_model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).
|
25 |
current_preprocessor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
|
26 |
-
return f"Model {model_name} loaded successfully."
|
27 |
|
28 |
# Predict function
|
29 |
def predict(image, model, preprocessor):
|
30 |
"""Make a prediction on the given image patch using the loaded model."""
|
31 |
if model is None or preprocessor is None:
|
32 |
raise ValueError("Model and preprocessor are not loaded.")
|
33 |
-
inputs = preprocessor(images=image, return_tensors="pt").to(
|
34 |
with torch.no_grad():
|
35 |
outputs = model(**inputs)
|
36 |
predicted_class = torch.argmax(outputs.logits, dim=1).item()
|
|
|
15 |
# Global variables
|
16 |
current_model = None
|
17 |
current_preprocessor = None
|
18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu" # Dynamically set device
|
19 |
|
20 |
# Load model and preprocessor
|
21 |
def load_model_and_preprocessor(model_name):
|
22 |
"""Load model and preprocessor for a given model name."""
|
23 |
global current_model, current_preprocessor
|
24 |
+
print(f"Loading model and preprocessor for: {model_name} on {device}")
|
25 |
+
current_model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).to(device).eval()
|
26 |
current_preprocessor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
|
27 |
+
return f"Model {model_name} loaded successfully on {device}."
|
28 |
|
29 |
# Predict function
|
30 |
def predict(image, model, preprocessor):
|
31 |
"""Make a prediction on the given image patch using the loaded model."""
|
32 |
if model is None or preprocessor is None:
|
33 |
raise ValueError("Model and preprocessor are not loaded.")
|
34 |
+
inputs = preprocessor(images=image, return_tensors="pt").to(device)
|
35 |
with torch.no_grad():
|
36 |
outputs = model(**inputs)
|
37 |
predicted_class = torch.argmax(outputs.logits, dim=1).item()
|