NORLIE JHON MALAGDAO
Update app.py
fd3cb72 verified
raw
history blame
4.52 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf
import gdown
import zipfile
import pathlib
from tensorflow import keras
from tensorflow.keras import layers, callbacks
# Define the Google Drive shareable link
gdrive_url = 'https://drive.google.com/file/d/1HjHYlQyRz5oWt8kehkt1TiOGRRlKFsv8/view?usp=drive_link'
file_id = gdrive_url.split('/d/')[1].split('/view')[0]
direct_download_url = f'https://drive.google.com/uc?id={file_id}'
# Define the local filename to save the ZIP file
local_zip_file = 'file.zip'
# Download the ZIP file
gdown.download(direct_download_url, local_zip_file, quiet=False)
# Directory to extract files
extracted_path = 'extracted_files'
# Verify if the downloaded file is a ZIP file and extract it
try:
with zipfile.ZipFile(local_zip_file, 'r') as zip_ref:
zip_ref.extractall(extracted_path)
print("Extraction successful!")
except zipfile.BadZipFile:
print("Error: The downloaded file is not a valid ZIP file.")
# Optionally, you can delete the ZIP file after extraction
os.remove(local_zip_file)
# Convert the extracted directory path to a pathlib.Path object
data_dir = pathlib.Path(extracted_path) / 'Pest_Dataset'
# Load and preprocess data
img_height, img_width = 180, 180
batch_size = 32
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size
)
# Class names
class_names = train_ds.class_names
# Data augmentation
data_augmentation = keras.Sequential([
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
layers.RandomZoom(0.1),
])
# Model
num_classes = len(class_names)
model = Sequential([
data_augmentation,
layers.Rescaling(1./255),
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Dropout(0.2),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes, name="outputs")
])
# Compile the model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# Early stopping callback
early_stopping = callbacks.EarlyStopping(
monitor='val_loss', patience=5, restore_best_weights=True
)
# Train the model
epochs = 50
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs,
callbacks=[early_stopping]
)
# Evaluate the model on validation data
results = model.evaluate(val_ds, verbose=0)
print("Validation Loss: {:.5f}".format(results[0]))
print("Validation Accuracy: {:.2f}%".format(results[1] * 100))
# Plot training history
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training and Validation Accuracy')
plt.show()
# Prediction function
def predict_image(img):
img = np.array(img)
img_resized = tf.image.resize(img, (img_height, img_width))
img_4d = tf.expand_dims(img_resized, axis=0)
prediction = model.predict(img_4d)[0]
return {class_names[i]: float(prediction[i]) for i in range(num_classes)}
# Interface
image = gr.Image()
label = gr.Label(num_top_classes=num_classes)
custom_css = """
body {
background-image: url('extracted_files/Pest_Dataset/bees/bees (444).jpg');
background-size: cover;
background-repeat: no-repeat;
background-attachment: fixed;
color: white;
}
"""
gr.Interface(
fn=predict_image,
inputs=image,
outputs=label,
title="Pest Classification",
description="Upload an image of a pest to classify it into one of the predefined categories.",
css=custom_css
).launch(debug=True)