NORLIE JHON MALAGDAO commited on
Commit
8b884e6
·
verified ·
1 Parent(s): d44395b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -63
app.py CHANGED
@@ -4,98 +4,151 @@ import numpy as np
4
  import os
5
  import PIL
6
  import tensorflow as tf
 
7
  from tensorflow import keras
8
  from tensorflow.keras import layers
9
  from tensorflow.keras.models import Sequential
 
10
  from PIL import Image
11
  import gdown
12
  import zipfile
13
  import pathlib
14
 
15
- # Download and extract dataset
16
  gdrive_url = 'https://drive.google.com/file/d/1HjHYlQyRz5oWt8kehkt1TiOGRRlKFsv8/view?usp=drive_link'
 
 
17
  file_id = gdrive_url.split('/d/')[1].split('/view')[0]
18
  direct_download_url = f'https://drive.google.com/uc?id={file_id}'
 
 
19
  local_zip_file = 'file.zip'
 
 
20
  gdown.download(direct_download_url, local_zip_file, quiet=False)
 
 
21
  extracted_path = 'extracted_files'
 
 
22
  try:
23
  with zipfile.ZipFile(local_zip_file, 'r') as zip_ref:
24
  zip_ref.extractall(extracted_path)
25
  print("Extraction successful!")
26
  except zipfile.BadZipFile:
27
  print("Error: The downloaded file is not a valid ZIP file.")
 
 
28
  os.remove(local_zip_file)
29
- data_dir = pathlib.Path(extracted_path) / 'Pest_Dataset'
30
 
31
- # Data loading and preprocessing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  img_height, img_width = 180, 180
33
  batch_size = 32
34
  train_ds = tf.keras.preprocessing.image_dataset_from_directory(
35
- data_dir,
36
- validation_split=0.2,
37
- subset="training",
38
- seed=123,
39
- image_size=(img_height, img_width),
40
- batch_size=batch_size
41
  )
 
42
  val_ds = tf.keras.preprocessing.image_dataset_from_directory(
43
- data_dir,
44
- validation_split=0.2,
45
- subset="validation",
46
- seed=123,
47
- image_size=(img_height, img_width),
48
- batch_size=batch_size
49
  )
 
50
  class_names = train_ds.class_names
 
 
 
 
 
 
 
 
 
51
 
52
- # Data augmentation
53
  data_augmentation = keras.Sequential(
54
- [
55
- layers.RandomFlip("horizontal", input_shape=(img_height, img_width, 3)),
56
- layers.RandomRotation(0.1),
57
- layers.RandomZoom(0.1),
58
- layers.RandomBrightness(0.2),
59
- layers.RandomContrast(0.2),
60
- ]
 
 
 
61
  )
62
 
63
- # Model definition
 
 
 
 
 
 
 
64
  num_classes = len(class_names)
65
  model = Sequential([
66
- data_augmentation,
67
- layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
68
- layers.Conv2D(16, 3, padding='same', activation='relu'),
69
- layers.MaxPooling2D(),
70
- layers.Conv2D(32, 3, padding='same', activation='relu'),
71
- layers.MaxPooling2D(),
72
- layers.Conv2D(64, 3, padding='same', activation='relu'),
73
- layers.MaxPooling2D(),
74
- layers.Conv2D(128, 3, padding='same', activation='relu'),
75
- layers.MaxPooling2D(),
76
- layers.Dropout(0.5),
77
- layers.Flatten(),
78
- layers.Dense(256, activation='relu'),
79
- layers.Dense(num_classes, activation='softmax', name="outputs")
80
  ])
81
 
82
- optimizer = keras.optimizers.Adam(learning_rate=0.001)
83
- lr_scheduler = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3)
84
- early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
85
-
86
- model.compile(optimizer=optimizer,
87
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
88
  metrics=['accuracy'])
89
 
90
  model.summary()
91
 
92
- # Train the model
93
- epochs = 15
 
 
 
 
 
94
  history = model.fit(
95
- train_ds,
96
- validation_data=val_ds,
97
- epochs=epochs,
98
- callbacks=[lr_scheduler, early_stopping]
99
  )
100
 
101
  # Define category descriptions
@@ -114,32 +167,37 @@ category_descriptions = {
114
  "Weevils": "Weevils are a type of beetle with a long snout, known for being pests to crops and stored grains."
115
  }
116
 
117
- # Prediction function
118
  def predict_image(img):
119
  img = np.array(img)
120
  img_resized = tf.image.resize(img, (180, 180))
121
  img_4d = tf.expand_dims(img_resized, axis=0)
122
  prediction = model.predict(img_4d)[0]
123
- top_3_indices = prediction.argsort()[-3:][::-1]
124
- results = {}
125
- for i in top_3_indices:
126
- class_name = class_names[i]
127
- results[class_name] = f"{float(prediction[i]):.2f} - {category_descriptions[class_name]}"
128
- return results
129
-
130
- # Gradio interface setup
131
  image = gr.Image()
132
- label = gr.Label(num_top_classes=3)
 
 
133
  custom_css = """
134
- body {background-color: #f5f5f5;}
135
- .gradio-container {border: 1px solid #ccc; border-radius: 10px; padding: 20px;}
 
 
 
 
 
136
  """
137
 
138
  gr.Interface(
139
  fn=predict_image,
140
  inputs=image,
141
  outputs=label,
142
- title="Agricultural Pest Image Classification",
143
- description="Identify 12 types of agricultural pests from images. This model was trained on a dataset from Kaggle.",
144
  css=custom_css
145
  ).launch(debug=True)
 
4
  import os
5
  import PIL
6
  import tensorflow as tf
7
+
8
  from tensorflow import keras
9
  from tensorflow.keras import layers
10
  from tensorflow.keras.models import Sequential
11
+
12
  from PIL import Image
13
  import gdown
14
  import zipfile
15
  import pathlib
16
 
17
+ # Define the Google Drive shareable link
18
  gdrive_url = 'https://drive.google.com/file/d/1HjHYlQyRz5oWt8kehkt1TiOGRRlKFsv8/view?usp=drive_link'
19
+
20
+ # Extract the file ID from the URL
21
  file_id = gdrive_url.split('/d/')[1].split('/view')[0]
22
  direct_download_url = f'https://drive.google.com/uc?id={file_id}'
23
+
24
+ # Define the local filename to save the ZIP file
25
  local_zip_file = 'file.zip'
26
+
27
+ # Download the ZIP file
28
  gdown.download(direct_download_url, local_zip_file, quiet=False)
29
+
30
+ # Directory to extract files
31
  extracted_path = 'extracted_files'
32
+
33
+ # Verify if the downloaded file is a ZIP file and extract it
34
  try:
35
  with zipfile.ZipFile(local_zip_file, 'r') as zip_ref:
36
  zip_ref.extractall(extracted_path)
37
  print("Extraction successful!")
38
  except zipfile.BadZipFile:
39
  print("Error: The downloaded file is not a valid ZIP file.")
40
+
41
+ # Optionally, you can delete the ZIP file after extraction
42
  os.remove(local_zip_file)
 
43
 
44
+ # Convert the extracted directory path to a pathlib.Path object
45
+ data_dir = pathlib.Path(extracted_path)
46
+
47
+ # Print the directory structure to debug
48
+ for root, dirs, files in os.walk(extracted_path):
49
+ level = root.replace(extracted_path, '').count(os.sep)
50
+ indent = ' ' * 4 * (level)
51
+ print(f"{indent}{os.path.basename(root)}/")
52
+ subindent = ' ' * 4 * (level + 1)
53
+ for f in files:
54
+ print(f"{subindent}{f}")
55
+
56
+ import pathlib
57
+ # Path to the dataset directory
58
+ data_dir = pathlib.Path('extracted_files/Pest_Dataset')
59
+ data_dir = pathlib.Path(data_dir)
60
+
61
+ bees = list(data_dir.glob('bees/*'))
62
+ print(bees[0])
63
+ PIL.Image.open(str(bees[0]))
64
+
65
  img_height, img_width = 180, 180
66
  batch_size = 32
67
  train_ds = tf.keras.preprocessing.image_dataset_from_directory(
68
+ data_dir,
69
+ validation_split=0.2,
70
+ subset="training",
71
+ seed=123,
72
+ image_size=(img_height, img_width),
73
+ batch_size=batch_size
74
  )
75
+
76
  val_ds = tf.keras.preprocessing.image_dataset_from_directory(
77
+ data_dir,
78
+ validation_split=0.2,
79
+ subset="validation",
80
+ seed=123,
81
+ image_size=(img_height, img_width),
82
+ batch_size=batch_size
83
  )
84
+
85
  class_names = train_ds.class_names
86
+ print(class_names)
87
+
88
+ plt.figure(figsize=(10, 10))
89
+ for images, labels in train_ds.take(1):
90
+ for i in range(9):
91
+ ax = plt.subplot(3, 3, i + 1)
92
+ plt.imshow(images[i].numpy().astype("uint8"))
93
+ plt.title(class_names[labels[i]])
94
+ plt.axis("off")
95
 
 
96
  data_augmentation = keras.Sequential(
97
+ [
98
+ layers.RandomFlip("horizontal",
99
+ input_shape=(img_height,
100
+ img_width,
101
+ 3)),
102
+ layers.RandomRotation(0.1),
103
+ layers.RandomZoom(0.1),
104
+ layers.RandomContrast(0.1),
105
+ layers.RandomBrightness(0.1)
106
+ ]
107
  )
108
 
109
+ plt.figure(figsize=(10, 10))
110
+ for images, _ in train_ds.take(1):
111
+ for i in range(9):
112
+ augmented_images = data_augmentation(images)
113
+ ax = plt.subplot(3, 3, i + 1)
114
+ plt.imshow(augmented_images[0].numpy().astype("uint8"))
115
+ plt.axis("off")
116
+
117
  num_classes = len(class_names)
118
  model = Sequential([
119
+ data_augmentation,
120
+ layers.Rescaling(1./255),
121
+ layers.Conv2D(32, 3, padding='same', activation='relu'),
122
+ layers.MaxPooling2D(),
123
+ layers.Conv2D(64, 3, padding='same', activation='relu'),
124
+ layers.MaxPooling2D(),
125
+ layers.Conv2D(128, 3, padding='same', activation='relu'),
126
+ layers.MaxPooling2D(),
127
+ layers.Dropout(0.5),
128
+ layers.Flatten(),
129
+ layers.Dense(256, activation='relu'),
130
+ layers.Dropout(0.5),
131
+ layers.Dense(num_classes, activation='softmax', name="outputs")
 
132
  ])
133
 
134
+ model.compile(optimizer='adam',
 
 
 
 
135
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
136
  metrics=['accuracy'])
137
 
138
  model.summary()
139
 
140
+ # Learning rate scheduler
141
+ lr_scheduler = keras.callbacks.LearningRateScheduler(lambda epoch: 1e-3 * 10**(epoch / 20))
142
+
143
+ # Early stopping
144
+ early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
145
+
146
+ epochs = 20
147
  history = model.fit(
148
+ train_ds,
149
+ validation_data=val_ds,
150
+ epochs=epochs,
151
+ callbacks=[lr_scheduler, early_stopping]
152
  )
153
 
154
  # Define category descriptions
 
167
  "Weevils": "Weevils are a type of beetle with a long snout, known for being pests to crops and stored grains."
168
  }
169
 
170
+ # Define the prediction function
171
  def predict_image(img):
172
  img = np.array(img)
173
  img_resized = tf.image.resize(img, (180, 180))
174
  img_4d = tf.expand_dims(img_resized, axis=0)
175
  prediction = model.predict(img_4d)[0]
176
+ predicted_class = np.argmax(prediction)
177
+ predicted_label = class_names[predicted_class]
178
+ predicted_description = category_descriptions[predicted_label]
179
+ return {predicted_label: f"{float(prediction[predicted_class]):.2f} - {predicted_description}"}
180
+
181
+ # Set up Gradio interface
 
 
182
  image = gr.Image()
183
+ label = gr.Label(num_top_classes=1)
184
+
185
+ # Define custom CSS for background image
186
  custom_css = """
187
+ body {
188
+ background-image: url('extracted_files/Pest_Dataset/bees/bees (444).jpg');
189
+ background-size: cover;
190
+ background-repeat: no-repeat;
191
+ background-attachment: fixed;
192
+ color: white;
193
+ }
194
  """
195
 
196
  gr.Interface(
197
  fn=predict_image,
198
  inputs=image,
199
  outputs=label,
200
+ title="Welcome to Agricultural Pest Image Classification",
201
+ description="The image data set used was obtained from Kaggle and has a collection of 12 different types of agricultural pests: Ants, Bees, Beetles, Caterpillars, Earthworms, Earwigs, Grasshoppers, Moths, Slugs, Snails, Wasps, and Weevils",
202
  css=custom_css
203
  ).launch(debug=True)