import gradio as gr import matplotlib.pyplot as plt import numpy as np import os import PIL import tensorflow as tf import gdown import zipfile import pathlib from tensorflow import keras from tensorflow.keras import layers, callbacks # Define the Google Drive shareable link gdrive_url = 'https://drive.google.com/file/d/1HjHYlQyRz5oWt8kehkt1TiOGRRlKFsv8/view?usp=drive_link' file_id = gdrive_url.split('/d/')[1].split('/view')[0] direct_download_url = f'https://drive.google.com/uc?id={file_id}' # Define the local filename to save the ZIP file local_zip_file = 'file.zip' # Download the ZIP file gdown.download(direct_download_url, local_zip_file, quiet=False) # Directory to extract files extracted_path = 'extracted_files' # Verify if the downloaded file is a ZIP file and extract it try: with zipfile.ZipFile(local_zip_file, 'r') as zip_ref: zip_ref.extractall(extracted_path) print("Extraction successful!") except zipfile.BadZipFile: print("Error: The downloaded file is not a valid ZIP file.") # Optionally, you can delete the ZIP file after extraction os.remove(local_zip_file) # Convert the extracted directory path to a pathlib.Path object data_dir = pathlib.Path(extracted_path) / 'Pest_Dataset' # Load and preprocess data img_height, img_width = 180, 180 batch_size = 32 train_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, subset="training", seed=123, image_size=(img_height, img_width), batch_size=batch_size ) val_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, subset="validation", seed=123, image_size=(img_height, img_width), batch_size=batch_size ) # Class names class_names = train_ds.class_names # Data augmentation data_augmentation = keras.Sequential([ layers.RandomFlip("horizontal"), layers.RandomRotation(0.1), layers.RandomZoom(0.1), ]) # Model num_classes = len(class_names) model = Sequential([ data_augmentation, layers.Rescaling(1./255), layers.Conv2D(16, 3, padding='same', activation='relu'), layers.MaxPooling2D(), layers.Conv2D(32, 3, padding='same', activation='relu'), layers.MaxPooling2D(), layers.Conv2D(64, 3, padding='same', activation='relu'), layers.MaxPooling2D(), layers.Dropout(0.2), layers.Flatten(), layers.Dense(128, activation='relu'), layers.Dense(num_classes, name="outputs") ]) # Compile the model model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) # Early stopping callback early_stopping = callbacks.EarlyStopping( monitor='val_loss', patience=5, restore_best_weights=True ) # Train the model epochs = 50 history = model.fit( train_ds, validation_data=val_ds, epochs=epochs, callbacks=[early_stopping] ) # Evaluate the model on validation data results = model.evaluate(val_ds, verbose=0) print("Validation Loss: {:.5f}".format(results[0])) print("Validation Accuracy: {:.2f}%".format(results[1] * 100)) # Plot training history plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.plot(history.history['loss'], label='Training Loss') plt.plot(history.history['val_loss'], label='Validation Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.title('Training and Validation Loss') plt.subplot(1, 2, 2) plt.plot(history.history['accuracy'], label='Training Accuracy') plt.plot(history.history['val_accuracy'], label='Validation Accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.legend() plt.title('Training and Validation Accuracy') plt.show() # Prediction function def predict_image(img): img = np.array(img) img_resized = tf.image.resize(img, (img_height, img_width)) img_4d = tf.expand_dims(img_resized, axis=0) prediction = model.predict(img_4d)[0] return {class_names[i]: float(prediction[i]) for i in range(num_classes)} # Interface image = gr.Image() label = gr.Label(num_top_classes=num_classes) custom_css = """ body { background-image: url('extracted_files/Pest_Dataset/bees/bees (444).jpg'); background-size: cover; background-repeat: no-repeat; background-attachment: fixed; color: white; } """ gr.Interface( fn=predict_image, inputs=image, outputs=label, title="Pest Classification", description="Upload an image of a pest to classify it into one of the predefined categories.", css=custom_css ).launch(debug=True)