Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| import PIL | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| from tensorflow.keras.models import Sequential | |
| from PIL import Image | |
| import gdown | |
| import zipfile | |
| import pathlib | |
| # Define the Google Drive shareable link | |
| gdrive_url = 'https://drive.google.com/file/d/1HjHYlQyRz5oWt8kehkt1TiOGRRlKFsv8/view?usp=drive_link' | |
| # Extract the file ID from the URL | |
| file_id = gdrive_url.split('/d/')[1].split('/view')[0] | |
| direct_download_url = f'https://drive.google.com/uc?id={file_id}' | |
| # Define the local filename to save the ZIP file | |
| local_zip_file = 'file.zip' | |
| # Download the ZIP file | |
| gdown.download(direct_download_url, local_zip_file, quiet=False) | |
| # Directory to extract files | |
| extracted_path = 'extracted_files' | |
| # Verify if the downloaded file is a ZIP file and extract it | |
| try: | |
| with zipfile.ZipFile(local_zip_file, 'r') as zip_ref: | |
| zip_ref.extractall(extracted_path) | |
| print("Extraction successful!") | |
| except zipfile.BadZipFile: | |
| print("Error: The downloaded file is not a valid ZIP file.") | |
| # Optionally, you can delete the ZIP file after extraction | |
| os.remove(local_zip_file) | |
| # Convert the extracted directory path to a pathlib.Path object | |
| data_dir = pathlib.Path(extracted_path) | |
| # Print the directory structure to debug | |
| for root, dirs, files in os.walk(extracted_path): | |
| level = root.replace(extracted_path, '').count(os.sep) | |
| indent = ' ' * 4 * (level) | |
| print(f"{indent}{os.path.basename(root)}/") | |
| subindent = ' ' * 4 * (level + 1) | |
| for f in files: | |
| print(f"{subindent}{f}") | |
| import pathlib | |
| # Path to the dataset directory | |
| data_dir = pathlib.Path('extracted_files/Pest_Dataset') | |
| data_dir = pathlib.Path(data_dir) | |
| bees = list(data_dir.glob('bees/*')) | |
| print(bees[0]) | |
| PIL.Image.open(str(bees[0])) | |
| img_height, img_width = 180, 180 | |
| batch_size = 32 | |
| train_ds = tf.keras.preprocessing.image_dataset_from_directory( | |
| data_dir, | |
| validation_split=0.2, | |
| subset="training", | |
| seed=123, | |
| image_size=(img_height, img_width), | |
| batch_size=batch_size | |
| ) | |
| val_ds = tf.keras.preprocessing.image_dataset_from_directory( | |
| data_dir, | |
| validation_split=0.2, | |
| subset="validation", | |
| seed=123, | |
| image_size=(img_height, img_width), | |
| batch_size=batch_size | |
| ) | |
| class_names = train_ds.class_names | |
| print(class_names) | |
| plt.figure(figsize=(10, 10)) | |
| for images, labels in train_ds.take(1): | |
| for i in range(9): | |
| ax = plt.subplot(3, 3, i + 1) | |
| plt.imshow(images[i].numpy().astype("uint8")) | |
| plt.title(class_names[labels[i]]) | |
| plt.axis("off") | |
| data_augmentation = keras.Sequential( | |
| [ | |
| layers.RandomFlip("horizontal", | |
| input_shape=(img_height, | |
| img_width, | |
| 3)), | |
| layers.RandomRotation(0.1), | |
| layers.RandomZoom(0.1), | |
| layers.RandomContrast(0.1), | |
| layers.RandomBrightness(0.1) | |
| ] | |
| ) | |
| plt.figure(figsize=(10, 10)) | |
| for images, _ in train_ds.take(1): | |
| for i in range(9): | |
| augmented_images = data_augmentation(images) | |
| ax = plt.subplot(3, 3, i + 1) | |
| plt.imshow(augmented_images[0].numpy().astype("uint8")) | |
| plt.axis("off") | |
| num_classes = len(class_names) | |
| model = Sequential([ | |
| data_augmentation, | |
| layers.Rescaling(1./255), | |
| layers.Conv2D(32, 3, padding='same', activation='relu'), | |
| layers.MaxPooling2D(), | |
| layers.Conv2D(64, 3, padding='same', activation='relu'), | |
| layers.MaxPooling2D(), | |
| layers.Conv2D(128, 3, padding='same', activation='relu'), | |
| layers.MaxPooling2D(), | |
| layers.Dropout(0.5), | |
| layers.Flatten(), | |
| layers.Dense(256, activation='relu'), | |
| layers.Dropout(0.5), | |
| layers.Dense(num_classes, activation='softmax', name="outputs") | |
| ]) | |
| model.compile(optimizer='adam', | |
| loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), | |
| metrics=['accuracy']) | |
| model.summary() | |
| # Learning rate scheduler | |
| lr_scheduler = keras.callbacks.LearningRateScheduler(lambda epoch: 1e-3 * 10**(epoch / 20)) | |
| # Early stopping | |
| early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True) | |
| epochs = 20 | |
| history = model.fit( | |
| train_ds, | |
| validation_data=val_ds, | |
| epochs=epochs, | |
| callbacks=[lr_scheduler, early_stopping] | |
| ) | |
| # Define category descriptions | |
| category_descriptions = { | |
| "Ants": "Ants are small insects known for their complex social structures and teamwork.", | |
| "Bees": "Bees are flying insects known for their role in pollination and producing honey.", | |
| "Beetles": "Beetles are a group of insects with hard exoskeletons and wings. They are the largest order of insects.", | |
| "Caterpillars": "Caterpillars are the larval stage of butterflies and moths, known for their voracious appetite.", | |
| "Earthworms": "Earthworms are segmented worms that are crucial for soil health and nutrient cycling.", | |
| "Earwigs": "Earwigs are insects with pincers on their abdomen and are known for their nocturnal activity.", | |
| "Grasshoppers": "Grasshoppers are insects known for their powerful hind legs, which they use for jumping.", | |
| "Moths": "Moths are nocturnal insects related to butterflies, known for their attraction to light.", | |
| "Slugs": "Slugs are soft-bodied mollusks that are similar to snails but lack a shell.", | |
| "Snails": "Snails are mollusks with a coiled shell, known for their slow movement and slimy trail.", | |
| "Wasps": "Wasps are stinging insects that can be solitary or social, and some species are important pollinators.", | |
| "Weevils": "Weevils are a type of beetle with a long snout, known for being pests to crops and stored grains." | |
| } | |
| # Define the prediction function | |
| def predict_image(img): | |
| img = np.array(img) | |
| img_resized = tf.image.resize(img, (180, 180)) | |
| img_4d = tf.expand_dims(img_resized, axis=0) | |
| prediction = model.predict(img_4d)[0] | |
| predicted_class = np.argmax(prediction) | |
| predicted_label = class_names[predicted_class] | |
| predicted_description = category_descriptions[predicted_label] | |
| return {predicted_label: f"{float(prediction[predicted_class]):.2f} - {predicted_description}"} | |
| # Set up Gradio interface | |
| image = gr.Image() | |
| label = gr.Label(num_top_classes=1) | |
| # Define custom CSS for background image | |
| custom_css = """ | |
| body { | |
| background-image: url('extracted_files/Pest_Dataset/bees/bees (444).jpg'); | |
| background-size: cover; | |
| background-repeat: no-repeat; | |
| background-attachment: fixed; | |
| color: white; | |
| } | |
| """ | |
| gr.Interface( | |
| fn=predict_image, | |
| inputs=image, | |
| outputs=label, | |
| title="Welcome to Agricultural Pest Image Classification", | |
| description="The image data set used was obtained from Kaggle and has a collection of 12 different types of agricultural pests: Ants, Bees, Beetles, Caterpillars, Earthworms, Earwigs, Grasshoppers, Moths, Slugs, Snails, Wasps, and Weevils", | |
| css=custom_css | |
| ).launch(debug=True) | |