Spaces:
Sleeping
Sleeping
import torch.nn | |
from torch.utils.data import DataLoader | |
from utils.data import FFTDataset, SplitDataset, AudioINRDataset | |
from datasets import load_dataset | |
from utils.train import Trainer, INRTrainer | |
from utils.models import MultiGraph, ImplicitEncoder | |
from omegaconf import OmegaConf | |
# from .utils.models import CNNKan, KanEncoder | |
from utils.inr import INR | |
from utils.data_utils import * | |
from huggingface_hub import login | |
import yaml | |
import datetime | |
import json | |
import numpy as np | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
from scipy.signal import savgol_filter as savgol | |
from utils.kan import FasterKAN | |
from utils.relational_transformer import RelationalTransformer | |
from collections import OrderedDict | |
def plot_results(dims, i, data, losses, pred_values): | |
data = savgol(data.cpu().detach().numpy(), window_length=250, polyorder=1) | |
pred_values = pred_values.transpose(-1, -2).unflatten(-1, data.shape[-2:]).squeeze(0).cpu().detach().numpy() | |
pred_values = (pred_values - np.min(pred_values)) / (np.max(pred_values) - np.min(pred_values)) | |
data = (data - np.min(data)) / (np.max(data) - np.min(data)) | |
plt.plot(data.squeeze()) | |
plt.plot(pred_values.squeeze()) | |
# axes[0].set_title('Original') | |
# axes[1].set_title('Reconstruction') | |
plt.show() | |
# plt.plot(np.arange(len(losses)), losses) | |
# plt.xlabel('Iteration') | |
# plt.ylabel('Reconstruction MSE Error') | |
# plt.show() | |
# local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
current_date = datetime.date.today().strftime("%Y-%m-%d") | |
datetime_dir = f"frugal_{current_date}" | |
args_dir = 'utils/config.yaml' | |
data_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Data']) | |
exp_num = data_args.exp_num | |
model_name = data_args.model_name | |
rt_args = Container(**yaml.safe_load(open(args_dir, 'r'))['RelationalTransformer']) | |
cnn_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f']) | |
conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer']) | |
kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN_INR']) | |
inr_args = Container(**yaml.safe_load(open(args_dir, 'r'))['INR']) | |
if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"): | |
os.makedirs(f"{data_args.log_dir}/{datetime_dir}") | |
with open("../../logs/token.txt", "r") as f: | |
api_key = f.read() | |
# local_rank, world_size, gpus_per_node = setup() | |
local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
login(api_key) | |
dataset = load_dataset("rfcx/frugalai", streaming=True) | |
train_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=True) | |
train_dl = DataLoader(train_ds, batch_size=data_args.batch_size) | |
val_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=False) | |
val_dl = DataLoader(val_ds, batch_size=data_args.batch_size) | |
test_ds = AudioINRDataset(FFTDataset(dataset["test"])) | |
test_dl = DataLoader(test_ds, batch_size=data_args.batch_size) | |
# for i, batch in enumerate(train_ds): | |
# fft_phase, fft_mag, audio = batch['audio']['fft_phase'], batch['audio']['fft_mag'], batch['audio']['array'] | |
# label = batch['label'] | |
# fig, axes = plt.subplots(nrows=1, ncols=3) | |
# axes = axes.flatten() | |
# axes[0].plot(fft_phase) | |
# axes[1].plot(fft_mag) | |
# axes[2].plot(audio) | |
# fig.suptitle(label) | |
# plt.tight_layout() | |
# plt.show() | |
# if i > 20: | |
# break | |
# model = DualEncoder(model_args, model_args_f, conformer_args) | |
# model = FasterKAN([18000,64,64,16,1]) | |
# model = INR(in_features=1) | |
# model.kan.speed() | |
# model = KanEncoder(kan_args.get_dict()) | |
# model = model.to(local_rank) | |
# state_dict = torch.load(data_args.checkpoint_path, map_location=torch.device('cpu')) | |
# new_state_dict = OrderedDict() | |
# for key, value in state_dict.items(): | |
# if key.startswith('module.'): | |
# key = key[7:] | |
# new_state_dict[key] = value | |
# missing, unexpected = model.load_state_dict(new_state_dict) | |
# model = DDP(model, device_ids=[local_rank], output_device=local_rank) | |
# num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
# print(f"Number of parameters: {num_params}") | |
# | |
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
# total_steps = int(data_args.num_epochs) * 1000 | |
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, | |
# T_max=total_steps, | |
# eta_min=float((5e-4) / 10)) | |
loss_fn = torch.nn.BCEWithLogitsLoss() | |
inr_criterion = torch.nn.MSELoss() | |
# for i, batch in enumerate(train_ds): | |
# coords, fft, audio = batch['audio']['coords'], batch['audio']['fft_mag'], batch['audio']['array'] | |
# coords = coords.to(local_rank) | |
# fft = fft.to(local_rank) | |
# audio = audio.to(local_rank) | |
# values = torch.cat((audio.unsqueeze(-1), fft.unsqueeze(-1)), dim=-1) | |
# # model = INR(hidden_features=128, n_layers=3, | |
# # in_features=1, | |
# # out_features=1).to(local_rank) | |
# model = FasterKAN(**kan_args.get_dict()).to(local_rank) | |
# optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-3) | |
# pbar = tqdm(range(200)) | |
# losses = [] | |
# print(coords.shape) | |
# for t in pbar: | |
# optimizer.zero_grad() | |
# pred_values = model(coords.to(local_rank)).float() | |
# loss = inr_criterion(pred_values, values) | |
# loss.backward() | |
# optimizer.step() | |
# pbar.set_description(f'loss: {loss.item()}') | |
# losses.append(loss.item()) | |
# state_dict = model.state_dict() | |
# torch.save(state_dict, 'test') | |
# # print(f'Sample {i+offset} label {label} saved in {inr_path}') | |
# plot_results(1, i, fft, losses, pred_values) | |
# # | |
# exit() | |
# missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path)) | |
# print(f"Missing keys: {missing}") | |
# print(f"Unexpected keys: {unexpected}") | |
layer_layout = [inr_args.in_features] + [inr_args.hidden_features for _ in range(inr_args.n_layers)] + [inr_args.out_features] | |
graph_constructor = OmegaConf.create( | |
{ | |
"_target_": "utils.graph_constructor.GraphConstructor", | |
"_recursive_": False, | |
"_convert_": "all", | |
"d_in": 1, | |
"d_edge_in": 1, | |
"zero_out_bias": False, | |
"zero_out_weights": False, | |
"sin_emb": True, | |
"sin_emb_dim": rt_args.d_node, | |
"use_pos_embed": False, | |
"input_layers": 1, | |
"inp_factor": 1, | |
"num_probe_features": 0, | |
"inr_model": None, | |
"stats": None, | |
"sparsify": False, | |
'sym_edges': False, | |
} | |
) | |
rt_model = RelationalTransformer(layer_layout=layer_layout, graph_constructor=graph_constructor, | |
**rt_args.get_dict()).to(local_rank) | |
rt_model.proj_out= torch.nn.Identity() | |
multi_graph = MultiGraph(rt_model, cnn_args) | |
implicit_net = INR(**inr_args.get_dict()) | |
model = ImplicitEncoder(implicit_net, multi_graph).to(local_rank) | |
num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(f"Number of parameters: {num_parameters}") | |
optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-3) | |
trainer = Trainer(model=model, optimizer=optimizer, | |
criterion=loss_fn, output_dim=1, scaler=None, | |
scheduler=None, train_dataloader=train_dl, | |
val_dataloader=val_dl, device=local_rank, | |
exp_num=datetime_dir, log_path=data_args.log_dir, | |
range_update=None, | |
accumulation_step=1, max_iter=100, | |
exp_name=f"frugal_kan_{exp_num}") | |
fit_res = trainer.fit(num_epochs=100, device=local_rank, | |
early_stopping=10, only_p=False, best='loss', conf=True) | |
output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json' | |
with open(output_filename, "w") as f: | |
json.dump(fit_res, f, indent=2) | |
preds, acc = trainer.predict(test_dl, local_rank) | |
print(f"Accuracy: {acc}") | |