|
import torch |
|
import torch.nn as nn |
|
import pandas as pd |
|
import numpy as np |
|
from functools import partial |
|
from datetime import datetime, timedelta |
|
from pathlib import Path |
|
import pickle |
|
|
|
import dask |
|
import dask.array as da |
|
import cartopy |
|
import cartopy.crs as ccrs |
|
import xarray as xr |
|
import xarray.ufuncs as xu |
|
import matplotlib.pyplot as plt |
|
|
|
from model.afnonet import AFNONet |
|
|
|
DATANAMES = ['10m_u_component_of_wind', '10m_v_component_of_wind', '2m_temperature', |
|
'geopotential@1000', 'geopotential@50', 'geopotential@500', 'geopotential@850', |
|
'mean_sea_level_pressure', 'relative_humidity@500', 'relative_humidity@850', |
|
'surface_pressure', 'temperature@500', 'temperature@850', 'total_column_water_vapour', |
|
'u_component_of_wind@1000', 'u_component_of_wind@500', 'u_component_of_wind@850', |
|
'v_component_of_wind@1000', 'v_component_of_wind@500', 'v_component_of_wind@850', |
|
'total_precipitation'] |
|
DATAMAP = { |
|
'geopotential': 'z', |
|
'relative_humidity': 'r', |
|
'temperature': 't', |
|
'u_component_of_wind': 'u', |
|
'v_component_of_wind': 'v' |
|
} |
|
|
|
|
|
def load_model(): |
|
|
|
h, w = 720, 1440 |
|
x_c, y_c, p_c = 20, 20, 1 |
|
|
|
backbone_model = AFNONet(img_size=[h, w], in_chans=x_c, out_chans=y_c, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
|
ckpt = torch.load('./backbone.pt', map_location="cpu") |
|
backbone_model.load_state_dict(ckpt['model']) |
|
|
|
precip_model = AFNONet(img_size=[h, w], in_chans=x_c, out_chans=p_c, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
|
ckpt = torch.load('./precipitation.pt', map_location="cpu") |
|
precip_model.load_state_dict(ckpt['model']) |
|
|
|
|
|
def imcol(data, img_path, img_name, **kwargs): |
|
fig = plt.figure(figsize=(20, 10)) |
|
ax = plt.axes(projection=ccrs.PlateCarree()) |
|
|
|
I = data.plot(ax=ax, transform=ccrs.PlateCarree(), add_colorbar=False, add_labels=False, rasterized=True, **kwargs) |
|
ax.coastlines(resolution='110m') |
|
|
|
dirname = f'{img_path.absolute()}/{img_name}.jpg' |
|
|
|
plt.axis('off') |
|
plt.savefig(dirname, bbox_inches='tight', pad_inches=0.) |
|
plt.close(fig) |
|
|
|
|
|
def plot(real_data, pred_data, save_path): |
|
cmap_t = 'RdYlBu_r' |
|
|
|
wind = xu.sqrt(real_data['u10'] ** 2 + real_data['v10'] ** 2) |
|
wmin, wmax = wind.values.min(), wind.values.max() |
|
wind = xu.sqrt(pred_data['u10'] ** 2 + pred_data['v10'] ** 2) |
|
wmin, wmax = min(wind.values.min(), wmin), max(wind.values.max(), wmax) |
|
|
|
pmin, pmax = real_data['tp'].values.min(), real_data['tp'].values.max() |
|
pmin, pmax = min(pred_data['tp'].values.min(), pmin), max(pred_data['tp'].values.max(), pmax) |
|
|
|
tmin, tmax = real_data['t2m'].values.min(), real_data['t2m'].values.max() |
|
tmin, tmax = min(pred_data['t2m'].values.min(), tmin), max(pred_data['t2m'].values.max(), tmax) |
|
|
|
for i in range(len(real_data.time)): |
|
u = real_data['u10'].isel(time=i) |
|
v = real_data['v10'].isel(time=i) |
|
wind = xu.sqrt(u ** 2 + v ** 2) |
|
precip = real_data['tp'].isel(time=i) |
|
temp = real_data['t2m'].isel(time=i) |
|
|
|
datetime = pd.to_datetime(str(wind['time'].values)) |
|
datetime = datetime.strftime('%Y-%m-%d %H:%M:%S') |
|
print(f'plot {datetime}') |
|
|
|
imcol(wind, save_path, img_name=f'wind_{datetime}_real', cmap=cmap_t, vmin=wmin, vmax=wmax), |
|
imcol(precip, save_path, img_name=f'precipitation_{datetime}_real', cmap=cmap_t, vmin=pmin, vmax=pmax), |
|
imcol(temp, save_path, img_name=f'temperature_{datetime}_real', cmap=cmap_t, vmin=tmin, vmax=tmax) |
|
|
|
for i in range(len(pred_data.time)): |
|
u = pred_data['u10'].isel(time=i) |
|
v = pred_data['v10'].isel(time=i) |
|
wind = xu.sqrt(u ** 2 + v ** 2) |
|
precip = pred_data['tp'].isel(time=i) |
|
temp = pred_data['t2m'].isel(time=i) |
|
|
|
datetime = pd.to_datetime(str(wind['time'].values)) |
|
datetime = datetime.strftime('%Y-%m-%d %H:%M:%S') |
|
print(f'plot {datetime}') |
|
|
|
imcol(wind, save_path, img_name=f'wind_{datetime}_pred', cmap=cmap_t, vmin=wmin, vmax=wmax), |
|
imcol(precip, save_path, img_name=f'precipitation_{datetime}_pred', cmap=cmap_t, vmin=pmin, vmax=pmax), |
|
imcol(temp, save_path, img_name=f'temperature_{datetime}_pred', cmap=cmap_t, vmin=tmin, vmax=tmax) |
|
|
|
|
|
def get_pred(sample, scaler, times=None, latitude=None, longitude=None): |
|
|
|
backbone_model, precip_model = load_model() |
|
|
|
sample = torch.from_numpy(sample[0]) |
|
sample = sample.float() |
|
|
|
backbone_model.eval() |
|
precip_model.eval() |
|
pred = [] |
|
x = sample.unsqueeze(0).transpose(3, 2).transpose(2, 1) |
|
for i in range(len(times)): |
|
print(f"predict {times[i]}") |
|
|
|
with torch.cuda.amp.autocast(): |
|
x = backbone_model(x) |
|
tmp = x.transpose(1, 2).transpose(2, 3) |
|
p = precip_model(x) |
|
|
|
tmp = tmp.detach().numpy()[0, :, :, :3] * scaler['std'][:3] + scaler['mean'][:3] |
|
p = p.detach().numpy()[0, 0, :, :, np.newaxis] * scaler['std'][-1] + scaler['mean'][-1] |
|
tmp = np.concatenate([tmp, p], axis=-1) |
|
pred.append(tmp) |
|
|
|
pred = np.asarray(pred) |
|
pred_data = xr.Dataset({ |
|
'u10': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 0], chunks=(7, 720, 1440))), |
|
'v10': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 1], chunks=(7, 720, 1440))), |
|
't2m': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 2], chunks=(7, 720, 1440))), |
|
'tp': (['time', 'latitude', 'longitude'], da.from_array(pred[:, :, :, 3], chunks=(7, 720, 1440))), |
|
}, |
|
coords={'time': (['time'], times), |
|
'latitude': (['latitude'], latitude), |
|
'longitude': (['longitude'], longitude) |
|
} |
|
) |
|
|
|
return pred_data |
|
|
|
|
|
def get_data(start_time, end_time): |
|
times = slice(start_time, end_time) |
|
|
|
with open(f'./scaler.pkl', "rb") as f: |
|
scaler = pickle.load(f) |
|
|
|
|
|
datas = [] |
|
for file in DATANAMES: |
|
tmp = xr.open_mfdataset(f'./ERA5_rawdata/{file}/*.nc', combine='by_coords').sel(time=times) |
|
if '@' in file: |
|
k, v = file.split('@') |
|
tmp = tmp.rename_vars({DATAMAP[k]: f'{DATAMAP[k]}@{v}'}) |
|
datas.append(tmp) |
|
with dask.config.set(**{'array.slicing.split_large_chunks': False}): |
|
raw_data = xr.merge(datas, compat="identical", join="inner") |
|
|
|
data = [] |
|
for name in ['u10', 'v10', 't2m', 'z@1000', 'z@50', 'z@500', 'z@850', 'msl', 'r@500', 'r@850', 'sp', 't@500', 't@850', 'tcwv', 'u@1000', 'u@500', 'u@850', 'v@1000', 'v@500', 'v@850']: |
|
raw = raw_data[name].values |
|
data.append(raw) |
|
|
|
data = np.stack(data, axis=-1) |
|
data = (data - scaler['mean']) / scaler['std'] |
|
data = data[:, 1:, :, :] |
|
|
|
return raw_data[['u10', 'v10', 't2m', 'tp']].sel(expver=1), data, scaler |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
start_time = datetime(2023, 1, 1, 0, 0) |
|
end_time = datetime(2023, 1, 5, 18, 0) |
|
num = int((end_time - start_time) / timedelta(hours=6)) |
|
|
|
print(f"start_time: {start_time}, end_time: {end_time}, pred_num: {num}") |
|
|
|
real_data, sample, scaler = get_data(start_time) |
|
print(sample.shape) |
|
|
|
pred_times = [start_time + timedelta(hours=6) * i for i in range(1, num)] |
|
pred = get_pred(sample, scaler=scaler, times=pred_times, latitude=real_data.latitude[1:], longitude=real_data.longitude) |
|
|
|
save_path = Path(f"./output/") |
|
save_path.mkdir(parents=True, exist_ok=True) |
|
|
|
plot(real_data, pred, save_path) |