File size: 3,884 Bytes
49ebc1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from torch.utils.data import DataLoader
from .utils.data import FFTDataset, SplitDataset
from datasets import load_dataset
from .utils.train import Trainer
from .utils.models import CNNKan, KanEncoder
from .utils.data_utils import *
from huggingface_hub import login
import yaml
import datetime
import json
import numpy as np

# local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
current_date = datetime.date.today().strftime("%Y-%m-%d")
datetime_dir = f"frugal_{current_date}"
args_dir = 'tasks/utils/config.yaml'
data_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Data'])
exp_num = data_args.exp_num
model_name = data_args.model_name
model_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder'])
model_args_f = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f'])
conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer'])
kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN'])
if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"):
    os.makedirs(f"{data_args.log_dir}/{datetime_dir}")

with open("../logs//token.txt", "r") as f:
    api_key = f.read()

# local_rank, world_size, gpus_per_node = setup()
local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
login(api_key)
dataset = load_dataset("rfcx/frugalai", streaming=True)

train_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=True)
  
train_dl = DataLoader(train_ds, batch_size=data_args.batch_size, collate_fn=collate_fn)

val_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=False)
    
val_dl = DataLoader(val_ds,batch_size=data_args.batch_size, collate_fn=collate_fn)

test_ds = FFTDataset(dataset["test"])
test_dl = DataLoader(test_ds,batch_size=data_args.batch_size, collate_fn=collate_fn)

# for i, batch in enumerate(train_dl):
#     x, x_f, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
#     print(x.shape, x_f.shape, y.shape)
#     if i > 10:
#         break
# exit()

# model = DualEncoder(model_args, model_args_f, conformer_args)
# model = FasterKAN([18000,64,64,16,1])
model = CNNKan(model_args, conformer_args, kan_args.get_dict())
# model.kan.speed()
# model = KanEncoder(kan_args.get_dict())
model = model.to(local_rank)
# model = DDP(model, device_ids=[local_rank], output_device=local_rank)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of parameters: {num_params}")

loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
total_steps = int(data_args.num_epochs) * 1000
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                    T_max=total_steps,
                                                    eta_min=float((5e-4)/10))

# missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path))
# print(f"Missing keys: {missing}")
# print(f"Unexpected keys: {unexpected}")

trainer = Trainer(model=model, optimizer=optimizer,
                        criterion=loss_fn, output_dim=model_args.output_dim, scaler=None,
                       scheduler=None, train_dataloader=train_dl,
                       val_dataloader=val_dl, device=local_rank,
                           exp_num=datetime_dir, log_path=data_args.log_dir,
                            range_update=None,
                           accumulation_step=1, max_iter=np.inf,
                           exp_name=f"frugal_kan_{exp_num}")
fit_res = trainer.fit(num_epochs=100, device=local_rank,
                        early_stopping=10, only_p=False, best='loss', conf=True)
output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json'
with open(output_filename, "w") as f:
    json.dump(fit_res, f, indent=2)
preds, acc = trainer.predict(test_dl, local_rank)
print(f"Accuracy: {acc}")