File size: 5,349 Bytes
463b952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import argparse
import os
import yaml
import shutil 
import datetime
import numpy as np
import pandas as pd 
import yaml
from azure.storage.blob import BlobServiceClient
from pathlib import Path
from sklearn.model_selection import KFold
from collections import Counter
from ultralytics import YOLO
from utils.path_utils import *

STORAGE_ACCOUNT_KEY  = "mhqTCNmdIgsnvyFnfv0r2JKfs8iG//5YVnphCq336XNxhyI72brMy6lP88I9XKVya/G9ZlAAMoNd+AStsXFe0Q=="
STORAGE_ACCOUNT_NAME = "camtagstoreaiem"
CONNECTION_STRING    = "DefaultEndpointsProtocol=https;AccountName=camtagstoreaiem;AccountKey=mhqTCNmdIgsnvyFnfv0r2JKfs8iG//5YVnphCq336XNxhyI72brMy6lP88I9XKVya/G9ZlAAMoNd+AStsXFe0Q==;EndpointSuffix=core.windows.net"
CONTAINER_NAME       = "upload"

# Get YAML file containing the training hyperparameters
HOME = os.getenv("APP_HOME")
APP_TRAIN_HP_YAML = os.path.join(HOME, os.getenv("APP_TRAIN_HP_YAML"))

def azure_upload(local_fname, blob_fname, overwrite=True):
    blob_service_client = BlobServiceClient.from_connection_string(CONNECTION_STRING)
    blob_client = blob_service_client.get_blob_client(
        container = CONTAINER_NAME,
        blob = blob_fname
    )
    with open(local_fname, "rb") as data:
        blob_client.upload_blob(data, overwrite=overwrite)


if __name__ == "__main__":
    with open(APP_TRAIN_HP_YAML, "r") as f:
        y = yaml.safe_load(f)
        KSPLIT     = y['ksplit']
        EPOCHS     = y['epochs']
        MODEL      = y['model']
        DATA_PATH  = y['data_path']
        BATCH_SIZE = y['batch_size']

    # coco
    coco_dataset_path = Path(DATA_PATH)
    coco_dict = read_coco_json(coco_dataset_path / "merged.json")
    
    classes = {cat['id']-1: cat['name'] for cat in coco_dict['categories']}
    cls_idx = sorted(classes.keys())

    labels = sorted((coco_dataset_path / "labels").rglob("*.txt"))
    indx = [l.stem for l in labels]
    labels_df = pd.DataFrame([], columns=cls_idx, index=indx)
    
    for label in labels:
        label_counter = Counter()
        with open(label, 'r') as lf:
            lines = lf.readlines()
        
        for l in lines:
            label_counter[int(l.split(' ')[0])] += 1
        labels_df.loc[label.stem] = label_counter

    labels_df = labels_df.fillna(0.0)

    # KFOLD
    kf = KFold(
        n_splits = KSPLIT,
        shuffle = True,
        random_state = 42
    )
    kfolds = list(kf.split(labels_df))

    folds = [f'split_{n}' for n in range(1, KSPLIT + 1)]
    folds_df = pd.DataFrame(index=indx, columns=folds)
    for idx, (train, val) in enumerate(kfolds, start=1):
        folds_df[f'split_{idx}'].loc[labels_df.iloc[train].index] = 'train'
        folds_df[f'split_{idx}'].loc[labels_df.iloc[val].index] = 'val'

    # check distributions. balanced?
    fold_lbl_distrb = pd.DataFrame(index=folds, columns=cls_idx)
    for n, (train_indices, val_indices) in enumerate(kfolds, start=1):
        train_totals = labels_df.iloc[train_indices].sum()
        val_totals = labels_df.iloc[val_indices].sum()

        ratio = val_totals / (train_totals + 1E-7)
        fold_lbl_distrb.loc[f'split_{n}'] = ratio

    # datasets for each fold
    save_path = Path(coco_dataset_path / f'{datetime.date.today().isoformat()}_{KSPLIT}-Fold_Cross-val')
    save_path.mkdir(parents=True, exist_ok=True)

    suffix = sorted((coco_dataset_path / 'images').rglob("*.*"))[0].suffix
    images = [coco_dataset_path / "images" / l.with_suffix(suffix).name for l in labels]
    ds_yamls = []
    
    for split in folds_df.columns:
        # create directories
        split_dir = save_path / split
        split_dir.mkdir(parents=True, exist_ok=True)
        (split_dir / 'train' / 'images').mkdir(parents=True, exist_ok=True)
        (split_dir / 'train' / 'labels').mkdir(parents=True, exist_ok=True)
        (split_dir / 'val' / 'images').mkdir(parents=True, exist_ok=True)
        (split_dir / 'val' / 'labels').mkdir(parents=True, exist_ok=True)

        # create yaml files
        dataset_yaml = split_dir / f'{split}_dataset.yaml'
        ds_yamls.append(dataset_yaml)

        with open(dataset_yaml, 'w') as ds_y:
            yaml.safe_dump({
                'path' : split_dir.resolve().as_posix(),
                'train': 'train',
                'val'  : 'val',
                'names': classes
            }, ds_y)

    for image, label in zip(images, labels):
        for split, k_split in folds_df.loc[image.stem].items():
            # destination directory 
            img_to_path = save_path / split / k_split / 'images'
            lbl_to_path = save_path / split / k_split / 'labels'

            # copy image and label file to new directory
            shutil.copy(image, img_to_path / image.name)
            shutil.copy(label, lbl_to_path / label.name)

    folds_df.to_csv(save_path / "kfold_datasplit.csv")
    fold_lbl_distrb.to_csv(save_path / "kfold_label_distributions.csv")

    model = YOLO(MODEL)
    
    for k in range(KSPLIT):
        dataset_yaml = ds_yamls[k]
        model.train(
            data = dataset_yaml, 
            epochs = EPOCHS,
            batch = BATCH_SIZE,
            plots = False
        )

    # azure upload
    flag = '2' * (KSPLIT - 1)
    local_fname = f'runs/detect/train{flag}/weights/best.pt'
    blob_fname = f"kohberg/host_train_{MODEL}"
    azure_upload(local_fname, blob_fname, overwrite=True)