NORLIE JHON MALAGDAO commited on
Commit
23b0b6e
·
verified ·
1 Parent(s): 5b3e3dd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import os
5
+ import PIL
6
+ import tensorflow as tf
7
+
8
+ from tensorflow import keras
9
+ from tensorflow.keras import layers
10
+ from tensorflow.keras.models import Sequential
11
+
12
+
13
+ from PIL import Image
14
+ import gdown
15
+ import zipfile
16
+
17
+ import pathlib
18
+
19
+ # Define the Google Drive shareable link
20
+ gdrive_url = 'https://drive.google.com/file/d/1HjHYlQyRz5oWt8kehkt1TiOGRRlKFsv8/view?usp=drive_link'
21
+
22
+ # Extract the file ID from the URL
23
+ file_id = gdrive_url.split('/d/')[1].split('/view')[0]
24
+ direct_download_url = f'https://drive.google.com/uc?id={file_id}'
25
+
26
+ # Define the local filename to save the ZIP file
27
+ local_zip_file = 'file.zip'
28
+
29
+ # Download the ZIP file
30
+ gdown.download(direct_download_url, local_zip_file, quiet=False)
31
+
32
+ # Directory to extract files
33
+ extracted_path = 'extracted_files'
34
+
35
+ # Verify if the downloaded file is a ZIP file and extract it
36
+ try:
37
+ with zipfile.ZipFile(local_zip_file, 'r') as zip_ref:
38
+ zip_ref.extractall(extracted_path)
39
+ print("Extraction successful!")
40
+ except zipfile.BadZipFile:
41
+ print("Error: The downloaded file is not a valid ZIP file.")
42
+
43
+ # Optionally, you can delete the ZIP file after extraction
44
+ os.remove(local_zip_file)
45
+
46
+ # Convert the extracted directory path to a pathlib.Path object
47
+ data_dir = pathlib.Path(extracted_path)
48
+
49
+ # Print the directory structure to debug
50
+ for root, dirs, files in os.walk(extracted_path):
51
+ level = root.replace(extracted_path, '').count(os.sep)
52
+ indent = ' ' * 4 * (level)
53
+ print(f"{indent}{os.path.basename(root)}/")
54
+ subindent = ' ' * 4 * (level + 1)
55
+ for f in files:
56
+ print(f"{subindent}{f}")
57
+
58
+ import pathlib
59
+ # Path to the dataset directory
60
+ data_dir = pathlib.Path('extracted_files/Pest_Dataset')
61
+ data_dir = pathlib.Path(data_dir)
62
+
63
+
64
+ bees = list(data_dir.glob('bees/*'))
65
+ print(bees[0])
66
+ PIL.Image.open(str(bees[0]))
67
+
68
+
69
+ bees = list(data_dir.glob('bees/*'))
70
+ print(bees[0])
71
+ PIL.Image.open(str(bees[0]))
72
+
73
+
74
+ img_height,img_width=180,180
75
+ batch_size=32
76
+ train_ds = tf.keras.preprocessing.image_dataset_from_directory(
77
+ data_dir,
78
+ validation_split=0.2,
79
+ subset="training",
80
+ seed=123,
81
+ image_size=(img_height, img_width),
82
+ batch_size=batch_size)
83
+
84
+
85
+ val_ds = tf.keras.preprocessing.image_dataset_from_directory(
86
+ data_dir,
87
+ validation_split=0.2,
88
+ subset="validation",
89
+ seed=123,
90
+ image_size=(img_height, img_width),
91
+ batch_size=batch_size)
92
+
93
+
94
+ class_names = train_ds.class_names
95
+ print(class_names)
96
+
97
+
98
+ import matplotlib.pyplot as plt
99
+
100
+ plt.figure(figsize=(10, 10))
101
+ for images, labels in train_ds.take(1):
102
+ for i in range(9):
103
+ ax = plt.subplot(3, 3, i + 1)
104
+ plt.imshow(images[i].numpy().astype("uint8"))
105
+ plt.title(class_names[labels[i]])
106
+ plt.axis("off")
107
+
108
+
109
+ data_augmentation = keras.Sequential(
110
+ [
111
+ layers.RandomFlip("horizontal", input_shape=(img_height, img_width, 3)),
112
+ layers.RandomRotation(0.1),
113
+ layers.RandomZoom(0.1),
114
+ ]
115
+ )
116
+
117
+
118
+ plt.figure(figsize=(10, 10))
119
+ for images, _ in train_ds.take(1):
120
+ for i in range(9):
121
+ augmented_images = data_augmentation(images)
122
+ ax = plt.subplot(3, 3, i + 1)
123
+ plt.imshow(augmented_images[0].numpy().astype("uint8"))
124
+ plt.axis("off")
125
+
126
+
127
+ num_classes = len(class_names)
128
+ model = Sequential([
129
+ data_augmentation,
130
+ layers.Rescaling(1./255),
131
+ layers.Conv2D(16, 3, padding='same', activation='relu'),
132
+ layers.MaxPooling2D(),
133
+ layers.Conv2D(32, 3, padding='same', activation='relu'),
134
+ layers.MaxPooling2D(),
135
+ layers.Conv2D(64, 3, padding='same', activation='relu'),
136
+ layers.MaxPooling2D(),
137
+ layers.Dropout(0.2),
138
+ layers.Flatten(),
139
+ layers.Dense(128, activation='relu'),
140
+ layers.Dense(num_classes, activation='softmax', name="outputs") # Use softmax here
141
+ ])
142
+
143
+ model.compile(optimizer='adam',
144
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), # Change from_logits to False
145
+ metrics=['accuracy'])
146
+
147
+ model.summary()
148
+
149
+
150
+ epochs = 15
151
+ history = model.fit(
152
+ train_ds,
153
+ validation_data=val_ds,
154
+ epochs=epochs
155
+ )
156
+
157
+
158
+ import gradio as gr
159
+ import numpy as np
160
+ import tensorflow as tf
161
+
162
+ def predict_image(img):
163
+ img = np.array(img)
164
+ img_resized = tf.image.resize(img, (180, 180))
165
+ img_4d = tf.expand_dims(img_resized, axis=0)
166
+ prediction = model.predict(img_4d)[0]
167
+ return {class_names[i]: float(prediction[i]) for i in range(len(class_names))}
168
+
169
+ image = gr.Image()
170
+ label = gr.Label(num_top_classes=1)
171
+
172
+ # Define custom CSS for background image
173
+ custom_css = """
174
+ body {
175
+ background-image: url('extracted_files/Pest_Dataset/bees/bees (444).jpg');
176
+ background-size: cover;
177
+ background-repeat: no-repeat;
178
+ background-attachment: fixed;
179
+ color: white;
180
+ }
181
+ """
182
+
183
+ gr.Interface(
184
+ fn=predict_image,
185
+ inputs=image,
186
+ outputs=label,
187
+ title="Welcome to Agricultural Pest Image Classification",
188
+ description="The image data set used was obtained from Kaggle and has a collection of 12 different types of agricultural pests: Ants, Bees, Beetles, Caterpillars, Earthworms, Earwigs, Grasshoppers, Moths, Slugs, Snails, Wasps, and Weevils",
189
+ css=custom_css
190
+ ).launch(debug=True)