NORLIE JHON MALAGDAO commited on
Commit
e34e22a
·
verified ·
1 Parent(s): 27ebf6c

Create app.file

Browse files
Files changed (1) hide show
  1. app.file +175 -0
app.file ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import os
5
+ import PIL
6
+ import tensorflow as tf
7
+
8
+ from tensorflow import keras
9
+ from tensorflow.keras import layers
10
+ from tensorflow.keras.models import Sequential
11
+
12
+ from PIL import Image
13
+ import gdown
14
+ import zipfile
15
+ import pathlib
16
+
17
+ # Define the Google Drive shareable link
18
+ gdrive_url = 'https://drive.google.com/file/d/1HjHYlQyRz5oWt8kehkt1TiOGRRlKFsv8/view?usp=drive_link'
19
+
20
+ # Extract the file ID from the URL
21
+ file_id = gdrive_url.split('/d/')[1].split('/view')[0]
22
+ direct_download_url = f'https://drive.google.com/uc?id={file_id}'
23
+
24
+ # Define the local filename to save the ZIP file
25
+ local_zip_file = 'file.zip'
26
+
27
+ # Download the ZIP file
28
+ gdown.download(direct_download_url, local_zip_file, quiet=False)
29
+
30
+ # Directory to extract files
31
+ extracted_path = 'extracted_files'
32
+
33
+ # Verify if the downloaded file is a ZIP file and extract it
34
+ try:
35
+ with zipfile.ZipFile(local_zip_file, 'r') as zip_ref:
36
+ zip_ref.extractall(extracted_path)
37
+ print("Extraction successful!")
38
+ except zipfile.BadZipFile:
39
+ print("Error: The downloaded file is not a valid ZIP file.")
40
+
41
+ # Optionally, you can delete the ZIP file after extraction
42
+ os.remove(local_zip_file)
43
+
44
+ # Convert the extracted directory path to a pathlib.Path object
45
+ data_dir = pathlib.Path(extracted_path)
46
+
47
+ # Print the directory structure to debug
48
+ for root, dirs, files in os.walk(extracted_path):
49
+ level = root.replace(extracted_path, '').count(os.sep)
50
+ indent = ' ' * 4 * (level)
51
+ print(f"{indent}{os.path.basename(root)}/")
52
+ subindent = ' ' * 4 * (level + 1)
53
+ for f in files:
54
+ print(f"{subindent}{f}")
55
+
56
+ # Path to the dataset directory
57
+ data_dir = pathlib.Path('extracted_files/Pest_Dataset')
58
+ data_dir = pathlib.Path(data_dir)
59
+
60
+ # Verify if the path exists
61
+ assert data_dir.exists(), f"Path {data_dir} does not exist."
62
+
63
+ # Load the dataset
64
+ img_height, img_width = 180, 180
65
+ batch_size = 32
66
+
67
+ train_ds = tf.keras.preprocessing.image_dataset_from_directory(
68
+ data_dir,
69
+ validation_split=0.2,
70
+ subset="training",
71
+ seed=123,
72
+ image_size=(img_height, img_width),
73
+ batch_size=batch_size
74
+ )
75
+
76
+ val_ds = tf.keras.preprocessing.image_dataset_from_directory(
77
+ data_dir,
78
+ validation_split=0.2,
79
+ subset="validation",
80
+ seed=123,
81
+ image_size=(img_height, img_width),
82
+ batch_size=batch_size
83
+ )
84
+
85
+ class_names = train_ds.class_names
86
+ print(class_names)
87
+
88
+ # Plot some images from the training dataset
89
+ plt.figure(figsize=(10, 10))
90
+ for images, labels in train_ds.take(1):
91
+ for i in range(9):
92
+ ax = plt.subplot(3, 3, i + 1)
93
+ plt.imshow(images[i].numpy().astype("uint8"))
94
+ plt.title(class_names[labels[i]])
95
+ plt.axis("off")
96
+
97
+ # Define data augmentation
98
+ data_augmentation = keras.Sequential(
99
+ [
100
+ layers.RandomFlip("horizontal", input_shape=(img_height, img_width, 3)),
101
+ layers.RandomRotation(0.1),
102
+ layers.RandomZoom(0.1),
103
+ ]
104
+ )
105
+
106
+ # Plot augmented images
107
+ plt.figure(figsize=(10, 10))
108
+ for images, _ in train_ds.take(1):
109
+ for i in range(9):
110
+ augmented_images = data_augmentation(images)
111
+ ax = plt.subplot(3, 3, i + 1)
112
+ plt.imshow(augmented_images[0].numpy().astype("uint8"))
113
+ plt.axis("off")
114
+
115
+ # Define the model
116
+ num_classes = len(class_names)
117
+ model = Sequential([
118
+ data_augmentation,
119
+ layers.Rescaling(1./255),
120
+ layers.Conv2D(16, 3, padding='same', activation='relu'),
121
+ layers.MaxPooling2D(),
122
+ layers.Conv2D(32, 3, padding='same', activation='relu'),
123
+ layers.MaxPooling2D(),
124
+ layers.Conv2D(64, 3, padding='same', activation='relu'),
125
+ layers.MaxPooling2D(),
126
+ layers.Dropout(0.2),
127
+ layers.Flatten(),
128
+ layers.Dense(128, activation='relu'),
129
+ layers.Dense(num_classes, activation='softmax')
130
+ ])
131
+
132
+ model.compile(optimizer='adam',
133
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
134
+ metrics=['accuracy'])
135
+
136
+ model.summary()
137
+
138
+ # Train the model
139
+ epochs = 15
140
+ history = model.fit(
141
+ train_ds,
142
+ validation_data=val_ds,
143
+ epochs=epochs
144
+ )
145
+
146
+ # Define the Gradio interface
147
+ def predict_image(img):
148
+ img = np.array(img)
149
+ img_resized = tf.image.resize(img, (img_height, img_width))
150
+ img_4d = tf.expand_dims(img_resized, axis=0)
151
+ prediction = model.predict(img_4d)[0]
152
+ return {class_names[i]: float(prediction[i]) for i in range(len(class_names))}
153
+
154
+ image = gr.Image()
155
+ label = gr.Label(num_top_classes=1)
156
+
157
+ # Define custom CSS for background image
158
+ custom_css = """
159
+ body {
160
+ background-image: url('extracted_files/Pest_Dataset/bees/bees (444).jpg');
161
+ background-size: cover;
162
+ background-repeat: no-repeat;
163
+ background-attachment: fixed;
164
+ color: white;
165
+ }
166
+ """
167
+
168
+ gr.Interface(
169
+ fn=predict_image,
170
+ inputs=image,
171
+ outputs=label,
172
+ title="Welcome to Agricultural Pest Image Classification",
173
+ description="The image data set used was obtained from Kaggle and has a collection of 12 different types of agricultural pests: Ants, Bees, Beetles, Caterpillars, Earthworms, Earwigs, Grasshoppers, Moths, Slugs, Snails, Wasps, and Weevils",
174
+ css=custom_css
175
+ ).launch(debug=True)