Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import matplotlib.pyplot as plt | |
| from scipy.integrate import odeint | |
| import torch | |
| from torch.utils import data | |
| from torch.utils.data import DataLoader, Dataset | |
| from torch import nn, optim | |
| from skimage.transform import rescale, resize | |
| from torch import nn, optim | |
| import torch.nn.functional as F | |
| from torch.utils.data import Subset | |
| from scipy.interpolate import interp1d | |
| import collections | |
| import numpy as np | |
| import random | |
| #for pvloop simulator: | |
| import pandas as pd | |
| from scipy.integrate import odeint | |
| import torchvision | |
| import echonet | |
| import matplotlib.animation as animation | |
| from functools import partial | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| sequences_all = [] | |
| info_data_all = [] | |
| path = 'EchoNet-Dynamic' | |
| output_path = '' | |
| class Echo(torchvision.datasets.VisionDataset): | |
| """EchoNet-Dynamic Dataset. | |
| Args: | |
| root (string): Root directory of dataset (defaults to `echonet.config.DATA_DIR`) | |
| split (string): One of {``train'', ``val'', ``test'', ``all'', or ``external_test''} | |
| target_type (string or list, optional): Type of target to use, | |
| ``Filename'', ``EF'', ``EDV'', ``ESV'', ``LargeIndex'', | |
| ``SmallIndex'', ``LargeFrame'', ``SmallFrame'', ``LargeTrace'', | |
| or ``SmallTrace'' | |
| Can also be a list to output a tuple with all specified target types. | |
| The targets represent: | |
| ``Filename'' (string): filename of video | |
| ``EF'' (float): ejection fraction | |
| ``EDV'' (float): end-diastolic volume | |
| ``ESV'' (float): end-systolic volume | |
| ``LargeIndex'' (int): index of large (diastolic) frame in video | |
| ``SmallIndex'' (int): index of small (systolic) frame in video | |
| ``LargeFrame'' (np.array shape=(3, height, width)): normalized large (diastolic) frame | |
| ``SmallFrame'' (np.array shape=(3, height, width)): normalized small (systolic) frame | |
| ``LargeTrace'' (np.array shape=(height, width)): left ventricle large (diastolic) segmentation | |
| value of 0 indicates pixel is outside left ventricle | |
| 1 indicates pixel is inside left ventricle | |
| ``SmallTrace'' (np.array shape=(height, width)): left ventricle small (systolic) segmentation | |
| value of 0 indicates pixel is outside left ventricle | |
| 1 indicates pixel is inside left ventricle | |
| Defaults to ``EF''. | |
| mean (int, float, or np.array shape=(3,), optional): means for all (if scalar) or each (if np.array) channel. | |
| Used for normalizing the video. Defaults to 0 (video is not shifted). | |
| std (int, float, or np.array shape=(3,), optional): standard deviation for all (if scalar) or each (if np.array) channel. | |
| Used for normalizing the video. Defaults to 0 (video is not scaled). | |
| length (int or None, optional): Number of frames to clip from video. If ``None'', longest possible clip is returned. | |
| Defaults to 16. | |
| period (int, optional): Sampling period for taking a clip from the video (i.e. every ``period''-th frame is taken) | |
| Defaults to 2. | |
| max_length (int or None, optional): Maximum number of frames to clip from video (main use is for shortening excessively | |
| long videos when ``length'' is set to None). If ``None'', shortening is not applied to any video. | |
| Defaults to 250. | |
| clips (int, optional): Number of clips to sample. Main use is for test-time augmentation with random clips. | |
| Defaults to 1. | |
| pad (int or None, optional): Number of pixels to pad all frames on each side (used as augmentation). | |
| and a window of the original size is taken. If ``None'', no padding occurs. | |
| Defaults to ``None''. | |
| noise (float or None, optional): Fraction of pixels to black out as simulated noise. If ``None'', no simulated noise is added. | |
| Defaults to ``None''. | |
| target_transform (callable, optional): A function/transform that takes in the target and transforms it. | |
| external_test_location (string): Path to videos to use for external testing. | |
| """ | |
| def __init__(self, root=None, | |
| split="train", target_type="EF", | |
| mean=0., std=1., | |
| length=16, period=2, | |
| max_length=250, | |
| clips=1, | |
| pad=None, | |
| noise=None, | |
| target_transform=None, | |
| external_test_location=None): | |
| if root is None: | |
| root = path | |
| super().__init__(root, target_transform=target_transform) | |
| self.split = split.upper() | |
| if not isinstance(target_type, list): | |
| target_type = [target_type] | |
| self.target_type = target_type | |
| self.mean = mean | |
| self.std = std | |
| self.length = length | |
| self.max_length = max_length | |
| self.period = period | |
| self.clips = clips | |
| self.pad = pad | |
| self.noise = noise | |
| self.target_transform = target_transform | |
| self.external_test_location = external_test_location | |
| self.fnames, self.outcome = [], [] | |
| if self.split == "EXTERNAL_TEST": | |
| self.fnames = sorted(os.listdir(self.external_test_location)) | |
| else: | |
| # Load video-level labels | |
| with open(f"{self.root}/FileList.csv") as f: | |
| data = pd.read_csv(f) | |
| data["Split"].map(lambda x: x.upper()) | |
| if self.split != "ALL": | |
| data = data[data["Split"] == self.split] | |
| self.header = data.columns.tolist() | |
| self.fnames = data["FileName"].tolist() | |
| self.fnames = [fn + ".avi" for fn in self.fnames if os.path.splitext(fn)[1] == ""] # Assume avi if no suffix | |
| self.outcome = data.values.tolist() | |
| # Check that files are present | |
| """ | |
| missing = set(self.fnames) - set(os.listdir(os.path.join(self.root, "Videos"))) | |
| if len(missing) != 0: | |
| print("{} videos could not be found in {}:".format(len(missing), os.path.join(self.root, "Videos"))) | |
| for f in sorted(missing): | |
| print("\t", f) | |
| raise FileNotFoundError(os.path.join(self.root, "Videos", sorted(missing)[0])) | |
| """ | |
| # Load traces | |
| self.frames = collections.defaultdict(list) | |
| self.trace = collections.defaultdict(_defaultdict_of_lists) | |
| with open(f"{self.root}/VolumeTracings.csv") as f: | |
| header = f.readline().strip().split(",") | |
| assert header == ["FileName", "X1", "Y1", "X2", "Y2", "Frame"] | |
| for line in f: | |
| filename, x1, y1, x2, y2, frame = line.strip().split(',') | |
| x1 = float(x1) | |
| y1 = float(y1) | |
| x2 = float(x2) | |
| y2 = float(y2) | |
| frame = int(frame) | |
| if frame not in self.trace[filename]: | |
| self.frames[filename].append(frame) | |
| self.trace[filename][frame].append((x1, y1, x2, y2)) | |
| for filename in self.frames: | |
| for frame in self.frames[filename]: | |
| self.trace[filename][frame] = np.array(self.trace[filename][frame]) | |
| # A small number of videos are missing traces; remove these videos | |
| keep = [len(self.frames[f]) >= 2 for f in self.fnames] | |
| self.fnames = [f for (f, k) in zip(self.fnames, keep) if k] | |
| self.outcome = [f for (f, k) in zip(self.outcome, keep) if k] | |
| def __getitem__(self, index): | |
| # Find filename of video | |
| if self.split == "EXTERNAL_TEST": | |
| video = os.path.join(self.external_test_location, self.fnames[index]) | |
| elif self.split == "CLINICAL_TEST": | |
| video = os.path.join(self.root, "ProcessedStrainStudyA4c", self.fnames[index]) | |
| else: | |
| video = os.path.join(self.root, "Videos", self.fnames[index]) | |
| # Load video into np.array | |
| video = echonet.utils.loadvideo(video).astype(np.float32) | |
| # Add simulated noise (black out random pixels) | |
| # 0 represents black at this point (video has not been normalized yet) | |
| if self.noise is not None: | |
| n = video.shape[1] * video.shape[2] * video.shape[3] | |
| ind = np.random.choice(n, round(self.noise * n), replace=False) | |
| f = ind % video.shape[1] | |
| ind //= video.shape[1] | |
| i = ind % video.shape[2] | |
| ind //= video.shape[2] | |
| j = ind | |
| video[:, f, i, j] = 0 | |
| # Apply normalization | |
| if isinstance(self.mean, (float, int)): | |
| video -= self.mean | |
| else: | |
| video -= self.mean.reshape(3, 1, 1, 1) | |
| if isinstance(self.std, (float, int)): | |
| video /= self.std | |
| else: | |
| video /= self.std.reshape(3, 1, 1, 1) | |
| # Set number of frames | |
| c, f, h, w = video.shape | |
| if self.length is None: | |
| # Take as many frames as possible | |
| length = f // self.period | |
| else: | |
| # Take specified number of frames | |
| length = self.length | |
| if self.max_length is not None: | |
| # Shorten videos to max_length | |
| length = min(length, self.max_length) | |
| if f < length * self.period: | |
| # Pad video with frames filled with zeros if too short | |
| # 0 represents the mean color (dark grey), since this is after normalization | |
| video = np.concatenate((video, np.zeros((c, length * self.period - f, h, w), video.dtype)), axis=1) | |
| c, f, h, w = video.shape # pylint: disable=E0633 | |
| if self.clips == "all": | |
| # Take all possible clips of desired length | |
| start = np.arange(f - (length - 1) * self.period) | |
| else: | |
| # Take random clips from video | |
| start = np.random.choice(f - (length - 1) * self.period, self.clips) | |
| # Gather targets | |
| target = [] | |
| for t in self.target_type: | |
| key = self.fnames[index] | |
| if t == "Filename": | |
| target.append(self.fnames[index]) | |
| elif t == "LargeIndex": | |
| # Traces are sorted by cross-sectional area | |
| # Largest (diastolic) frame is last | |
| target.append(int(self.frames[key][-1])) | |
| elif t == "SmallIndex": | |
| # Largest (diastolic) frame is first | |
| target.append(int(self.frames[key][0])) | |
| elif t == "LargeFrame": | |
| target.append(video[:, self.frames[key][-1], :, :]) | |
| elif t == "SmallFrame": | |
| target.append(video[:, self.frames[key][0], :, :]) | |
| elif t in ["LargeTrace", "SmallTrace"]: | |
| if t == "LargeTrace": | |
| t = self.trace[key][self.frames[key][-1]] | |
| else: | |
| t = self.trace[key][self.frames[key][0]] | |
| x1, y1, x2, y2 = t[:, 0], t[:, 1], t[:, 2], t[:, 3] | |
| x = np.concatenate((x1[1:], np.flip(x2[1:]))) | |
| y = np.concatenate((y1[1:], np.flip(y2[1:]))) | |
| r, c = skimage.draw.polygon(np.rint(y).astype(np.int), np.rint(x).astype(np.int), (video.shape[2], video.shape[3])) | |
| mask = np.zeros((video.shape[2], video.shape[3]), np.float32) | |
| mask[r, c] = 1 | |
| target.append(mask) | |
| else: | |
| if self.split == "CLINICAL_TEST" or self.split == "EXTERNAL_TEST": | |
| target.append(np.float32(0)) | |
| else: | |
| target.append(np.float32(self.outcome[index][self.header.index(t)])) | |
| if target != []: | |
| target = tuple(target) if len(target) > 1 else target[0] | |
| if self.target_transform is not None: | |
| target = self.target_transform(target) | |
| # Select clips from video | |
| video = tuple(video[:, s + self.period * np.arange(length), :, :] for s in start) | |
| if self.clips == 1: | |
| video = video[0] | |
| else: | |
| video = np.stack(video) | |
| if self.pad is not None: | |
| # Add padding of zeros (mean color of videos) | |
| # Crop of original size is taken out | |
| # (Used as augmentation) | |
| c, l, h, w = video.shape | |
| temp = np.zeros((c, l, h + 2 * self.pad, w + 2 * self.pad), dtype=video.dtype) | |
| temp[:, :, self.pad:-self.pad, self.pad:-self.pad] = video # pylint: disable=E1130 | |
| i, j = np.random.randint(0, 2 * self.pad, 2) | |
| video = temp[:, :, i:(i + h), j:(j + w)] | |
| return video, target | |
| def __len__(self): | |
| return len(self.fnames) | |
| def extra_repr(self) -> str: | |
| """Additional information to add at end of __repr__.""" | |
| lines = ["Target type: {target_type}", "Split: {split}"] | |
| return '\n'.join(lines).format(**self.__dict__) | |
| def _defaultdict_of_lists(): | |
| """Returns a defaultdict of lists. | |
| This is used to avoid issues with Windows (if this function is anonymous, | |
| the Echo dataset cannot be used in a dataloader). | |
| """ | |
| return collections.defaultdict(list) | |
| ## | |
| print("Done loading training data!") | |
| # define normalization layer to make sure output xi in an interval [ai, bi]: | |
| # define normalization layer to make sure output xi in an interval [ai, bi]: | |
| class IntervalNormalizationLayer(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # new_output = [Tc, start_p, Emax, Emin, Rm, Ra, Vd] | |
| self.a = torch.tensor([0.4, 0., 0.5, 0.02, 0.005, 0.0001, 4.], dtype=torch.float32) #HR in 20-200->Tc in [0.3, 4] | |
| self.b = torch.tensor([1.7, 280., 3.5, 0.1, 0.1, 0.25, 16.], dtype=torch.float32) | |
| #taken out (initial conditions): a: 20, 5, 50; b: 400, 20, 100 | |
| def forward(self, inputs): | |
| sigmoid_output = torch.sigmoid(inputs) | |
| scaled_output = sigmoid_output * (self.b - self.a) + self.a | |
| return scaled_output | |
| class NEW3DCNN(nn.Module): | |
| def __init__(self, num_parameters): | |
| super(NEW3DCNN, self).__init__() | |
| self.conv1 = nn.Conv3d(3, 8, kernel_size=3, padding=1) | |
| self.batchnorm1 = nn.BatchNorm3d(8) | |
| self.conv2 = nn.Conv3d(8, 16, kernel_size=3, padding=1) | |
| self.batchnorm2 = nn.BatchNorm3d(16) | |
| self.conv3 = nn.Conv3d(16, 32, kernel_size=3, padding=1) | |
| self.batchnorm3 = nn.BatchNorm3d(32) | |
| self.conv4 = nn.Conv3d(32, 64, kernel_size=3, padding=1) | |
| self.batchnorm4 = nn.BatchNorm3d(64) | |
| self.conv5 = nn.Conv3d(64, 128, kernel_size=3, padding=1) | |
| self.batchnorm5 = nn.BatchNorm3d(128) | |
| self.pool = nn.AdaptiveAvgPool3d(1) | |
| self.fc1 = nn.Linear(128, 512) | |
| self.fc2 = nn.Linear(512, num_parameters) | |
| self.norm1 = IntervalNormalizationLayer() | |
| def forward(self, x): | |
| x = F.relu(self.batchnorm1(self.conv1(x))) | |
| x = F.max_pool3d(x, kernel_size=2, stride=2) | |
| x = F.relu(self.batchnorm2(self.conv2(x))) | |
| x = F.max_pool3d(x, kernel_size=2, stride=2) | |
| x = F.relu(self.batchnorm3(self.conv3(x))) | |
| x = F.max_pool3d(x, kernel_size=2, stride=2) | |
| x = F.relu(self.batchnorm4(self.conv4(x))) | |
| x = F.max_pool3d(x, kernel_size=2, stride=2) | |
| x = F.relu(self.batchnorm5(self.conv5(x))) | |
| x = self.pool(x) | |
| x = x.view(x.size(0), -1) | |
| x = F.relu(self.fc1(x)) | |
| x = self.fc2(x) | |
| x = self.norm1(x) | |
| return x | |
| # Define a neural network with one hidden layer | |
| class Interpolator(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.fc1 = nn.Linear(6, 250).double() | |
| self.fc2 = nn.Linear(250, 2).double() | |
| def forward(self, x): | |
| x = torch.relu(self.fc1(x)) | |
| x = self.fc2(x) | |
| return x | |
| # Initialize the neural network | |
| net = Interpolator() | |
| net.load_state_dict(torch.load('final_model_weights/interp6_7param_weight.pt')) | |
| print("Done loading interpolator!") | |
| weights_path = 'final_model_weights/202_full_echonet_7param_Vloss_epoch_200_lr_0.001_weight_best_model.pt' | |
| model = NEW3DCNN(num_parameters = 7) | |
| model.load_state_dict(torch.load(weights_path)) | |
| model.to(device) | |
| ## PV loops | |
| #returns Plv at time t using Elastance(t) and Vlv(t)-Vd=x1 | |
| def Plv(volume, Emax, Emin, t, Tc, Vd): | |
| return Elastance(Emax,Emin,t, Tc)*(volume - Vd) | |
| #returns Elastance(t) | |
| def Elastance(Emax,Emin, t, Tc): | |
| t = t-int(t/Tc)*Tc #can remove this if only want 1st ED (and the 1st ES before) | |
| tn = t/(0.2+0.15*Tc) | |
| return (Emax-Emin)*1.55*(tn/0.7)**1.9/((tn/0.7)**1.9+1.0)*1.0/((tn/1.17)**21.9+1.0) + Emin | |
| def solve_ODE_for_volume(Rm, Ra, Emax, Emin, Vd, Tc, start_v, t): | |
| # the ODE from Simaan et al 2008 | |
| def heart_ode(y, t, Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emax, Emin, Tc): | |
| x1, x2, x3, x4, x5 = y #here y is a vector of 5 values (not functions), at time t, used for getting (dy/dt)(t) | |
| P_lv = Plv(x1+Vd,Emax,Emin,t,Tc,Vd) | |
| dydt = [r(x2-P_lv)/Rm-r(P_lv-x4)/Ra, (x3-x2)/(Rs*Cr)-r(x2-P_lv)/(Cr*Rm), (x2-x3)/(Rs*Cs)+x5/Cs, -x5/Ca+r(P_lv-x4)/(Ca*Ra), (x4-x3-Rc*x5)/Ls] | |
| return dydt | |
| # RELU for diodes | |
| def r(u): | |
| return max(u, 0.) | |
| # Define fixed parameters | |
| Rs = 1.0 | |
| Rc = 0.0398 | |
| Ca = 0.08 | |
| Cs = 1.33 | |
| Cr = 4.400 | |
| Ls = 0.0005 | |
| startp = 75. | |
| # Initial conditions | |
| start_pla = float(start_v*Elastance(Emax, Emin, 0., Tc)) | |
| start_pao = startp | |
| start_pa = start_pao | |
| start_qt = 0 #aortic flow is Q_T and is 0 at ED, also see Fig5 in simaan2008dynamical | |
| y0 = [start_v, start_pla, start_pa, start_pao, start_qt] | |
| # Solve | |
| sol = odeint(heart_ode, y0, t, args = (Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emax, Emin, Tc)) #t: list of values | |
| # volume is the first state variable plus theoretical zero pressure volume | |
| volumes = np.array(sol[:, 0]) + Vd | |
| return volumes | |
| def pvloop_simulator(Rm, Ra, Emax, Emin, Vd, Tc, start_v, animate=True): | |
| # Define initial parameters | |
| init_Emax = Emax # 3.0 # .5 to 3.5 | |
| init_Emin = Emin # 0.04 # .02 to .1 | |
| # init_Tc = Tc # .4 # .4 to 1.7 | |
| init_Vd = Vd # 10.0 # 0 to 25 | |
| # DUMMY VOLUME | |
| # def volume(t, Tc): | |
| # return 50*np.sin(2 * np.pi * t*(1/Tc))+100 | |
| # SOLVE the ODE model for the VOLUME CURVE | |
| N = 100 | |
| t = np.linspace(0, Tc*N, int(60000*N)) #np.linspace(1, 100, 1000000) | |
| volumes = solve_ODE_for_volume(Rm, Ra, Emax, Emin, Vd, Tc, start_v, t) | |
| # FUNCTIONS for PRESSURE CURVE | |
| vectorized_Elastance = np.vectorize(Elastance) | |
| vectorized_Plv = np.vectorize(Plv) | |
| def pressure(t, volume, Emax, Emin, Tc, Vd): | |
| return vectorized_Plv(volume, Emax, Emin, t, Tc, Vd) | |
| # calculate PRESSURE | |
| pressures = pressure(t, volumes, init_Emax, init_Emin, Tc, init_Vd) | |
| # Create the figure and the loop that we will manipulate | |
| fig, ax = plt.subplots() | |
| plt.ylim((0,220)) | |
| plt.xlim((0,250)) | |
| start = (N-2)*60000 | |
| end = (N)*60000 | |
| if animate: | |
| line = ax.plot(volumes[start:(start+1)], pressures[start:(start+1)], lw=1, color='b') | |
| point = ax.scatter(volumes[start:(start+1)], pressures[start:(start+1)], c="b", s=5)#, label='End Diastole') | |
| #point = ax.scatter(volumes[start:(start+1)], pressures[start:(start+1)], c="b", s=5, label='End Systole') | |
| else: | |
| line = ax.plot(volumes[start:end], pressures[start:end], lw=1, color='b') | |
| plt.title('Predicted PI-SSL LV Pressure Volume Loop', fontsize=16) | |
| #plt.rcParams['fig.suptitle'] = -2.0 | |
| #ax.set_title(f'Mitral valve circuit resistance (Rm): {Rm} mmHg*s/ml \n Aortic valve circuit resistance (Ra): {Ra} mmHg*s/ml', fontsize=6) | |
| ax.set_xlabel('LV Volume (ml)') | |
| ax.set_ylabel('LV Pressure (mmHg)') | |
| # adjust the main plot to make room for the sliders | |
| # fig.subplots_adjust(left=0.25, bottom=0.25) | |
| def update(frame): | |
| # update to add more of the loop | |
| end = (N-2)*60000+1000 * frame | |
| x = volumes[start:end] | |
| y = pressures[start:end] | |
| ax.plot(x, y, lw=1, c='b') | |
| if animate: | |
| anim = animation.FuncAnimation(fig, partial(update), frames=100, interval=30) | |
| anim.save("prediction.gif") | |
| return plt, Rm, Ra, Emax, Emin, Vd, Tc, start_v | |
| def pvloop_simulator_plot_only(Rm, Ra, Emax, Emin, Vd, Tc, start_v): | |
| plot,_,_,_,_,_,_,_ =pvloop_simulator(Rm, Ra, Emax, Emin, Vd, Tc, start_v, animate=False) | |
| plt.title('Simulated PI-SSL LV Pressure Volume Loop', fontsize=16) | |
| return plot | |
| ######################################### | |
| # LVAD functions | |
| # RELU for diodes | |
| def r(u): | |
| return max(u, 0.) | |
| def heart_ode0(y, t, Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emax, Emin, Tc, Vd): | |
| x1, x2, x3, x4, x5 = y #here y is a vector of 5 values (not functions), at time t, used for getting (dy/dt)(t) | |
| P_lv = Plv(x1+Vd,Emax,Emin,t,Tc,Vd) | |
| dydt = [r(x2-P_lv)/Rm-r(P_lv-x4)/Ra, (x3-x2)/(Rs*Cr)-r(x2-P_lv)/(Cr*Rm), (x2-x3)/(Rs*Cs)+x5/Cs, -x5/Ca+r(P_lv-x4)/(Ca*Ra), (x4-x3-Rc*x5)/Ls] | |
| return dydt | |
| def getslope(y1, y2, y3, x1, x2, x3): | |
| sum_x = x1 + x2 + x3 | |
| sum_y = y1 + y2 + y3 | |
| sum_xy = x1*y1 + x2*y2 + x3*y3 | |
| sum_xx = x1*x1 + x2*x2 + x3*x3 | |
| # calculate the coefficients of the least-squares line | |
| n = 3 | |
| slope = (n*sum_xy - sum_x*sum_y) / (n*sum_xx - sum_x*sum_x) | |
| return slope | |
| ### ODE: for each t (here fixed), gives dy/dt as a function of y(t) at that t, so can be used for integrating the vector y over time | |
| #it is run for each t going from 0 to tmax | |
| def lvad_ode(y, t, Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emax, Emin, Tc, Vd, ratew): | |
| #from simaan2008dynamical: | |
| Ri = 0.0677 | |
| R0 = 0.0677 | |
| Rk = 0.0 | |
| x1bar = 1. | |
| alpha = -3.5 | |
| Li = 0.0127 | |
| L0 = 0.0127 | |
| b0 = -0.296 | |
| b1 = -0.027 | |
| b2 = 9.9025e-7 | |
| x1, x2, x3, x4, x5, x6, x7 = y #here y is a vector of 5 values (not functions), at time t, used for getting (dy/dt)(t) | |
| P_lv = Plv(x1+Vd,Emax,Emin,t,Tc,Vd) | |
| if (P_lv <= x1bar): Rk = alpha * (P_lv - x1bar) | |
| Lstar = Li + L0 + b1 | |
| Lstar2 = -Li -L0 +b1 | |
| Rstar = Ri + R0 + Rk + b0 | |
| dydt = [-x6 + r(x2-P_lv)/Rm-r(P_lv-x4)/Ra, (x3-x2)/(Rs*Cr)-r(x2-P_lv)/(Cr*Rm), (x2-x3)/(Rs*Cs)+x5/Cs, -x5/Ca+r(P_lv-x4)/(Ca*Ra) + x6/Ca, (x4-x3)/Ls-Rc*x5/Ls, -P_lv / Lstar2 + x4/Lstar2 + (Ri+R0+Rk-b0) / Lstar2 * x6 - b2 / Lstar2 * x7**2, ratew] | |
| return dydt | |
| #returns pv loop and ef when there is no lvad: | |
| def f_nolvad(Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emin, Vd, Tc, start_v, Emax, showpvloop): | |
| N = 20 | |
| start_pla = float(start_v*Elastance(Emax, Emin, 0.0, Tc)) | |
| start_pao = 75. | |
| start_pa = start_pao | |
| start_qt = 0.0 #aortic flow is Q_T and is 0 at ED, also see Fig5 in simaan2008dynamical | |
| y0 = [start_v, start_pla, start_pa, start_pao, start_qt] | |
| t = np.linspace(0, Tc*N, int(60000*N)) #spaced numbers over interval (start, stop, number_of_steps), 60000 time instances for each heart cycle | |
| #changed to 60000 for having integer positions for Tmax | |
| #obtain 5D vector solution: | |
| sol = odeint(heart_ode0, y0, t, args = (Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emax, Emin, Tc,Vd)) #t: list of values | |
| result_Vlv = np.array(sol[:, 0]) + Vd | |
| result_Plv = np.array([Plv(v+Vd, Emax, Emin, xi, Tc, Vd) for xi,v in zip(t,sol[:, 0])]) | |
| #if showpvloop: plt.plot(result_Vlv[18*60000:20*60000], result_Plv[18*60000:20*60000], color='black', label='Without LVAD') | |
| ved = sol[19*60000, 0] + Vd | |
| ves = sol[200*int(60/Tc)+9000+19*60000, 0] + Vd | |
| ef = (ved-ves)/ved * 100. | |
| minv = min(result_Vlv[19*60000:20*60000-1]) | |
| minp = min(result_Plv[19*60000:20*60000-1]) | |
| result_pao = np.array(sol[:, 3]) | |
| pao_ed = min(result_pao[(N-1)*60000:N*60000-1]) | |
| pao_es = max(result_pao[(N-1)*60000:N*60000-1]) | |
| return ef, pao_ed, pao_es, ((ved - ves) * 60/Tc ) / 1000, sol[19*60000, 0], sol[19*60000, 1], sol[19*60000, 2], sol[19*60000, 3], sol[19*60000, 4], result_Vlv[18*60000:20*60000], result_Plv[18*60000:20*60000] | |
| #returns the w at which suction occurs: (i.e. for which the slope of the envelopes of x6 becomes negative) | |
| def get_suctionw(Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emin, Vd, Tc, start_v, Emax, y00, y01, y02, y03, y04, w0, x60, ratew): #slope is slope0 for w | |
| N = 70 | |
| start_pla = float(start_v*Elastance(Emax, Emin, 0.0, Tc)) | |
| start_pao = 75. | |
| start_pa = start_pao | |
| start_qt = 0 #aortic flow is Q_T and is 0 at ED, also see Fig5 in simaan2008dynamical | |
| y0 = [start_v, start_pla, start_pa, start_pao, start_qt, x60, w0] | |
| y0 = [y00, y01, y02, y03, y04, x60, w0] | |
| ncycle = 20000 | |
| n = N * ncycle | |
| sol = np.zeros((n, 7)) | |
| t = np.linspace(0., Tc * N, n) | |
| for j in range(7): | |
| sol[0][j] = y0[j] | |
| result_Vlv = [] | |
| result_Plv = [] | |
| result_x6 = [] | |
| result_x7 = [] | |
| envx6 = [] | |
| timesenvx6 = [] | |
| slopes = [] | |
| ws = [] | |
| minx6 = 99999 | |
| tmin = 0 | |
| tlastupdate = 0 | |
| lastw = w0 | |
| update = 1 | |
| #solve the ODE step by step by adding dydt*dt: | |
| for j in range(0, n-1): | |
| #update y with dydt * dt | |
| y = sol[j] | |
| dydt = lvad_ode(y, t[j], Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emax, Emin, Tc, Vd, ratew) | |
| for k in range(7): | |
| dydt[k] = dydt[k] * (t[j+1] - t[j]) | |
| sol[j+1] = sol[j] + dydt | |
| #update the min of x6 in the current cylce. also keep the time at which the min is obtained (for getting the slope later) | |
| if (minx6 > sol[j][5]): | |
| minx6 = sol[j][5] | |
| tmin = t[j] | |
| #add minimum of x6 once each cycle ends: (works). then reset minx6 to 99999 for calculating again the minimum | |
| if (j%ncycle==0 and j>1): | |
| envx6.append(minx6) | |
| timesenvx6.append(tmin) | |
| minx6 = 99999 | |
| if (len(envx6)>=3): | |
| slope = getslope(envx6[-1], envx6[-2], envx6[-3], timesenvx6[-1], timesenvx6[-2], timesenvx6[-3]) | |
| slopes.append(slope) | |
| ws.append(y[6]) | |
| for i in range(n): | |
| result_x6.append(sol[i, 5]) | |
| result_x7.append(sol[i, 6]) | |
| suction_w = 0 | |
| for i in range(2, len(slopes)): | |
| if (slopes[i] < 0): | |
| suction_w = ws[i-1] | |
| break | |
| return suction_w | |
| def f_lvad(Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emin, Vd, Tc, start_v, Emax, c, slope, w0, x60, y00, y01, y02, y03, y04): #slope is slope0 for w | |
| N = 70 | |
| y0 = [y00, y01, y02, y03, y04, x60, w0] | |
| ncycle = 10000 | |
| n = N * ncycle | |
| sol = np.zeros((n, 7)) | |
| t = np.linspace(0., Tc * N, n) | |
| for j in range(7): | |
| sol[0][j] = y0[j] | |
| result_Vlv = [] | |
| result_Plv = [] | |
| result_x6 = [] | |
| result_x7 = [] | |
| envx6 = [] | |
| timesenvx6 = [] | |
| minx6 = 99999 | |
| tmin = 0 | |
| tlastupdate = 0 | |
| lastw = w0 | |
| update = 1 | |
| ratew = 0 #6000/60 | |
| #solve the ODE step by step by adding dydt*dt: | |
| for j in range(0, n-1): | |
| #update y with dydt * dt | |
| y = sol[j] | |
| dydt = lvad_ode(y, t[j], Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emax, Emin, Tc, Vd, ratew) | |
| for k in range(7): | |
| dydt[k] = dydt[k] * (t[j+1] - t[j]) | |
| sol[j+1] = sol[j] + dydt | |
| #update the min of x6 in the current cylce. also keep the time at which the min is obtained (for getting the slope later) | |
| if (minx6 > sol[j][5]): | |
| minx6 = sol[j][5] | |
| tmin = t[j] | |
| #add minimum of x6 once each cycle ends: (works). then reset minx6 to 99999 for calculating again the minimum | |
| if (j%ncycle==0 and j>1): | |
| envx6.append(minx6) | |
| timesenvx6.append(tmin) | |
| minx6 = 99999 | |
| #update w (if 0.005 s. have passed since the last update): | |
| if (slope<0): | |
| update = 0 | |
| if (t[j+1] - tlastupdate > 0.005 and slope>0 and update==1): #abs(slope)>0.0001 | |
| # if there are enough points of envelope: calculate slope: | |
| if (len(envx6)>=3): | |
| slope = getslope(envx6[-1], envx6[-2], envx6[-3], timesenvx6[-1], timesenvx6[-2], timesenvx6[-3]) | |
| sol[j+1][6] = lastw + c * slope | |
| #otherwise: take arbitrary rate (see Fig. 16a in simaan2008dynamical) | |
| else: | |
| sol[j+1][6] = lastw + 0.005 * slope | |
| #save w(k) (see formula (8) simaan2008dynamical) and the last time of update t[j] (will have to wait 0.005 s for next update of w) | |
| tlastupdate = t[j+1] | |
| lastw = sol[j+1][6] | |
| #save functions and print MAP, CO: | |
| map = 0 | |
| Pao = [] | |
| for i in range(n): | |
| result_Vlv.append(sol[i, 0] + Vd) | |
| result_Plv.append(Plv(sol[i, 0]+Vd, Emax, Emin, t[i], Tc, Vd)) | |
| result_x6.append(sol[i, 5]) | |
| result_x7.append(sol[i, 6]) | |
| Pao.append(sol[i, 3]) | |
| colors0=np.zeros((len(result_Vlv[65*ncycle:70*ncycle]), 3)) | |
| for col in colors0: | |
| col[0]=41/255 | |
| col[1]=128/255 | |
| col[2]=205/255 | |
| #get co and ef: | |
| ved = max(result_Vlv[50 * ncycle:52 * ncycle]) | |
| ves = min(result_Vlv[50 * ncycle:52 * ncycle]) | |
| #ves = result_Vlv[50 * ncycle + int(ncycle * 0.2 /Tc + 0.15 * ncycle)] | |
| ef = (ved-ves)/ved*100 | |
| CO = ((ved - ves) * 60/Tc ) / 1000 | |
| #get MAP: | |
| for i in range(n - 5*ncycle, n): | |
| map += sol[i, 2] | |
| map = map/(5*ncycle) | |
| result_pao = np.array(sol[:, 3]) | |
| pao_ed = min(Pao[50 * ncycle:52 * ncycle]) | |
| pao_es = max(Pao[50 * ncycle:52 * ncycle]) | |
| return ef, pao_ed, pao_es, CO, map, result_Vlv[65*ncycle:70*ncycle], result_Plv[65*ncycle:70*ncycle] | |
| ############################# | |
| ## Demo functions | |
| def generate_example(): | |
| # get random input | |
| data_path = 'EchoNet-Dynamic' | |
| image_data = Echo(root = data_path, split = 'all', target_type=['Filename','LargeIndex','SmallIndex']) | |
| image_loaded_data = DataLoader(image_data, batch_size=30, shuffle=True) | |
| val_data = next(iter(image_loaded_data)) | |
| #create_echo_clip(val_data,'test') | |
| val_seq = val_data[0] | |
| val_tensor = torch.tensor(val_seq, dtype=torch.float32) | |
| n=random.randint(0, 27) | |
| results = model(val_tensor)[n] | |
| filename = val_data[1][0][n] | |
| video = f"EchoNet-Dynamic/Videos/{filename}" | |
| plot, Rm, Ra, Emax, Emin, Vd,Tc, start_v = pvloop_simulator(Rm=round(results[4].item(),2), Ra=round(results[5].item(),2), Emax=round(results[2].item(),2), Emin=round(results[3].item(),2), Vd=round(results[6].item(),2), Tc=round(results[0].item(),2), start_v=round(results[1].item(),2)) | |
| video = video.replace("avi", "mp4") | |
| animated = "<img src='prediction.gif' alt='pv_loop'>" # "prediction.gif" # style="width:48px;height:48px;" | |
| return video, animated, Rm, Ra, Emax, Emin, Vd, Tc, start_v | |
| def lvad_plots(Rm, Ra, Emax, Emin, Vd, Tc, start_v, beta): | |
| ncycle = 10000 | |
| Rs = 1. | |
| Rc = 0.0398 | |
| Ca= 0.08 | |
| Cs= 1.33 | |
| Cr= 4.4 | |
| Ls=0.0005 | |
| #get values for periodic loops: | |
| ef_nolvad, pao_ed, pao_es, co_nolvad, y00, y01, y02, y03, y04, Vlv0, Plv0 = f_nolvad(Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emin, Vd,Tc, start_v, Emax, 0.0) | |
| #pao_eds = [pao_ed] | |
| #pao_ess = [pao_es] | |
| #get suction w: (make w go linearly from w0 to w0 + maxtime * 400, and find w at which suction occurs) | |
| w0 = 5000. | |
| ratew = 400. | |
| x60 = 0. | |
| suctionw = get_suctionw(Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emin, Vd, Tc, start_v, Emax, y00, y01, y02, y03, y04, w0, x60, ratew) | |
| #gamma = 1.8 | |
| c = 0.065 #(in simaan2008dynamical: 0.67, but too fast -> 0.061 gives better shape) | |
| slope0 = 100. | |
| w0 = suctionw * beta #if doesn't work (x6 negative), change gamma down to 1.4 or up to 2.1 # switch to beta = 1/gamma 3/12 for interpretability | |
| #compute new pv loops and ef with lvad added: | |
| new_ef, pao_ed, pao_es, CO, MAP, Vlvs, Plvs = f_lvad(Rs, Rm, Ra, Rc, Ca, Cs, Cr, Ls, Emin, Vd, Tc, start_v, Emax, c, slope0, w0, x60, y00, y01, y02, y03, y04) | |
| fig, ax = plt.subplots() | |
| ax.plot(Vlv0, Plv0, color='blue', label='No LVAD') #blue | |
| ax.plot(Vlvs, Plvs, color=(78/255, 192/255, 44/255), label=f"LVAD, ω(0)= {round(w0,2)}r/min") #green | |
| plt.xlabel('LV volume (ml)') | |
| plt.ylabel('LV pressure (mmHg)') | |
| plt.legend(loc='upper left', framealpha=1) | |
| plt.ylim((0,220)) | |
| plt.xlim((0,250)) | |
| #plt.title('Simulated PI-SSL LV Pressure Volume Loop', fontsize=16) | |
| return plt, round(ef_nolvad,2), round(new_ef,2), round(co_nolvad,2), round(CO, 2) | |
| title = "<h1 style='text-align: center; margin-bottom: 1rem'> Non-Invasive Medical Digital Twins using Physics-Informed Self-Supervised Learning </h1>" | |
| description = """ | |
| <p style='text-align: center'> Keying Kuang, Frances Dean, Jack B. Jedlicki, David Ouyang, Anthony Philippakis, David Sontag, Ahmed Alaa <br></p> | |
| <p> We develop methodology for predicting digital twins from non-invasive cardiac ultrasound images in <a href='https://arxiv.org/abs/2403.00177'>Non-Invasive Medical Digital Twins using Physics-Informed Self-Supervised Learning</a>. Check out our <a href='https://github.com/AlaaLab/CardioPINN' target='_blank'>code.</a> \n \n | |
| We demonstrate the ability of our model to predict left ventricular pressure-volume loops using image data here. To run example predictions on samples from the <a href='https://echonet.github.io/dynamic/'>EchoNet</a> dataset, click the first button. \n \n | |
| </p> | |
| """ | |
| title2 = "<h3 style='text-align: center'> Physics-based model simulation</h3>" | |
| description2 = """ | |
| \n \n | |
| Our model uses a hydraulic analogy model of cardiac function from <a href='https://ieeexplore.ieee.org/document/4729737/keywords#keywords'>Simaan et al 2008</a>. Below you can input values of predicted parameters and output a simulated pressure-volume loop predicted from the <a href='https://ieeexplore.ieee.org/document/4729737/keywords#keywords'>Simaan et al 2008</a> model, which is an ordinary differential equation. Tune parameters and press 'Run simulation.' | |
| """ | |
| description3 = """ | |
| \n\n | |
| This model can incorporate a tunable left-ventricular assistance device (LVAD) for in-silico experimentation. Click to view the effect of adding an LVAD to the simulated PV loop. | |
| """ | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| with gr.Blocks() as demo: | |
| # text | |
| gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>") | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(scale=1.5, min_width=100): | |
| generate_button = gr.Button("Load sample echocardiogram and generate result") | |
| with gr.Row(): | |
| video = gr.PlayableVideo(autoplay=True) #format="avi" | |
| plot = gr.HTML() # gr.PlayableVideo(autoplay=True) | |
| with gr.Row(): | |
| Rm = gr.Number(label="Mitral valve circuit resistance (Rm) mmHg*s/ml:") | |
| Ra = gr.Number(label="Aortic valve circuit resistance (Ra) mmHg*s/ml:") | |
| Emax = gr.Number(label="Maximum elastance (Emax) mmHg/ml:") | |
| Emin = gr.Number(label="Minimum elastance (Emin) mmHg/ml:") | |
| Vd = gr.Number(label="Theoretical zero pressure volume (Vd) ml:") | |
| Tc = gr.Number(label="Cycle duration (Tc) s:") | |
| start_v = gr.Number(label="Initial volume (start_v) ml:") | |
| gr.Markdown(title2) | |
| gr.Markdown(description2) | |
| simulation_button = gr.Button("Run simulation") | |
| with gr.Row(): | |
| sl1 = gr.Slider(0.005, 0.1, value=.005, label="Rm (mmHg*s/ml)") | |
| sl2 = gr.Slider(0.0001, 0.25, value=.0001, label="Ra (mmHg*s/ml)") | |
| sl3 = gr.Slider(0.5, 3.5, value=.5, label="Emax (mmHg/ml)") | |
| sl4 = gr.Slider(0.02, 0.1, value= .02, label="Emin (mmHg/ml)") | |
| sl5 = gr.Slider(4.0, 25.0, value= 4.0, label="Vd (ml)") | |
| sl6 = gr.Slider(0.4, 1.7, value= 0.4, label="Tc (s)") | |
| sl7 = gr.Slider(0.0, 280.0, value= 140., label="start_v (ml)") | |
| with gr.Row(): | |
| simulation = gr.Plot() | |
| gr.Markdown(description3) | |
| LVAD_button = gr.Button("Add LVAD") | |
| with gr.Row(): | |
| beta = gr.Slider(.4, 1.0, value= 1.4, label="Pump speed parameter:") | |
| with gr.Row(): | |
| lvad = gr.Plot() | |
| with gr.Row(): | |
| EF_o = gr.Number(label="Ejection fraction (EF) before LVAD:") | |
| EF_n = gr.Number(label="Ejection fraction (EF) after LVAD:") | |
| CO_o = gr.Number(label="Cardiac output before LVAD:") | |
| CO_n = gr.Number(label="Cardiac output after LVAD:") | |
| #MAP_n = gr.Number(label="Mean arterial pressure (MAP) after LVAD:") | |
| generate_button.click(fn=generate_example, outputs = [video,plot,Rm,Ra,Emax,Emin,Vd,Tc,start_v]) | |
| simulation_button.click(fn=pvloop_simulator_plot_only, inputs = [sl1,sl2,sl3,sl4,sl5,sl6,sl7], outputs = [simulation]) | |
| LVAD_button.click(fn=lvad_plots, inputs = [sl1,sl2,sl3,sl4,sl5,sl6,sl7,beta], outputs = [lvad, EF_o, EF_n, CO_o, CO_n]) | |
| demo.launch() |