Spaces:
Runtime error
Runtime error
File size: 3,664 Bytes
05e9cff e5b5b7e 83e4be4 05e9cff 83e4be4 8b884e6 05e9cff 83e4be4 05e9cff 83e4be4 77af281 05e9cff e5b5b7e 05e9cff e5b5b7e a2718c9 05e9cff 68d5b48 1f0eeb1 77af281 68d5b48 05e9cff 68d5b48 05e9cff 8e0a53d 68d5b48 a2718c9 68d5b48 05e9cff 8b884e6 c0cb430 05e9cff 8b884e6 bf81dde 05e9cff bf81dde a2718c9 68d5b48 8e0a53d 5f80dfa a2718c9 05e9cff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
batch_size = 32
img_height = 180
img_width = 180
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
normalization_layer = layers.Rescaling(1./255)
normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(normalized_ds))
first_image = image_batch[0]
# Notice the pixel values are now in `[0,1]`.
print(np.min(first_image), np.max(first_image))
data_augmentation = keras.Sequential(
[
layers.RandomFlip("horizontal",
input_shape=(img_height,
img_width,
3)),
layers.RandomRotation(0.1),
layers.RandomZoom(0.1),
]
)
plt.figure(figsize=(10, 10))
for images, _ in train_ds.take(1):
for i in range(9):
augmented_images = data_augmentation(images)
ax = plt.subplot(3, 3, i + 1)
plt.imshow(augmented_images[0].numpy().astype("uint8"))
plt.axis("off")
num_classes = len(class_names)
model = Sequential([
data_augmentation,
layers.Rescaling(1./255),
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Dropout(0.2),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes, name="outputs")
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.summary()
epochs = 15
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
import gradio as gr
import numpy as np
import tensorflow as tf
def predict_image(img):
img = np.array(img)
img_resized = tf.image.resize(img, (180, 180))
img_4d = tf.expand_dims(img_resized, axis=0)
prediction = model.predict(img_4d)[0]
return {class_names[i]: float(prediction[i]) for i in range(len(class_names))}
image = gr.Image()
label = gr.Label(num_top_classes=12)
# Define custom CSS for background image
custom_css = """
body {
background-image: url('\extracted_files\Pest_Dataset\bees\bees (444).jpg');
background-size: cover;
background-repeat: no-repeat;
background-attachment: fixed;
color: white;
}
"""
gr.Interface(
fn=predict_image,
inputs=image,
outputs=label,
title="Welcome to Agricultural Pest Image Classification",
description="The image data set used was obtaied from Kaggle and has a collection of 12 different types of agricultral pests: Ants, Bees, Beetles, Caterpillars, Earthworms, Earwigs, Grasshoppers, Moths, Slugs, Snails, Wasps, and Weevils",
css=custom_css
).launch(debug=True) |