Spaces:
Runtime error
Runtime error
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import os | |
import PIL | |
import tensorflow as tf | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
from tensorflow.keras.models import Sequential | |
from PIL import Image | |
import gdown | |
import zipfile | |
import pathlib | |
# Download and extract dataset | |
gdrive_url = 'https://drive.google.com/file/d/1HjHYlQyRz5oWt8kehkt1TiOGRRlKFsv8/view?usp=drive_link' | |
file_id = gdrive_url.split('/d/')[1].split('/view')[0] | |
direct_download_url = f'https://drive.google.com/uc?id={file_id}' | |
local_zip_file = 'file.zip' | |
gdown.download(direct_download_url, local_zip_file, quiet=False) | |
extracted_path = 'extracted_files' | |
try: | |
with zipfile.ZipFile(local_zip_file, 'r') as zip_ref: | |
zip_ref.extractall(extracted_path) | |
print("Extraction successful!") | |
except zipfile.BadZipFile: | |
print("Error: The downloaded file is not a valid ZIP file.") | |
os.remove(local_zip_file) | |
data_dir = pathlib.Path(extracted_path) / 'Pest_Dataset' | |
# Data loading and preprocessing | |
img_height, img_width = 180, 180 | |
batch_size = 32 | |
train_ds = tf.keras.preprocessing.image_dataset_from_directory( | |
data_dir, | |
validation_split=0.2, | |
subset="training", | |
seed=123, | |
image_size=(img_height, img_width), | |
batch_size=batch_size | |
) | |
val_ds = tf.keras.preprocessing.image_dataset_from_directory( | |
data_dir, | |
validation_split=0.2, | |
subset="validation", | |
seed=123, | |
image_size=(img_height, img_width), | |
batch_size=batch_size | |
) | |
class_names = train_ds.class_names | |
# Data augmentation | |
data_augmentation = keras.Sequential( | |
[ | |
layers.RandomFlip("horizontal", input_shape=(img_height, img_width, 3)), | |
layers.RandomRotation(0.1), | |
layers.RandomZoom(0.1), | |
layers.RandomBrightness(0.2), | |
layers.RandomContrast(0.2), | |
] | |
) | |
# Model definition | |
num_classes = len(class_names) | |
model = Sequential([ | |
data_augmentation, | |
layers.Rescaling(1./255), | |
layers.Conv2D(16, 3, padding='same', activation='relu'), | |
layers.MaxPooling2D(), | |
layers.Conv2D(32, 3, padding='same', activation='relu'), | |
layers.MaxPooling2D(), | |
layers.Conv2D(64, 3, padding='same', activation='relu'), | |
layers.MaxPooling2D(), | |
layers.Conv2D(128, 3, padding='same', activation='relu'), | |
layers.MaxPooling2D(), | |
layers.Dropout(0.5), | |
layers.Flatten(), | |
layers.Dense(256, activation='relu'), | |
layers.Dense(num_classes, activation='softmax', name="outputs") | |
]) | |
optimizer = keras.optimizers.Adam(learning_rate=0.001) | |
lr_scheduler = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3) | |
early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True) | |
model.compile(optimizer=optimizer, | |
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), | |
metrics=['accuracy']) | |
model.summary() | |
# Train the model | |
epochs = 15 | |
history = model.fit( | |
train_ds, | |
validation_data=val_ds, | |
epochs=epochs, | |
callbacks=[lr_scheduler, early_stopping] | |
) | |
# Define category descriptions | |
category_descriptions = { | |
"Ants": "Ants are small insects known for their complex social structures and teamwork.", | |
"Bees": "Bees are flying insects known for their role in pollination and producing honey.", | |
"Beetles": "Beetles are a group of insects with hard exoskeletons and wings. They are the largest order of insects.", | |
"Caterpillars": "Caterpillars are the larval stage of butterflies and moths, known for their voracious appetite.", | |
"Earthworms": "Earthworms are segmented worms that are crucial for soil health and nutrient cycling.", | |
"Earwigs": "Earwigs are insects with pincers on their abdomen and are known for their nocturnal activity.", | |
"Grasshoppers": "Grasshoppers are insects known for their powerful hind legs, which they use for jumping.", | |
"Moths": "Moths are nocturnal insects related to butterflies, known for their attraction to light.", | |
"Slugs": "Slugs are soft-bodied mollusks that are similar to snails but lack a shell.", | |
"Snails": "Snails are mollusks with a coiled shell, known for their slow movement and slimy trail.", | |
"Wasps": "Wasps are stinging insects that can be solitary or social, and some species are important pollinators.", | |
"Weevils": "Weevils are a type of beetle with a long snout, known for being pests to crops and stored grains." | |
} | |
# Prediction function | |
def predict_image(img): | |
img = np.array(img) | |
img_resized = tf.image.resize(img, (180, 180)) | |
img_4d = tf.expand_dims(img_resized, axis=0) | |
prediction = model.predict(img_4d)[0] | |
top_3_indices = prediction.argsort()[-3:][::-1] | |
results = {} | |
for i in top_3_indices: | |
class_name = class_names[i] | |
results[class_name] = f"{float(prediction[i]):.2f} - {category_descriptions[class_name]}" | |
return results | |
# Gradio interface setup | |
image = gr.Image() | |
label = gr.Label(num_top_classes=3) | |
custom_css = """ | |
body {background-color: #f5f5f5;} | |
.gradio-container {border: 1px solid #ccc; border-radius: 10px; padding: 20px;} | |
""" | |
gr.Interface( | |
fn=predict_image, | |
inputs=image, | |
outputs=label, | |
title="Agricultural Pest Image Classification", | |
description="Identify 12 types of agricultural pests from images. This model was trained on a dataset from Kaggle.", | |
css=custom_css | |
).launch(debug=True) | |