Spaces:
Runtime error
Runtime error
import os | |
import zipfile | |
import gdown | |
import pathlib | |
import tensorflow as tf | |
from tensorflow.keras.preprocessing import image_dataset_from_directory | |
from tensorflow.keras import layers | |
from tensorflow.keras.models import Sequential | |
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dropout, Flatten, Dense, BatchNormalization, Rescaling | |
from tensorflow.keras.callbacks import EarlyStopping, LearningRateScheduler | |
import gradio as gr | |
import numpy as np | |
# Define the Google Drive shareable link | |
gdrive_url = 'https://drive.google.com/file/d/1HjHYlQyRz5oWt8kehkt1TiOGRRlKFsv8/view?usp=drive_link' | |
# Extract the file ID from the URL | |
file_id = gdrive_url.split('/d/')[1].split('/view')[0] | |
direct_download_url = f'https://drive.google.com/uc?id={file_id}' | |
# Define the local filename to save the ZIP file | |
local_zip_file = 'file.zip' | |
# Download the ZIP file | |
gdown.download(direct_download_url, local_zip_file, quiet=False) | |
# Directory to extract files | |
extracted_path = 'extracted_files' | |
# Verify if the downloaded file is a ZIP file and extract it | |
try: | |
with zipfile.ZipFile(local_zip_file, 'r') as zip_ref: | |
zip_ref.extractall(extracted_path) | |
print("Extraction successful!") | |
except zipfile.BadZipFile: | |
print("Error: The downloaded file is not a valid ZIP file.") | |
# Optionally, you can delete the ZIP file after extraction | |
os.remove(local_zip_file) | |
# Convert the extracted directory path to a pathlib.Path object | |
data_dir = pathlib.Path('extracted_files/Pest_Dataset') | |
# Set image dimensions and batch size | |
img_height, img_width = 180, 180 | |
batch_size = 32 | |
# Create training and validation datasets | |
train_ds = image_dataset_from_directory( | |
data_dir, | |
validation_split=0.2, | |
subset="training", | |
seed=123, | |
image_size=(img_height, img_width), | |
batch_size=batch_size | |
) | |
val_ds = image_dataset_from_directory( | |
data_dir, | |
validation_split=0.2, | |
subset="validation", | |
seed=123, | |
image_size=(img_height, img_width), | |
batch_size=batch_size | |
) | |
class_names = train_ds.class_names | |
print(class_names) | |
data_augmentation = tf.keras.Sequential( | |
[ | |
layers.RandomFlip("horizontal", input_shape=(img_height, img_width, 3)), | |
layers.RandomRotation(0.2), | |
layers.RandomZoom(0.2), | |
layers.RandomContrast(0.2), | |
layers.RandomBrightness(0.2), | |
] | |
) | |
num_classes = len(class_names) | |
model = Sequential() | |
model.add(data_augmentation) | |
model.add(Rescaling(1./255)) | |
model.add(Conv2D(32, 3, padding='same', activation='relu')) | |
model.add(BatchNormalization()) | |
model.add(MaxPooling2D()) | |
model.add(Conv2D(64, 3, padding='same', activation='relu')) | |
model.add(BatchNormalization()) | |
model.add(MaxPooling2D()) | |
model.add(Conv2D(128, 3, padding='same', activation='relu')) | |
model.add(BatchNormalization()) | |
model.add(MaxPooling2D()) | |
model.add(Conv2D(256, 3, padding='same', activation='relu')) | |
model.add(BatchNormalization()) | |
model.add(MaxPooling2D()) | |
model.add(Conv2D(512, 3, padding='same', activation='relu')) | |
model.add(BatchNormalization()) | |
model.add(MaxPooling2D()) | |
model.add(Dropout(0.5)) | |
model.add(Flatten()) | |
model.add(Dense(256, activation='relu')) | |
model.add(Dropout(0.5)) | |
model.add(Dense(num_classes, activation='softmax', name="outputs")) | |
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4), | |
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), | |
metrics=['accuracy']) | |
model.summary() | |
# Implement early stopping | |
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True) | |
# Learning rate scheduler | |
def scheduler(epoch, lr): | |
if epoch < 10: | |
return lr | |
else: | |
return lr * tf.math.exp(-0.1) | |
lr_scheduler = LearningRateScheduler(scheduler) | |
# Train the model | |
epochs = 30 | |
history = model.fit( | |
train_ds, | |
validation_data=val_ds, | |
epochs=epochs, | |
callbacks=[early_stopping, lr_scheduler] | |
) | |
def predict_image(img): | |
img = np.array(img) | |
img_resized = tf.image.resize(img, (img_height, img_width)) | |
img_4d = tf.expand_dims(img_resized, axis=0) | |
prediction = model.predict(img_4d)[0] | |
predicted_class = np.argmax(prediction) | |
predicted_label = class_names[predicted_class] | |
return {predicted_label: f"{float(prediction[predicted_class]):.2f}"} | |
image = gr.Image() | |
label = gr.Label(num_top_classes=1) | |
# Define custom CSS for background image | |
custom_css = """ | |
body { | |
background-image: url('extracted_files/Pest_Dataset/bees/bees (444).jpg'); | |
background-size: cover; | |
background-repeat: no-repeat; | |
background-attachment: fixed; | |
color: white; | |
} | |
""" | |
gr.Interface( | |
fn=predict_image, | |
inputs=image, | |
outputs=label, | |
title="Welcome to Agricultural Pest Image Classification", | |
description="The image data set used was obtained from Kaggle and has a collection of 12 different types of agricultural pests: Ants, Bees, Beetles, Caterpillars, Earthworms, Earwigs, Grasshoppers, Moths, Slugs, Snails, Wasps, and Weevils", | |
css=custom_css | |
).launch(debug=True) | |