image_classification / train_model.py
abuzarAli's picture
Create train_model.py
ed9eb17 verified
raw
history blame
2.1 kB
import os
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
# Set paths to the dataset (adjust paths based on your directory structure)
train_dir = './data/train'
validation_dir = './data/validation'
# Define the CNN model
def create_cnn_model(input_shape=(224, 224, 3)):
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=input_shape))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(1, activation='sigmoid')) # Binary classification (Normal vs Abnormal)
model.compile(optimizer=Adam(), loss='binary_crossentropy', metrics=['accuracy'])
return model
# Create the CNN model
model = create_cnn_model()
# ImageDataGenerator for training and validation
train_datagen = ImageDataGenerator(rescale=1./255, rotation_range=40, width_shift_range=0.2,
height_shift_range=0.2, shear_range=0.2, zoom_range=0.2,
horizontal_flip=True, fill_mode='nearest')
validation_datagen = ImageDataGenerator(rescale=1./255)
# Flow training and validation data from directories
train_generator = train_datagen.flow_from_directory(train_dir, target_size=(224, 224),
batch_size=32, class_mode='binary')
validation_generator = validation_datagen.flow_from_directory(validation_dir, target_size=(224, 224),
batch_size=32, class_mode='binary')
# Train the model
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)
# Save the trained model
model.save('classification_model.h5')