Spaces:
Running
Running
Upload 18 files
Browse files- .gitattributes +4 -35
- analyze_model.py +36 -0
- app.py +65 -0
- categorizer.py +58 -0
- check_categories.py +42 -0
- check_dataset.py +47 -0
- custom_image_model.pth +3 -0
- domain_config.json +92 -0
- download_images.py +40 -0
- downloadimages.py +33 -0
- huggingfacedownload.py.py +7 -0
- image_model.pth +3 -0
- predict.py +76 -0
- readingfile.py +36 -0
- removebadimage.py +15 -0
- scrapping.py +57 -0
- streamlit_app.py +42 -0
- train_model.py +104 -0
.gitattributes
CHANGED
@@ -1,35 +1,4 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
open-images-dataset-train0.tsv filter=lfs diff=lfs merge=lfs -text
|
2 |
+
News_Category_Dataset_v3.json filter=lfs diff=lfs merge=lfs -text
|
3 |
+
custom_image_model.pth filter=lfs diff=lfs merge=lfs -text
|
4 |
+
image_model.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
analyze_model.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# analyze_model.py
|
2 |
+
import torch
|
3 |
+
from torchvision import models, transforms
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from torchvision.datasets import ImageFolder
|
6 |
+
|
7 |
+
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
|
8 |
+
model.classifier[1] = torch.nn.Linear(1280, 18) # 18 classes
|
9 |
+
model.load_state_dict(torch.load("custom_image_model.pth"))
|
10 |
+
model.eval()
|
11 |
+
|
12 |
+
transform = transforms.Compose([
|
13 |
+
transforms.Resize(256),
|
14 |
+
transforms.CenterCrop(224),
|
15 |
+
transforms.ToTensor(),
|
16 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
17 |
+
])
|
18 |
+
|
19 |
+
dataset = ImageFolder(root="categorized_images", transform=transform)
|
20 |
+
val_loader = DataLoader(dataset, batch_size=16, shuffle=False)
|
21 |
+
|
22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
model.to(device)
|
24 |
+
|
25 |
+
correct = 0
|
26 |
+
total = 0
|
27 |
+
with torch.no_grad():
|
28 |
+
for images, labels in val_loader:
|
29 |
+
images, labels = images.to(device), labels.to(device)
|
30 |
+
outputs = model(images)
|
31 |
+
_, predicted = torch.max(outputs, 1)
|
32 |
+
total += labels.size(0)
|
33 |
+
correct += (predicted == labels).sum().item()
|
34 |
+
|
35 |
+
accuracy = 100 * correct / total
|
36 |
+
print(f"β
Model Accuracy: {accuracy:.2f}% on {total} images")
|
app.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from torchvision import models
|
8 |
+
|
9 |
+
# Set up dataset path
|
10 |
+
DATASET_PATH = "categorized_images"
|
11 |
+
os.makedirs(DATASET_PATH, exist_ok=True)
|
12 |
+
|
13 |
+
# Load class names dynamically from dataset folder
|
14 |
+
class_names = sorted(os.listdir(DATASET_PATH)) # Get categories from folder names
|
15 |
+
num_classes = len(class_names)
|
16 |
+
|
17 |
+
# Load the trained model
|
18 |
+
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
|
19 |
+
model.classifier[1] = torch.nn.Linear(1280, num_classes)
|
20 |
+
model.load_state_dict(torch.load("custom_image_model.pth", map_location=torch.device('cpu')))
|
21 |
+
model.eval()
|
22 |
+
|
23 |
+
# Define image transformation
|
24 |
+
transform = transforms.Compose([
|
25 |
+
transforms.Resize((224, 224)),
|
26 |
+
transforms.ToTensor(),
|
27 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
28 |
+
])
|
29 |
+
|
30 |
+
def predict_and_save(image, filename):
|
31 |
+
"""Predict category and save the image in the correct folder."""
|
32 |
+
image_tensor = transform(image).unsqueeze(0)
|
33 |
+
|
34 |
+
with torch.no_grad():
|
35 |
+
output = model(image_tensor)
|
36 |
+
probabilities = torch.nn.functional.softmax(output, dim=1)
|
37 |
+
predicted_index = torch.argmax(probabilities, dim=1).item()
|
38 |
+
|
39 |
+
predicted_category = class_names[predicted_index]
|
40 |
+
confidence = probabilities[0][predicted_index].item()
|
41 |
+
|
42 |
+
# Ensure category folder exists
|
43 |
+
category_path = os.path.join(DATASET_PATH, predicted_category)
|
44 |
+
os.makedirs(category_path, exist_ok=True)
|
45 |
+
|
46 |
+
# Save image in the correct category folder
|
47 |
+
image_save_path = os.path.join(category_path, filename)
|
48 |
+
image.save(image_save_path)
|
49 |
+
|
50 |
+
return predicted_category, confidence, image_save_path
|
51 |
+
|
52 |
+
# Streamlit UI
|
53 |
+
st.title("π Smart Image Categorizer")
|
54 |
+
st.write("Upload your images and let AI categorize them instantly!")
|
55 |
+
|
56 |
+
uploaded_files = st.file_uploader("Upload images (single or multiple)", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
|
57 |
+
|
58 |
+
if uploaded_files:
|
59 |
+
for uploaded_file in uploaded_files:
|
60 |
+
image = Image.open(uploaded_file).convert("RGB")
|
61 |
+
category, confidence, saved_path = predict_and_save(image, uploaded_file.name)
|
62 |
+
|
63 |
+
st.image(image, caption=f"{uploaded_file.name} β {category} ({confidence:.2%})", use_column_width=True)
|
64 |
+
st.success(f"β
Categorized as: **{category}** (Confidence: {confidence:.2%})")
|
65 |
+
st.info(f"π Image saved to: {saved_path}")
|
categorizer.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from torchvision import models
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
# Load class names dynamically
|
8 |
+
dataset_path = "categorized_images"
|
9 |
+
class_names = sorted(os.listdir(dataset_path)) # Get categories from folder names
|
10 |
+
num_classes = len(class_names)
|
11 |
+
|
12 |
+
# Load trained model
|
13 |
+
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
|
14 |
+
model.classifier[1] = torch.nn.Linear(1280, num_classes)
|
15 |
+
model.load_state_dict(torch.load("custom_image_model.pth", map_location=torch.device("cpu")))
|
16 |
+
model.eval()
|
17 |
+
|
18 |
+
# Define image transformations
|
19 |
+
transform = transforms.Compose([
|
20 |
+
transforms.Resize((224, 224)),
|
21 |
+
transforms.ToTensor(),
|
22 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
23 |
+
])
|
24 |
+
|
25 |
+
def predict_category(image_path):
|
26 |
+
"""Predicts the category of a single image."""
|
27 |
+
image = Image.open(image_path).convert("RGB")
|
28 |
+
image = transform(image).unsqueeze(0)
|
29 |
+
|
30 |
+
with torch.no_grad():
|
31 |
+
output = model(image)
|
32 |
+
probabilities = torch.nn.functional.softmax(output, dim=1)
|
33 |
+
predicted_index = torch.argmax(probabilities, dim=1).item()
|
34 |
+
|
35 |
+
return class_names[predicted_index]
|
36 |
+
|
37 |
+
def categorize_images(image_folder="uncategorized_images", output_folder="categorized_images"):
|
38 |
+
"""Categorizes all images in a folder."""
|
39 |
+
if not os.path.exists(image_folder):
|
40 |
+
print("β Image folder not found!")
|
41 |
+
return
|
42 |
+
|
43 |
+
for img_name in os.listdir(image_folder):
|
44 |
+
img_path = os.path.join(image_folder, img_name)
|
45 |
+
if not os.path.isfile(img_path):
|
46 |
+
continue
|
47 |
+
|
48 |
+
category = predict_category(img_path)
|
49 |
+
category_folder = os.path.join(output_folder, category)
|
50 |
+
os.makedirs(category_folder, exist_ok=True)
|
51 |
+
|
52 |
+
new_path = os.path.join(category_folder, img_name)
|
53 |
+
os.rename(img_path, new_path)
|
54 |
+
print(f"β
Moved {img_name} to {category}/")
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
categorize_images()
|
58 |
+
print("β
Categorization complete!")
|
check_categories.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from collections import Counter
|
4 |
+
from torchvision.datasets import ImageFolder
|
5 |
+
|
6 |
+
# Paths
|
7 |
+
dataset_path = "categorized_images"
|
8 |
+
domain_config_path = "domain_config.json"
|
9 |
+
|
10 |
+
# Load dataset using ImageFolder
|
11 |
+
dataset = ImageFolder(root=dataset_path)
|
12 |
+
|
13 |
+
# Count images in each class
|
14 |
+
category_counts = Counter()
|
15 |
+
for class_idx in dataset.targets:
|
16 |
+
category_counts[dataset.classes[class_idx]] += 1
|
17 |
+
|
18 |
+
# Load domain_config.json
|
19 |
+
with open(domain_config_path, "r") as f:
|
20 |
+
domain_config = json.load(f)
|
21 |
+
|
22 |
+
# Print dataset classes and domain config keys
|
23 |
+
print("\nβ
Dataset Classes from ImageFolder:", dataset.classes)
|
24 |
+
print("\nβ
Categories in domain_config.json:", list(domain_config.keys()))
|
25 |
+
|
26 |
+
# Check if classes match
|
27 |
+
if set(dataset.classes) == set(domain_config.keys()):
|
28 |
+
print("\nβ
Class labels MATCH between dataset and domain_config.json!")
|
29 |
+
else:
|
30 |
+
print("\nβ οΈ WARNING: Mismatch between dataset classes and domain_config.json!")
|
31 |
+
|
32 |
+
# Print category counts
|
33 |
+
print("\nπ Image Count Per Category:")
|
34 |
+
for category, count in category_counts.items():
|
35 |
+
print(f" - {category}: {count} images")
|
36 |
+
|
37 |
+
# Check for empty categories
|
38 |
+
empty_categories = [c for c in dataset.classes if category_counts[c] == 0]
|
39 |
+
if empty_categories:
|
40 |
+
print("\nβ οΈ WARNING: Some categories have 0 images:", empty_categories)
|
41 |
+
else:
|
42 |
+
print("\nβ
All categories have images!")
|
check_dataset.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from torchvision.datasets import ImageFolder
|
4 |
+
from torchvision import transforms
|
5 |
+
|
6 |
+
# Load domain configuration
|
7 |
+
config_path = "domain_config.json"
|
8 |
+
with open(config_path, "r") as f:
|
9 |
+
domain_config = json.load(f)
|
10 |
+
|
11 |
+
# Extract category names from domain_config.json
|
12 |
+
config_categories = list(domain_config.keys())
|
13 |
+
|
14 |
+
# Path to categorized images folder
|
15 |
+
dataset_path = "categorized_images"
|
16 |
+
|
17 |
+
# Apply data augmentation
|
18 |
+
transform = transforms.Compose([
|
19 |
+
transforms.Resize((224, 224)),
|
20 |
+
transforms.RandomHorizontalFlip(p=0.5), # Flip images randomly
|
21 |
+
transforms.RandomRotation(degrees=15), # Rotate images by up to 15 degrees
|
22 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Adjust colors
|
23 |
+
transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0)), # Random crop
|
24 |
+
transforms.ToTensor()
|
25 |
+
])
|
26 |
+
|
27 |
+
# Load dataset with augmentation
|
28 |
+
dataset = ImageFolder(root=dataset_path, transform=transform)
|
29 |
+
|
30 |
+
# Extract dataset categories
|
31 |
+
dataset_categories = dataset.classes
|
32 |
+
|
33 |
+
# Check for inconsistencies
|
34 |
+
print("\nβ
Dataset Classes from ImageFolder:", dataset_categories)
|
35 |
+
print("\nβ
Categories in domain_config.json:", config_categories)
|
36 |
+
|
37 |
+
if set(dataset_categories) != set(config_categories):
|
38 |
+
print("\nβ οΈ WARNING: Mismatch between dataset classes and domain_config.json!")
|
39 |
+
|
40 |
+
# Count images per category
|
41 |
+
print("\nπ Image Count Per Category:")
|
42 |
+
for category, idx in dataset.class_to_idx.items():
|
43 |
+
category_path = os.path.join(dataset_path, category)
|
44 |
+
num_images = len(os.listdir(category_path)) if os.path.exists(category_path) else 0
|
45 |
+
print(f" - {category}: {num_images} images")
|
46 |
+
|
47 |
+
print("\nβ
All categories have images!")
|
custom_image_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:355eb6b4476cbb18d2ce0a104816f85eb12992a023dd223606ccea6b6624bf4e
|
3 |
+
size 9231080
|
domain_config.json
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"Technology": {
|
3 |
+
"descriptions": ["a screenshot of a computer interface", "software application", "tech product"],
|
4 |
+
"keywords": ["code", "programming", "tech", "software"],
|
5 |
+
"weight": 1.0
|
6 |
+
},
|
7 |
+
"Finance": {
|
8 |
+
"descriptions": ["financial dashboard", "banking interface", "stock market"],
|
9 |
+
"keywords": ["bank", "money", "finance", "trading"],
|
10 |
+
"weight": 1.0
|
11 |
+
},
|
12 |
+
"Education": {
|
13 |
+
"descriptions": ["online learning", "educational website", "study materials"],
|
14 |
+
"keywords": ["study", "education", "course", "learn"],
|
15 |
+
"weight": 1.0
|
16 |
+
},
|
17 |
+
"Travel": {
|
18 |
+
"descriptions": ["travel booking", "maps", "travel planning interface"],
|
19 |
+
"keywords": ["travel", "ticket", "destination", "booking"],
|
20 |
+
"weight": 1.0
|
21 |
+
},
|
22 |
+
"Entertainment": {
|
23 |
+
"descriptions": ["streaming platform", "game interface", "media website"],
|
24 |
+
"keywords": ["movie", "game", "netflix", "youtube"],
|
25 |
+
"weight": 1.0
|
26 |
+
},
|
27 |
+
"E-commerce": {
|
28 |
+
"descriptions": ["shopping website", "product page", "marketplace"],
|
29 |
+
"keywords": ["buy", "cart", "shop", "product"],
|
30 |
+
"weight": 1.0
|
31 |
+
},
|
32 |
+
"Social Media": {
|
33 |
+
"descriptions": ["social app", "messaging platform", "social network interface"],
|
34 |
+
"keywords": ["facebook", "chat", "twitter", "instagram"],
|
35 |
+
"weight": 1.0
|
36 |
+
},
|
37 |
+
"News": {
|
38 |
+
"descriptions": ["news article", "news website", "online newspaper"],
|
39 |
+
"keywords": ["headline", "report", "news", "article"],
|
40 |
+
"weight": 1.0
|
41 |
+
},
|
42 |
+
"Productivity": {
|
43 |
+
"descriptions": ["task management", "project tracking", "productivity tool"],
|
44 |
+
"keywords": ["todo", "project", "management", "task"],
|
45 |
+
"weight": 1.0
|
46 |
+
},
|
47 |
+
"Sports": {
|
48 |
+
"descriptions": ["sports news", "sports statistics dashboard", "match schedule application"],
|
49 |
+
"keywords": ["match", "league", "team", "tournament"],
|
50 |
+
"weight": 1.0
|
51 |
+
},
|
52 |
+
"Food & Dining": {
|
53 |
+
"descriptions": ["food delivery", "restaurant review website", "recipe platform"],
|
54 |
+
"keywords": ["restaurant", "recipe", "food", "meal"],
|
55 |
+
"weight": 1.0
|
56 |
+
},
|
57 |
+
"Automotive": {
|
58 |
+
"descriptions": ["car shopping", "vehicle rental", "automobile marketplace"],
|
59 |
+
"keywords": ["vehicle", "rental", "car", "auto"],
|
60 |
+
"weight": 1.0
|
61 |
+
},
|
62 |
+
"Government & Public Services": {
|
63 |
+
"descriptions": ["government website", "public service portal", "tax filing system"],
|
64 |
+
"keywords": ["public service", "tax", "government", "policy"],
|
65 |
+
"weight": 1.0
|
66 |
+
},
|
67 |
+
"Nature": {
|
68 |
+
"descriptions": ["natural scenery", "wildlife photography", "landscape images"],
|
69 |
+
"keywords": ["forest", "mountains", "ocean", "nature"],
|
70 |
+
"weight": 1.0
|
71 |
+
},
|
72 |
+
"Quotes": {
|
73 |
+
"descriptions": ["motivational quotes", "inspirational sayings", "daily wisdom"],
|
74 |
+
"keywords": ["motivation", "inspiration", "quote", "wisdom"],
|
75 |
+
"weight": 1.0
|
76 |
+
},
|
77 |
+
"Resources": {
|
78 |
+
"descriptions": ["learning materials", "skill development resources", "reference guides"],
|
79 |
+
"keywords": ["skills", "reference", "guide", "tutorial"],
|
80 |
+
"weight": 1.0
|
81 |
+
},
|
82 |
+
"Ronaldo": {
|
83 |
+
"descriptions": ["Cristiano Ronaldo images", "football highlights of Ronaldo", "Ronaldo fan pages"],
|
84 |
+
"keywords": ["ronaldo", "football", "soccer", "cr7"],
|
85 |
+
"weight": 1.0
|
86 |
+
},
|
87 |
+
"Motivation": {
|
88 |
+
"descriptions": ["motivational content", "inspirational images", "quote visuals"],
|
89 |
+
"keywords": ["motivation", "inspire", "quote", "uplift"],
|
90 |
+
"weight": 1.0
|
91 |
+
}
|
92 |
+
}
|
download_images.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import requests
|
3 |
+
import os
|
4 |
+
|
5 |
+
file_path = "open-images-dataset-train0.tsv"
|
6 |
+
|
7 |
+
# Read TSV file, skipping the first row
|
8 |
+
df = pd.read_csv(file_path, sep="\t", engine="python", skiprows=1, names=["ImageURL", "Subset", "ImageID"])
|
9 |
+
|
10 |
+
# Print first few rows to verify
|
11 |
+
print("First few rows of the cleaned dataset:")
|
12 |
+
print(df.head())
|
13 |
+
|
14 |
+
# Create a fixed category folder (since 'Subset' contains numbers, not real categories)
|
15 |
+
output_folder = "open_images_v7/dataset"
|
16 |
+
os.makedirs(output_folder, exist_ok=True)
|
17 |
+
|
18 |
+
# Limit downloads to the first 100 images
|
19 |
+
max_images = 100
|
20 |
+
|
21 |
+
for index, row in df.iterrows():
|
22 |
+
if index >= max_images:
|
23 |
+
break # Stop downloading after 100 images
|
24 |
+
|
25 |
+
image_url = row["ImageURL"]
|
26 |
+
image_id = row["ImageID"]
|
27 |
+
|
28 |
+
# Ensure the image filename ends with ".jpg"
|
29 |
+
image_path = os.path.join(output_folder, f"{image_id}.jpg")
|
30 |
+
|
31 |
+
try:
|
32 |
+
response = requests.get(image_url, timeout=10)
|
33 |
+
if response.status_code == 200:
|
34 |
+
with open(image_path, "wb") as f:
|
35 |
+
f.write(response.content)
|
36 |
+
print(f"β
Downloaded: {image_id}.jpg")
|
37 |
+
else:
|
38 |
+
print(f"β Failed: {image_id}")
|
39 |
+
except Exception as e:
|
40 |
+
print(f"β Error downloading {image_id}: {e}")
|
downloadimages.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pandas as pd
|
3 |
+
import requests
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
# Load the CSV file
|
7 |
+
csv_file = "insparation.csv" # Make sure this is the correct file name
|
8 |
+
df = pd.read_csv(csv_file)
|
9 |
+
print("Column Names in CSV:", df.columns.tolist())
|
10 |
+
|
11 |
+
# Ensure the column name matches your file
|
12 |
+
url_column = "Image-link" # Change this if the column name is different
|
13 |
+
|
14 |
+
# Destination folder
|
15 |
+
save_folder = "Motivation"
|
16 |
+
os.makedirs(save_folder, exist_ok=True)
|
17 |
+
|
18 |
+
# Set limit to 80 images
|
19 |
+
num_images = min(80, len(df)) # If there are less than 80 URLs, take all available
|
20 |
+
|
21 |
+
# Download images
|
22 |
+
for idx, url in tqdm(enumerate(df[url_column][:num_images]), total=num_images):
|
23 |
+
try:
|
24 |
+
response = requests.get(url, stream=True)
|
25 |
+
if response.status_code == 200:
|
26 |
+
image_path = os.path.join(save_folder, f"motivation_{idx+1}.jpg")
|
27 |
+
with open(image_path, "wb") as file:
|
28 |
+
for chunk in response.iter_content(1024):
|
29 |
+
file.write(chunk)
|
30 |
+
except Exception as e:
|
31 |
+
print(f"Failed to download {url}: {e}")
|
32 |
+
|
33 |
+
print(f"Downloaded {num_images} images to {save_folder}")
|
huggingfacedownload.py.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import snapshot_download
|
2 |
+
|
3 |
+
# Dataset repository name (from the URL)
|
4 |
+
repo_id = "YashJain/UI-Elements-Detection-Dataset"
|
5 |
+
|
6 |
+
# Download entire dataset
|
7 |
+
snapshot_download(repo_id, repo_type="dataset", local_dir="UI_Dataset")
|
image_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6ea9a835dbab3c1c0e5e37a1a2e15590f2b15a4a15ce3460b807462c3eebe83f
|
3 |
+
size 9233974
|
predict.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from torchvision import models
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
import sys
|
8 |
+
|
9 |
+
# Load class names dynamically from dataset folder
|
10 |
+
dataset_path = "categorized_images"
|
11 |
+
class_names = sorted(os.listdir(dataset_path)) # Get categories from folder names
|
12 |
+
num_classes = len(class_names)
|
13 |
+
|
14 |
+
# Load trained model
|
15 |
+
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
|
16 |
+
model.classifier[1] = torch.nn.Linear(1280, num_classes)
|
17 |
+
model.load_state_dict(torch.load("custom_image_model.pth", map_location=torch.device('cpu')))
|
18 |
+
model.eval()
|
19 |
+
|
20 |
+
# Image transformation
|
21 |
+
transform = transforms.Compose([
|
22 |
+
transforms.Resize((224, 224)),
|
23 |
+
transforms.ToTensor(),
|
24 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
25 |
+
])
|
26 |
+
|
27 |
+
def predict_and_categorize(image_path, move=True):
|
28 |
+
"""Predict category for an image and move it to the correct folder."""
|
29 |
+
try:
|
30 |
+
image = Image.open(image_path).convert("RGB")
|
31 |
+
except Exception as e:
|
32 |
+
print(f"β οΈ Error loading image: {e}")
|
33 |
+
return
|
34 |
+
|
35 |
+
image_tensor = transform(image).unsqueeze(0)
|
36 |
+
|
37 |
+
with torch.no_grad():
|
38 |
+
output = model(image_tensor)
|
39 |
+
probabilities = torch.nn.functional.softmax(output, dim=1)
|
40 |
+
predicted_index = torch.argmax(probabilities, dim=1).item()
|
41 |
+
|
42 |
+
predicted_category = class_names[predicted_index]
|
43 |
+
confidence = probabilities[0][predicted_index].item()
|
44 |
+
|
45 |
+
print(f"β
{image_path} -> **Predicted Category:** {predicted_category} ({confidence:.2%} confidence)")
|
46 |
+
|
47 |
+
# Move image to categorized_images folder
|
48 |
+
if move:
|
49 |
+
category_folder = os.path.join("categorized_images", predicted_category)
|
50 |
+
os.makedirs(category_folder, exist_ok=True)
|
51 |
+
shutil.move(image_path, os.path.join(category_folder, os.path.basename(image_path)))
|
52 |
+
print(f"π Moved to: {category_folder}\n")
|
53 |
+
|
54 |
+
def process_folder(folder_path):
|
55 |
+
"""Process all images in a folder."""
|
56 |
+
if not os.path.exists(folder_path):
|
57 |
+
print(f"β Folder not found: {folder_path}")
|
58 |
+
return
|
59 |
+
|
60 |
+
for file in os.listdir(folder_path):
|
61 |
+
if file.lower().endswith((".png", ".jpg", ".jpeg")):
|
62 |
+
predict_and_categorize(os.path.join(folder_path, file))
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
if len(sys.argv) > 1:
|
66 |
+
input_path = sys.argv[1]
|
67 |
+
|
68 |
+
if os.path.isdir(input_path):
|
69 |
+
print(f"\nπ **Processing folder:** {input_path}\n")
|
70 |
+
process_folder(input_path)
|
71 |
+
elif os.path.isfile(input_path):
|
72 |
+
predict_and_categorize(input_path)
|
73 |
+
else:
|
74 |
+
print("β Invalid path. Please provide an image or folder.")
|
75 |
+
else:
|
76 |
+
print("β οΈ Please provide an image or folder path.")
|
readingfile.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os
|
3 |
+
|
4 |
+
# List of image URLs (Replace this with actual image links)
|
5 |
+
image_urls = [
|
6 |
+
"https://example.com/image1.jpg", # Replace with real image URLs
|
7 |
+
"https://example.com/image2.jpg",
|
8 |
+
# Add more image URLs here
|
9 |
+
]
|
10 |
+
|
11 |
+
# Folder to save images
|
12 |
+
save_folder = "categorized_images/News"
|
13 |
+
os.makedirs(save_folder, exist_ok=True)
|
14 |
+
|
15 |
+
# Function to download images
|
16 |
+
def download_image(url, folder):
|
17 |
+
try:
|
18 |
+
response = requests.get(url, headers={"User-Agent": "Mozilla/5.0"}, timeout=10)
|
19 |
+
|
20 |
+
# Check if response is an image
|
21 |
+
if "image" in response.headers["Content-Type"]:
|
22 |
+
filename = os.path.join(folder, url.split("/")[-1])
|
23 |
+
with open(filename, "wb") as file:
|
24 |
+
file.write(response.content)
|
25 |
+
print(f"β
Downloaded: {filename}")
|
26 |
+
else:
|
27 |
+
print(f"β Not an image: {url}")
|
28 |
+
|
29 |
+
except Exception as e:
|
30 |
+
print(f"β οΈ Error downloading {url}: {e}")
|
31 |
+
|
32 |
+
# Download only the first 80 images
|
33 |
+
for index, url in enumerate(image_urls[:80]): # Limit to 80 images
|
34 |
+
download_image(url, save_folder)
|
35 |
+
|
36 |
+
print("\nπ Done! Downloaded up to 80 images.")
|
removebadimage.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import os
|
3 |
+
|
4 |
+
folder_path = "categorized_images" # Update with your dataset folder
|
5 |
+
|
6 |
+
for root, _, files in os.walk(folder_path):
|
7 |
+
for file in files:
|
8 |
+
file_path = os.path.join(root, file)
|
9 |
+
try:
|
10 |
+
with Image.open(file_path) as img:
|
11 |
+
img.verify() # Verify if it's a valid image
|
12 |
+
except Exception as e:
|
13 |
+
print(f"Corrupt image found: {file_path}, Error: {e}")
|
14 |
+
os.remove(file_path) # Remove corrupt image
|
15 |
+
print(f"Deleted: {file_path}")
|
scrapping.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import requests
|
4 |
+
from selenium import webdriver
|
5 |
+
from selenium.webdriver.chrome.service import Service
|
6 |
+
from selenium.webdriver.common.by import By
|
7 |
+
from webdriver_manager.chrome import ChromeDriverManager
|
8 |
+
from tqdm import tqdm # Progress bar
|
9 |
+
|
10 |
+
# Setup Chrome Driver
|
11 |
+
options = webdriver.ChromeOptions()
|
12 |
+
options.add_argument("--headless") # Run in background
|
13 |
+
options.add_argument("--disable-gpu") # Prevents rendering issues
|
14 |
+
driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=options)
|
15 |
+
|
16 |
+
# Open Pexels search page
|
17 |
+
search_url = "https://www.pexels.com/search/productivity/"
|
18 |
+
driver.get(search_url)
|
19 |
+
|
20 |
+
# Wait for images to load
|
21 |
+
time.sleep(5)
|
22 |
+
|
23 |
+
# Scroll down multiple times to load more images
|
24 |
+
for _ in range(10):
|
25 |
+
driver.execute_script("window.scrollBy(0, 2000);")
|
26 |
+
time.sleep(2) # Wait for new images to load
|
27 |
+
|
28 |
+
# Find all image elements
|
29 |
+
images = driver.find_elements(By.TAG_NAME, "img")
|
30 |
+
|
31 |
+
# Extract Image URLs
|
32 |
+
image_urls = []
|
33 |
+
for img in images:
|
34 |
+
url = img.get_attribute("src")
|
35 |
+
if url and "pexels.com" in url: # Ensure it's a valid image link
|
36 |
+
image_urls.append(url)
|
37 |
+
|
38 |
+
# Keep only the first 100 images
|
39 |
+
image_urls = image_urls[:100]
|
40 |
+
|
41 |
+
# Create folder if not exists
|
42 |
+
save_folder = "Productivity"
|
43 |
+
os.makedirs(save_folder, exist_ok=True)
|
44 |
+
|
45 |
+
# Download and save images
|
46 |
+
for idx, img_url in enumerate(tqdm(image_urls, desc="Downloading Images")):
|
47 |
+
try:
|
48 |
+
img_data = requests.get(img_url).content
|
49 |
+
with open(os.path.join(save_folder, f"image_{idx+1}.jpg"), "wb") as f:
|
50 |
+
f.write(img_data)
|
51 |
+
except Exception as e:
|
52 |
+
print(f"Error downloading image {idx+1}: {e}")
|
53 |
+
|
54 |
+
# Close the browser
|
55 |
+
driver.quit()
|
56 |
+
|
57 |
+
print(f"β
{len(image_urls)} images downloaded in the '{save_folder}' folder.")
|
streamlit_app.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from torchvision import models
|
5 |
+
from PIL import Image
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
|
9 |
+
with open("domain_config.json", "r") as f:
|
10 |
+
domain_config = json.load(f)
|
11 |
+
class_names = list(domain_config.keys())
|
12 |
+
|
13 |
+
num_classes = len(class_names)
|
14 |
+
|
15 |
+
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
|
16 |
+
model.classifier[1] = torch.nn.Linear(1280, num_classes)
|
17 |
+
model.load_state_dict(torch.load("custom_image_model.pth", map_location=torch.device('cpu')))
|
18 |
+
model.eval()
|
19 |
+
|
20 |
+
transform = transforms.Compose([
|
21 |
+
transforms.Resize((224, 224)),
|
22 |
+
transforms.ToTensor(),
|
23 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
24 |
+
])
|
25 |
+
|
26 |
+
st.title("π AI-Powered Image Categorization")
|
27 |
+
|
28 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
29 |
+
if uploaded_file is not None:
|
30 |
+
image = Image.open(uploaded_file).convert("RGB")
|
31 |
+
st.image(image, caption="Uploaded Image", use_column_width=True)
|
32 |
+
|
33 |
+
if st.button("Categorize Image"):
|
34 |
+
image_tensor = transform(image).unsqueeze(0)
|
35 |
+
with torch.no_grad():
|
36 |
+
output = model(image_tensor)
|
37 |
+
probabilities = torch.nn.functional.softmax(output, dim=1)
|
38 |
+
predicted_index = torch.argmax(probabilities, dim=1).item()
|
39 |
+
|
40 |
+
predicted_category = class_names[predicted_index]
|
41 |
+
confidence = probabilities[0][predicted_index].item()
|
42 |
+
st.success(f"β
**Predicted Category:** {predicted_category} ({confidence:.2%} confidence)")
|
train_model.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from torch.utils.data import DataLoader, random_split
|
4 |
+
import torchvision.models as models
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.optim as optim
|
7 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
8 |
+
from torchvision.datasets import ImageFolder
|
9 |
+
import os
|
10 |
+
|
11 |
+
def main():
|
12 |
+
dataset_path = "categorized_images"
|
13 |
+
if not os.path.exists(dataset_path):
|
14 |
+
raise FileNotFoundError(f"β Dataset folder '{dataset_path}' not found!")
|
15 |
+
|
16 |
+
# Get class names dynamically from dataset folders
|
17 |
+
class_names = sorted(os.listdir(dataset_path))
|
18 |
+
num_classes = len(class_names)
|
19 |
+
|
20 |
+
# Data Augmentation & Normalization
|
21 |
+
train_transform = transforms.Compose([
|
22 |
+
transforms.RandomResizedCrop(224),
|
23 |
+
transforms.RandomHorizontalFlip(),
|
24 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
25 |
+
transforms.ToTensor(),
|
26 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
27 |
+
])
|
28 |
+
|
29 |
+
val_transform = transforms.Compose([
|
30 |
+
transforms.Resize(256),
|
31 |
+
transforms.CenterCrop(224),
|
32 |
+
transforms.ToTensor(),
|
33 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
34 |
+
])
|
35 |
+
|
36 |
+
dataset = ImageFolder(root=dataset_path, transform=train_transform)
|
37 |
+
train_size = int(0.8 * len(dataset))
|
38 |
+
val_size = len(dataset) - train_size
|
39 |
+
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
40 |
+
|
41 |
+
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
|
42 |
+
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)
|
43 |
+
|
44 |
+
# Load Pretrained Model
|
45 |
+
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
|
46 |
+
|
47 |
+
# Freeze all layers except the classifier
|
48 |
+
for param in model.features.parameters():
|
49 |
+
param.requires_grad = False
|
50 |
+
|
51 |
+
# Update the classifier for our dataset
|
52 |
+
model.classifier[1] = nn.Linear(1280, num_classes)
|
53 |
+
|
54 |
+
# Unfreeze last 3 layers to fine-tune
|
55 |
+
for param in model.features[-3:].parameters():
|
56 |
+
param.requires_grad = True
|
57 |
+
|
58 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
59 |
+
model.to(device)
|
60 |
+
|
61 |
+
criterion = nn.CrossEntropyLoss()
|
62 |
+
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
63 |
+
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.1)
|
64 |
+
|
65 |
+
best_val_loss = float('inf')
|
66 |
+
for epoch in range(30):
|
67 |
+
model.train()
|
68 |
+
train_loss = 0.0
|
69 |
+
for images, labels in train_loader:
|
70 |
+
images, labels = images.to(device), labels.to(device)
|
71 |
+
optimizer.zero_grad()
|
72 |
+
outputs = model(images)
|
73 |
+
loss = criterion(outputs, labels)
|
74 |
+
loss.backward()
|
75 |
+
optimizer.step()
|
76 |
+
train_loss += loss.item()
|
77 |
+
avg_train_loss = train_loss / len(train_loader)
|
78 |
+
|
79 |
+
model.eval()
|
80 |
+
val_loss, correct, total = 0.0, 0, 0
|
81 |
+
with torch.no_grad():
|
82 |
+
for images, labels in val_loader:
|
83 |
+
images, labels = images.to(device), labels.to(device)
|
84 |
+
outputs = model(images)
|
85 |
+
loss = criterion(outputs, labels)
|
86 |
+
val_loss += loss.item()
|
87 |
+
_, predicted = torch.max(outputs, 1)
|
88 |
+
total += labels.size(0)
|
89 |
+
correct += (predicted == labels).sum().item()
|
90 |
+
avg_val_loss = val_loss / len(val_loader)
|
91 |
+
val_accuracy = 100 * correct / total
|
92 |
+
|
93 |
+
print(f"π’ Epoch [{epoch+1}/30] β Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Accuracy: {val_accuracy:.2f}%")
|
94 |
+
scheduler.step(avg_val_loss)
|
95 |
+
|
96 |
+
if avg_val_loss < best_val_loss:
|
97 |
+
best_val_loss = avg_val_loss
|
98 |
+
torch.save(model.state_dict(), "custom_image_model.pth")
|
99 |
+
print("β
Best model saved!")
|
100 |
+
|
101 |
+
print("π Training Complete!")
|
102 |
+
|
103 |
+
if __name__ == '__main__':
|
104 |
+
main()
|