from torch.utils.data import DataLoader from .utils.data import FFTDataset, SplitDataset from datasets import load_dataset from .utils.train import Trainer, XGBoostTrainer from .utils.models import CNNKan, KanEncoder, CNNKanFeaturesEncoder from .utils.data_utils import * from huggingface_hub import login import yaml import datetime import json import numpy as np import pandas as pd import seaborn as sns import matplotlib.pyplot as plt from collections import OrderedDict # local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu') current_date = datetime.date.today().strftime("%Y-%m-%d") datetime_dir = f"frugal_{current_date}" args_dir = 'tasks/utils/config.yaml' data_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Data']) exp_num = data_args.exp_num model_name = data_args.model_name model_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder']) mlp_args = Container(**yaml.safe_load(open(args_dir, 'r'))['MLP']) model_args_f = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f']) conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer']) kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN']) boost_args = Container(**yaml.safe_load(open(args_dir, 'r'))['XGBoost']) if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"): os.makedirs(f"{data_args.log_dir}/{datetime_dir}") with open("../logs//token.txt", "r") as f: api_key = f.read() # local_rank, world_size, gpus_per_node = setup() local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu') login(api_key) dataset = load_dataset("rfcx/frugalai", streaming=True) train_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=True) train_dl = DataLoader(train_ds, batch_size=data_args.batch_size, collate_fn=collate_fn) val_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=False) val_dl = DataLoader(val_ds,batch_size=data_args.batch_size, collate_fn=collate_fn) test_ds = FFTDataset(dataset["test"]) test_dl = DataLoader(test_ds,batch_size=data_args.batch_size, collate_fn=collate_fn) # data = [] # # # Iterate over the dataset # for i, batch in enumerate(train_ds): # label = batch['label'] # features = batch['audio']['features'] # # # Flatten the nested dictionary structure # feature_dict = {'label': label} # for k, v in features.items(): # if isinstance(v, dict): # for sub_k, sub_v in v.items(): # feature_dict[f"{k}_{sub_k}"] = sub_v[0].item() # Aggregate (e.g., mean) # else: # print(k, v.shape) # Aggregate (e.g., mean) # # data.append(feature_dict) # print(i) # # if i > 1000: # Limit to 10 iterations # break # # # Convert to DataFrame # df = pd.DataFrame(data) # Plot distributions colored by label # plt.figure() # for col in df.columns: # if col != 'label': # sns.kdeplot(df, x=col, hue='label', fill=True, alpha=0.5) # plt.title(f'Distribution of {col}') # plt.show() # exit() # trainer = XGBoostTrainer(boost_args.get_dict(), train_ds, val_ds, test_ds) # res = trainer.fit() # trainer.predict() # trainer.plot_results(res) # exit() # model = DualEncoder(model_args, model_args_f, conformer_args) # model = FasterKAN([18000,64,64,16,1]) model = CNNKan(model_args, conformer_args, kan_args.get_dict()) # model = CNNKanFeaturesEncoder(model_args, mlp_args, kan_args.get_dict()) # model.kan.speed() # model = KanEncoder(kan_args.get_dict()) model = model.to(local_rank) # state_dict = torch.load(data_args.checkpoint_path, map_location=torch.device('cpu')) # new_state_dict = OrderedDict() # for key, value in state_dict.items(): # if key.startswith('module.'): # key = key[7:] # new_state_dict[key] = value # missing, unexpected = model.load_state_dict(new_state_dict) # model = DDP(model, device_ids=[local_rank], output_device=local_rank) num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Number of parameters: {num_params}") loss_fn = torch.nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) total_steps = int(data_args.num_epochs) * 1000 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=float((5e-4)/10)) # missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path)) # print(f"Missing keys: {missing}") # print(f"Unexpected keys: {unexpected}") trainer = Trainer(model=model, optimizer=optimizer, criterion=loss_fn, output_dim=model_args.output_dim, scaler=None, scheduler=None, train_dataloader=train_dl, val_dataloader=val_dl, device=local_rank, exp_num=datetime_dir, log_path=data_args.log_dir, range_update=None, accumulation_step=1, max_iter=np.inf, exp_name=f"frugal_kan_{exp_num}") fit_res = trainer.fit(num_epochs=100, device=local_rank, early_stopping=10, only_p=False, best='loss', conf=True) output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json' with open(output_filename, "w") as f: json.dump(fit_res, f, indent=2) preds, tru, acc = trainer.predict(test_dl, local_rank) print(f"Accuracy: {acc}")