ANCKEM commited on
Commit
6919ed7
·
verified ·
1 Parent(s): ac6da60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -358
app.py CHANGED
@@ -1,365 +1,19 @@
1
- # Import Data Science Libraries
2
- import gradio as gr
3
- import os
4
- import requests
5
- import gdown
6
- import zipfile
7
- import pandas as pd
8
- from pathlib import Path
9
- from PIL import Image, UnidentifiedImageError
10
- import numpy as np
11
- import tensorflow as tf
12
- from sklearn.model_selection import train_test_split
13
- import itertools
14
- import random
15
-
16
- # Import visualization libraries
17
- import matplotlib.pyplot as plt
18
- import matplotlib.cm as cm
19
- import cv2
20
- import seaborn as sns
21
-
22
- # Tensorflow Libraries
23
- from tensorflow import keras
24
- from tensorflow.keras import layers, models
25
- from tensorflow.keras.preprocessing.image import ImageDataGenerator
26
- from tensorflow.keras.layers import Dense, Dropout
27
- from tensorflow.keras.callbacks import Callback, EarlyStopping, ModelCheckpoint
28
- from tensorflow.keras.optimizers import Adam
29
- from tensorflow.keras.applications import MobileNetV2
30
- from tensorflow.keras import Model
31
-
32
- from keras.layers import Dense, Flatten, Dropout, BatchNormalization
33
-
34
- # System libraries
35
- from pathlib import Path
36
- import os.path
37
-
38
- # Metrics
39
- from sklearn.metrics import classification_report, confusion_matrix
40
-
41
- sns.set(style='darkgrid')
42
-
43
-
44
-
45
- # Seed Everything to reproduce results for future use cases
46
- def seed_everything(seed=42):
47
- # Seed value for TensorFlow
48
- tf.random.set_seed(seed)
49
-
50
- # Seed value for NumPy
51
- np.random.seed(seed)
52
-
53
- # Seed value for Python's random library
54
- random.seed(seed)
55
-
56
- # Force TensorFlow to use single thread
57
- # Multiple threads are a potential source of non-reproducible results.
58
- session_conf = tf.compat.v1.ConfigProto(
59
- intra_op_parallelism_threads=1,
60
- inter_op_parallelism_threads=1
61
- )
62
-
63
- # Make sure that TensorFlow uses a deterministic operation wherever possible
64
- tf.compat.v1.set_random_seed(seed)
65
-
66
- sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
67
- tf.compat.v1.keras.backend.set_session(sess)
68
-
69
- seed_everything()
70
-
71
-
72
-
73
- # URL of the file you want to download
74
- url = "https://raw.githubusercontent.com/mrdbourke/tensorflow-deep-learning/main/extras/helper_functions.py"
75
-
76
- # Send a GET request to the URL
77
- response = requests.get(url)
78
-
79
- # Check if the request was successful (status code 200)
80
- if response.status_code == 200:
81
- # Save the content of the response (the file) to a local file
82
- with open("helper_functions.py", "wb") as f:
83
- f.write(response.content)
84
- print("File downloaded successfully!")
85
- else:
86
- print("Failed to download file")
87
-
88
-
89
- # Import series of helper functions for our notebook
90
- from helper_functions import create_tensorboard_callback, plot_loss_curves, unzip_data, compare_historys, walk_through_dir, pred_and_plot
91
-
92
- BATCH_SIZE = 32
93
- TARGET_SIZE = (224, 224)
94
-
95
- # Define the Google Drive shareable link
96
- gdrive_url = 'https://drive.google.com/file/d/1HjHYlQyRz5oWt8kehkt1TiOGRRlKFsv8/view?usp=drive_link'
97
-
98
- # Extract the file ID from the URL
99
- file_id = gdrive_url.split('/d/')[1].split('/view')[0]
100
- direct_download_url = f'https://drive.google.com/uc?id={file_id}'
101
-
102
- # Define the local filename to save the ZIP file
103
- local_zip_file = 'file.zip'
104
-
105
- # Download the ZIP file
106
- gdown.download(direct_download_url, local_zip_file, quiet=False)
107
-
108
- # Directory to extract files
109
- extracted_path = 'extracted_files'
110
-
111
- # Verify if the downloaded file is a ZIP file and extract it
112
- try:
113
- with zipfile.ZipFile(local_zip_file, 'r') as zip_ref:
114
- zip_ref.extractall(extracted_path)
115
- print("Extraction successful!")
116
- except zipfile.BadZipFile:
117
- print("Error: The downloaded file is not a valid ZIP file.")
118
-
119
- # Optionally, you can delete the ZIP file after extraction
120
- os.remove(local_zip_file)
121
-
122
- # Convert the extracted directory path to a pathlib.Path object
123
- data_dir = Path(extracted_path)
124
-
125
- # Print the directory structure to debug
126
- for root, dirs, files in os.walk(extracted_path):
127
- level = root.replace(extracted_path, '').count(os.sep)
128
- indent = ' ' * 4 * (level)
129
- print(f"{indent}{os.path.basename(root)}/")
130
- subindent = ' ' * 4 * (level + 1)
131
- for f in files:
132
- print(f"{subindent}{f}")
133
-
134
- # Function to convert the directory path to a DataFrame
135
- def convert_path_to_df(dataset):
136
- image_dir = Path(dataset)
137
-
138
- # Get filepaths and labels
139
- filepaths = list(image_dir.glob(r'**/*.JPG')) + list(image_dir.glob(r'**/*.jpg')) + list(image_dir.glob(r'**/*.png')) + list(image_dir.glob(r'**/*.PNG'))
140
-
141
- labels = list(map(lambda x: os.path.split(os.path.split(x)[0])[1], filepaths))
142
-
143
- filepaths = pd.Series(filepaths, name='Filepath').astype(str)
144
- labels = pd.Series(labels, name='Label')
145
-
146
- # Concatenate filepaths and labels
147
- image_df = pd.concat([filepaths, labels], axis=1)
148
- return image_df
149
-
150
- # Path to the dataset directory
151
- data_dir = Path('extracted_files/Pest_Dataset')
152
- image_df = convert_path_to_df(data_dir)
153
-
154
- # Check for corrupted images within the dataset
155
- for img_p in data_dir.rglob("*.jpg"):
156
- try:
157
- img = Image.open(img_p)
158
- except UnidentifiedImageError:
159
- print(f"Corrupted image file: {img_p}")
160
-
161
- # You can save the DataFrame to a CSV for further use
162
- image_df.to_csv('image_dataset.csv', index=False)
163
- print("DataFrame created and saved successfully!")
164
-
165
- label_counts = image_df['Label'].value_counts()
166
-
167
- plt.figure(figsize=(10, 6))
168
- sns.barplot(x=label_counts.index, y=label_counts.values, alpha=0.8, palette='rocket')
169
- plt.title('Distribution of Labels in Image Dataset', fontsize=16)
170
- plt.xlabel('Label', fontsize=14)
171
- plt.ylabel('Count', fontsize=14)
172
- plt.xticks(rotation=45)
173
- plt.show()
174
-
175
- # Display 16 picture of the dataset with their labels
176
- random_index = np.random.randint(0, len(image_df), 16)
177
- fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(10, 10),
178
- subplot_kw={'xticks': [], 'yticks': []})
179
-
180
- for i, ax in enumerate(axes.flat):
181
- ax.imshow(plt.imread(image_df.Filepath[random_index[i]]))
182
- ax.set_title(image_df.Label[random_index[i]])
183
- plt.tight_layout()
184
- plt.show()
185
-
186
- # Function to return a random image path from a given directory
187
- def random_sample(directory):
188
- images = [os.path.join(directory, img) for img in os.listdir(directory) if img.endswith(('.jpg', '.jpeg', '.png'))]
189
- return random.choice(images)
190
-
191
- # Function to compute the Error Level Analysis (ELA) of an image
192
- def compute_ela_cv(path, quality):
193
- temp_filename = 'temp.jpg'
194
- orig = cv2.imread(path)
195
- cv2.imwrite(temp_filename, orig, [int(cv2.IMWRITE_JPEG_QUALITY), quality])
196
- compressed = cv2.imread(temp_filename)
197
- ela_image = cv2.absdiff(orig, compressed)
198
- ela_image = np.clip(ela_image * 10, 0, 255).astype(np.uint8)
199
- return ela_image
200
-
201
- # View random sample from the dataset
202
- p = random_sample('extracted_files/Pest_Dataset/beetle')
203
- orig = cv2.imread(p)
204
- orig = cv2.cvtColor(orig, cv2.COLOR_BGR2RGB) / 255.0
205
- init_val = 100
206
- columns = 3
207
- rows = 3
208
-
209
- fig=plt.figure(figsize=(15, 10))
210
- for i in range(1, columns*rows +1):
211
- quality=init_val - (i-1) * 8
212
- img = compute_ela_cv(path=p, quality=quality)
213
- if i == 1:
214
- img = orig.copy()
215
- ax = fig.add_subplot(rows, columns, i)
216
- ax.title.set_text(f'q: {quality}')
217
- plt.imshow(img)
218
- plt.show()
219
-
220
- # Separate in train and test data
221
- train_df, test_df = train_test_split(image_df, test_size=0.2, shuffle=True, random_state=42)
222
-
223
- train_generator = ImageDataGenerator(
224
- preprocessing_function=tf.keras.applications.efficientnet_v2.preprocess_input,
225
- validation_split=0.2
226
- )
227
-
228
- test_generator = ImageDataGenerator(
229
- preprocessing_function=tf.keras.applications.efficientnet_v2.preprocess_input
230
- )
231
-
232
- # Split the data into three categories.
233
- train_images = train_generator.flow_from_dataframe(
234
- dataframe=train_df,
235
- x_col='Filepath',
236
- y_col='Label',
237
- target_size=(224, 224),
238
- color_mode='rgb',
239
- class_mode='categorical',
240
- batch_size=32,
241
- shuffle=True,
242
- seed=42,
243
- subset='training'
244
- )
245
-
246
- val_images = train_generator.flow_from_dataframe(
247
- dataframe=train_df,
248
- x_col='Filepath',
249
- y_col='Label',
250
- target_size=(224, 224),
251
- color_mode='rgb',
252
- class_mode='categorical',
253
- batch_size=32,
254
- shuffle=True,
255
- seed=42,
256
- subset='validation'
257
- )
258
-
259
- test_images = test_generator.flow_from_dataframe(
260
- dataframe=test_df,
261
- x_col='Filepath',
262
- y_col='Label',
263
- target_size=(224, 224),
264
- color_mode='rgb',
265
- class_mode='categorical',
266
- batch_size=32,
267
- shuffle=False
268
- )
269
-
270
- # Data Augmentation Step
271
- augment = tf.keras.Sequential([
272
- tf.keras.layers.Resizing(224, 224),
273
- tf.keras.layers.Rescaling(1./255),
274
- tf.keras.layers.RandomFlip("horizontal"),
275
- tf.keras.layers.RandomRotation(0.1),
276
- tf.keras.layers.RandomZoom(0.1),
277
- tf.keras.layers.RandomContrast(0.1),
278
- ])
279
-
280
- # Load the pretained model
281
- pretrained_model = tf.keras.applications.efficientnet_v2.EfficientNetV2L(
282
- input_shape=(224, 224, 3),
283
- include_top=False,
284
- weights='imagenet',
285
- pooling='max'
286
- )
287
-
288
- pretrained_model.trainable = False
289
-
290
- # Create checkpoint callback
291
- checkpoint_path = "pests_cats_classification_model_checkpoint"
292
- checkpoint_callback = ModelCheckpoint(checkpoint_path,
293
- save_weights_only=True,
294
- monitor="val_accuracy",
295
- save_best_only=True)
296
-
297
- # Setup EarlyStopping callback to stop training if model's val_loss doesn't improve for 3 epochs
298
- early_stopping = EarlyStopping(monitor = "val_loss", # watch the val loss metric
299
- patience = 5,
300
- restore_best_weights = True) # if val loss decreases for 3 epochs in a row, stop training
301
-
302
- inputs = pretrained_model.input
303
- x = augment(inputs)
304
-
305
- # Add new classification layers
306
- x = Flatten()(pretrained_model.output)
307
- x = Dense(256, activation='relu')(x)
308
- x = Dropout(0.5)(x)
309
- x = BatchNormalization()(x)
310
- x = Dense(128, activation='relu')(x)
311
- x = Dropout(0.5)(x)
312
-
313
- outputs = Dense(12, activation='softmax')(x)
314
-
315
- model = Model(inputs=inputs, outputs=outputs)
316
-
317
- model.compile(
318
- optimizer=Adam(0.00001),
319
- loss='categorical_crossentropy',
320
- metrics=['accuracy']
321
- )
322
-
323
- history = model.fit(
324
- train_images,
325
- steps_per_epoch=len(train_images),
326
- validation_data=val_images,
327
- validation_steps=len(val_images),
328
- epochs=60, # Adjusted to 30 epochs
329
- callbacks=[
330
- early_stopping,
331
- create_tensorboard_callback("training_logs",
332
- "pests_cats_classification"),
333
- checkpoint_callback,
334
- ]
335
- )
336
-
337
-
338
- results = model.evaluate(test_images, verbose=0)
339
-
340
- print(" Test Loss: {:.5f}".format(results[0]))
341
- print("Test Accuracy: {:.2f}%".format(results[1] * 100))
342
-
343
-
344
- class_names = train_images.class_indices
345
- class_names = {v: k for k, v in class_names.items()}
346
-
347
- # Gradio Interface for Prediction
348
- def predict_image(img):
349
- img = np.array(img)
350
- img_resized = tf.image.resize(img, (TARGET_SIZE[0], TARGET_SIZE[1]))
351
- img_4d = tf.expand_dims(img_resized, axis=0)
352
- prediction = model.predict(img_4d)[0]
353
- return {class_names[i]: float(prediction[i]) for i in range(len(class_names))}
354
-
355
- # Launch Gradio interface
356
- image = gr.Image()
357
- label = gr.Label(num_top_classes=12)
358
 
359
  gr.Interface(
360
  fn=predict_image,
361
  inputs=image,
362
  outputs=label,
363
- title="Pest Classification",
364
  description="Upload an image of a pest to classify it into one of the predefined categories.",
 
365
  ).launch(debug=True)
 
1
+ # Define custom CSS for background image
2
+ custom_css = """
3
+ body {
4
+ background-image: url('/extracted_files/Pest_Dataset/bees/bees (444).jpg');
5
+ background-size: cover;
6
+ background-repeat: no-repeat;
7
+ background-attachment: fixed;
8
+ color: white;
9
+ }
10
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  gr.Interface(
13
  fn=predict_image,
14
  inputs=image,
15
  outputs=label,
16
+ title="PestScout: An Agricultural Pest Image Classification System Using Deep Conventional Neural Networks",
17
  description="Upload an image of a pest to classify it into one of the predefined categories.",
18
+ css=custom_css
19
  ).launch(debug=True)