yahiab commited on
Commit
f311e6e
·
1 Parent(s): 0d11696
Files changed (1) hide show
  1. app.py +23 -18
app.py CHANGED
@@ -12,18 +12,24 @@ MODEL_LIST = {
12
  'convnext': "facebook/convnext-tiny-224",
13
  }
14
 
15
- # Preprocessing transforms
16
- def get_preprocessor(model_name):
17
- extractor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
18
- return extractor
19
 
20
- # Load a model from Hugging Face
21
- def load_model(model_name):
22
- model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).cuda().eval()
23
- return model
 
 
 
 
24
 
25
- # Function to make predictions
26
  def predict(image, model, preprocessor):
 
 
 
27
  inputs = preprocessor(images=image, return_tensors="pt").to("cuda")
28
  with torch.no_grad():
29
  outputs = model(**inputs)
@@ -32,6 +38,7 @@ def predict(image, model, preprocessor):
32
 
33
  # Function to draw a rectangle on the image
34
  def draw_rectangle(image, x, y, size=224):
 
35
  image_pil = image.copy() # Create a copy to avoid modifying the original image
36
  draw = ImageDraw.Draw(image_pil)
37
  x1, y1 = x, y
@@ -41,6 +48,7 @@ def draw_rectangle(image, x, y, size=224):
41
 
42
  # Function to crop the image
43
  def crop_image(image, x, y, size=224):
 
44
  image_np = np.array(image)
45
  h, w, _ = image_np.shape
46
  x = min(max(x, 0), w - size)
@@ -48,10 +56,6 @@ def crop_image(image, x, y, size=224):
48
  cropped = image_np[y:y+size, x:x+size]
49
  return Image.fromarray(cropped)
50
 
51
- # Global variables
52
- current_model = None
53
- current_preprocessor = None
54
-
55
  # Gradio Interface
56
  with gr.Blocks() as demo:
57
  gr.Markdown("## Test Public Models for Coral Classification")
@@ -67,12 +71,9 @@ with gr.Blocks() as demo:
67
  cropped_image = gr.Image(label="Cropped Patch")
68
  label_output = gr.Textbox(label="Predicted Label")
69
 
70
- # Update the current model and preprocessor
71
  def update_model(model_name):
72
- global current_model, current_preprocessor
73
- current_model = load_model(model_name)
74
- current_preprocessor = get_preprocessor(model_name)
75
- return f"Model {model_name} loaded successfully."
76
 
77
  # Update the rectangle and crop the patch
78
  def update_selection(image, x, y):
@@ -82,6 +83,7 @@ with gr.Blocks() as demo:
82
 
83
  # Predict the label from the cropped patch
84
  def predict_from_cropped(cropped):
 
85
  return predict(cropped, current_model, current_preprocessor)
86
 
87
  # Buttons and interactions
@@ -102,4 +104,7 @@ with gr.Blocks() as demo:
102
 
103
  image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])
104
 
 
 
 
105
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
12
  'convnext': "facebook/convnext-tiny-224",
13
  }
14
 
15
+ # Global variables
16
+ current_model = None
17
+ current_preprocessor = None
 
18
 
19
+ # Load model and preprocessor
20
+ def load_model_and_preprocessor(model_name):
21
+ """Load model and preprocessor for a given model name."""
22
+ global current_model, current_preprocessor
23
+ print(f"Loading model and preprocessor for: {model_name}")
24
+ current_model = AutoModelForImageClassification.from_pretrained(MODEL_LIST[model_name]).cuda().eval()
25
+ current_preprocessor = AutoFeatureExtractor.from_pretrained(MODEL_LIST[model_name])
26
+ return f"Model {model_name} loaded successfully."
27
 
28
+ # Predict function
29
  def predict(image, model, preprocessor):
30
+ """Make a prediction on the given image patch using the loaded model."""
31
+ if model is None or preprocessor is None:
32
+ raise ValueError("Model and preprocessor are not loaded.")
33
  inputs = preprocessor(images=image, return_tensors="pt").to("cuda")
34
  with torch.no_grad():
35
  outputs = model(**inputs)
 
38
 
39
  # Function to draw a rectangle on the image
40
  def draw_rectangle(image, x, y, size=224):
41
+ """Draw a rectangle on the image."""
42
  image_pil = image.copy() # Create a copy to avoid modifying the original image
43
  draw = ImageDraw.Draw(image_pil)
44
  x1, y1 = x, y
 
48
 
49
  # Function to crop the image
50
  def crop_image(image, x, y, size=224):
51
+ """Crop a region from the image."""
52
  image_np = np.array(image)
53
  h, w, _ = image_np.shape
54
  x = min(max(x, 0), w - size)
 
56
  cropped = image_np[y:y+size, x:x+size]
57
  return Image.fromarray(cropped)
58
 
 
 
 
 
59
  # Gradio Interface
60
  with gr.Blocks() as demo:
61
  gr.Markdown("## Test Public Models for Coral Classification")
 
71
  cropped_image = gr.Image(label="Cropped Patch")
72
  label_output = gr.Textbox(label="Predicted Label")
73
 
74
+ # Update the model and preprocessor
75
  def update_model(model_name):
76
+ return load_model_and_preprocessor(model_name)
 
 
 
77
 
78
  # Update the rectangle and crop the patch
79
  def update_selection(image, x, y):
 
83
 
84
  # Predict the label from the cropped patch
85
  def predict_from_cropped(cropped):
86
+ print(f"Type of cropped_image before prediction: {type(cropped)}")
87
  return predict(cropped, current_model, current_preprocessor)
88
 
89
  # Buttons and interactions
 
104
 
105
  image_input.change(fn=update_sliders, inputs=image_input, outputs=[x_slider, y_slider])
106
 
107
+ # Initialize model on app start
108
+ demo.load(fn=lambda: load_model_and_preprocessor('beit'), inputs=None, outputs=None)
109
+
110
  demo.launch(server_name="0.0.0.0", server_port=7860)