Coverage for src/CSET/operators/ageofair.py: 99%

141 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-05 21:08 +0000

1# © Crown copyright, Met Office (2022-2024) and CSET contributors. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15""" 

16Age of air operator. 

17 

18The age of air diagnostic provides a qualtitative view of how old air is within 

19the domain, by calculating a back trajectory at each grid point at each lead time 

20to determine when air entered through the lateral boundary. This is useful for 

21diagnosing how quickly air ventilates the domain, depending on its size and the 

22prevailing meteorology. 

23 

24The diagnostic uses the u, v and w components of wind, along with geopotential height to 

25perform the back trajectory. Data is first regridded to 0.5 degrees. 

26 

27Note: the code here does not consider sub-grid transport, and only uses the postprocessed 

28velocity fields and geopotential height. Its applicability is for large-scale flow O(1000 km), 

29and not small scale flow where mixing is likely to play a larger role. 

30""" 

31 

32import datetime 

33import logging 

34import multiprocessing 

35import os 

36import tempfile 

37from functools import partial 

38from math import asin, cos, radians, sin, sqrt 

39 

40import numpy as np 

41from iris.cube import Cube 

42from scipy.ndimage import gaussian_filter 

43 

44from CSET.operators._utils import get_cube_yxcoordname 

45 

46 

47def _calc_dist(coord_1, coord_2): 

48 """Calculate distance between two coordinate tuples. 

49 

50 Arguments 

51 ---------- 

52 coord_1: tuple 

53 A tuple containing (latitude, longitude) coordinate floats 

54 coord_2: tuple 

55 A tuple containing (latitude, longitude) coordinate floats 

56 

57 Returns 

58 ------- 

59 distance: float 

60 Distance between the two coordinate points in meters 

61 

62 Notes 

63 ----- 

64 The function uses the Haversine approximation to calculate distance in metres. 

65 

66 """ 

67 # Approximate radius of earth in m 

68 # Source: https://nssdc.gsfc.nasa.gov/planetary/factsheet/earthfact.html 

69 radius = 6378000 

70 

71 # Extract coordinates and convert to radians 

72 lat1 = radians(coord_1[0]) 

73 lon1 = radians(coord_1[1]) 

74 lat2 = radians(coord_2[0]) 

75 lon2 = radians(coord_2[1]) 

76 

77 # Find out delta latitude, longitude 

78 dlon = lon2 - lon1 

79 dlat = lat2 - lat1 

80 

81 # Compute distance 

82 a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2 

83 c = 2 * asin(sqrt(a)) 

84 distance = radius * c 

85 

86 return distance 

87 

88 

89def _aoa_core( 

90 x_arr: np.ndarray, 

91 y_arr: np.ndarray, 

92 z_arr: np.ndarray, 

93 g_arr: np.ndarray, 

94 lats: np.ndarray, 

95 lons: np.ndarray, 

96 dt: int, 

97 plev_idx: int, 

98 timeunit: str, 

99 cyclic: bool, 

100 tmpdir: str, 

101 lon_pnt: int, 

102): 

103 """AOA multiprocessing core. 

104 

105 Runs the core age of air code on a specific longitude point (all latitudes) for 

106 parallelisation. 

107 

108 Arguments 

109 --------- 

110 x_arr: np.ndarray 

111 A numpy array containing x wind data. 

112 y_arr: np.ndarray 

113 A numpy array containing y wind data. 

114 z_arr: np.ndarray 

115 A numpy array containing w wind data. 

116 g_arr: np.ndarray 

117 A numpy array containing geopotential height data. 

118 lats: np.ndarray 

119 A numpy array containing latitude points. 

120 lons: np.ndarray 

121 A numpy array containing longitude points. 

122 dt: int 

123 Gap between time intervals 

124 plev_idx: int 

125 Index of pressure level requested to run back trajectories on. 

126 timeunit: str 

127 Units of time, currently only accepts 'hour' 

128 cyclic: bool 

129 Whether to wrap at east/west boundaries. See compute_ageofair for a fuller description. 

130 tmpdir: str 

131 Path to store intermediate data 

132 lon_pnt: int 

133 Longitude point to extract and run back trajectories on for parallelisation. 

134 """ 

135 # Initialise empty array to store age of air for this latitude strip. 

136 ageofair_local = np.zeros((x_arr.shape[0], x_arr.shape[2])) 

137 logging.debug("Working on %s", lon_pnt) 

138 

139 # Ignore leadtime 0 as this is trivial. 

140 for leadtime in range(1, x_arr.shape[0]): 

141 # Initialise leadtime slice with current leadtime. 

142 ageofair_local[leadtime, :] = leadtime * dt 

143 for lat_pnt in range(0, x_arr.shape[2]): 

144 # Gridpoint initialised as within LAM by construction. 

145 outside_lam = False 

146 

147 # If final column, look at dist from prev column, otherwise look at next column. 

148 if lon_pnt == len(lons) - 1: 

149 ew_spacing = _calc_dist( 

150 (lats[lat_pnt], lons[lon_pnt]), (lats[lat_pnt], lons[lon_pnt - 1]) 

151 ) 

152 else: 

153 ew_spacing = _calc_dist( 

154 (lats[lat_pnt], lons[lon_pnt]), (lats[lat_pnt], lons[lon_pnt + 1]) 

155 ) 

156 

157 # If final row, look at dist from row column, otherwise look at next row. 

158 if lat_pnt == len(lats) - 1: 

159 ns_spacing = _calc_dist( 

160 (lats[lat_pnt], lons[lon_pnt]), (lats[lat_pnt - 1], lons[lon_pnt]) 

161 ) 

162 else: 

163 ns_spacing = _calc_dist( 

164 (lats[lat_pnt], lons[lon_pnt]), (lats[lat_pnt + 1], lons[lon_pnt]) 

165 ) 

166 

167 # Go through past timeslices 

168 for n in range(0, leadtime): 

169 # First step back, so we use i,j coords to find out parcel location 

170 # in terms of array point 

171 if n == 0: 

172 x = lon_pnt 

173 y = lat_pnt 

174 z = plev_idx 

175 

176 # Only seek preceding wind if its inside domain. 

177 if not outside_lam: 

178 # Get vector profile at current time - nearest whole gridpoint. 

179 u = x_arr[leadtime - n, int(z), int(y), int(x)] 

180 v = y_arr[leadtime - n, int(z), int(y), int(x)] 

181 w = z_arr[leadtime - n, int(z), int(y), int(x)] 

182 g = g_arr[leadtime - n, int(z), int(y), int(x)] 

183 

184 # First, compute horizontal displacement using inverse of horizontal vector 

185 # Convert m/s to m/[samplingrate]h, then m -> model gridpoints 

186 if timeunit == "hour": 186 ↛ 192line 186 didn't jump to line 192 because the condition on line 186 was always true

187 du = ((u * 60 * 60 * dt) / ew_spacing) * -1.0 

188 dv = ((v * 60 * 60 * dt) / ns_spacing) * -1.0 

189 dz = (w * 60 * 60 * dt) * -1.0 

190 

191 # Get column of geopot height. 

192 g_col = g_arr[(leadtime - n), :, int(y), int(x)] 

193 

194 # New geopotential height of parcel - store 'capacity' between timesteps as vertical motions smaller. 

195 if n == 0: 

196 new_g = g + dz 

197 pre_g = new_g 

198 else: 

199 new_g = pre_g + dz 

200 

201 # Calculate which geopot level is closest to new geopot level. 

202 z = np.argmin(np.abs(g_col - new_g)) 

203 

204 # Update x,y location based on displacement. Z already updated 

205 x = x + du 

206 y = y + dv 

207 

208 # If it is now outside domain, then save age and don't process further with outside LAM flag. 

209 # Support cyclic domains like K-SCALE, where x coord out of domain gets moved through dateline. 

210 if cyclic: 

211 if ( 

212 x < 0 

213 ): # as for example -0.3 would still be in domain, but x_arr.shape-0.3 would result in index error 

214 x = x_arr.shape[3] + x # wrap back around dateline 

215 elif x >= x_arr.shape[3]: 

216 x = x_arr.shape[3] - x 

217 else: 

218 if x < 0 or x >= x_arr.shape[3]: 

219 ageofair_local[leadtime, lat_pnt] = n * dt 

220 outside_lam = True 

221 

222 if y < 0 or y >= x_arr.shape[2]: 

223 ageofair_local[leadtime, lat_pnt] = n * dt 

224 outside_lam = True 

225 

226 # Save 3d array containing age of air 

227 np.save(tmpdir + f"/aoa_frag_{lon_pnt:04d}.npy", ageofair_local) 

228 

229 

230def compute_ageofair( 

231 XWIND: Cube, 

232 YWIND: Cube, 

233 WWIND: Cube, 

234 GEOPOT: Cube, 

235 plev: int, 

236 cyclic: bool = False, 

237 multicore=True, 

238): 

239 """Compute back trajectories for a given forecast. 

240 

241 This allows us to determine when air entered through the boundaries. This will run on all available 

242 lead-times, and thus return an age of air cube of shape ntime, nlat, nlon. It supports multiprocesing, 

243 by iterating over longitude, or if set as None, will run on a single core, which is easier for debugging. 

244 This function supports ensembles, where it will check if realization dimension exists and if so, loop 

245 over this axis. 

246 

247 Arguments 

248 ---------- 

249 XWIND: Cube 

250 An iris cube containing the x component of wind on pressure levels, on a 0p5 degree grid. 

251 Requires 4 dimensions, ordered time, pressure, latitude and longitude. Must contain at 

252 least 2 time points to compute back trajectory. 

253 YWIND: Cube 

254 An iris cube in the same format as XWIND. 

255 WWIND: Cube 

256 An iris cube in the same format as XWIND. 

257 GEOPOT: Cube 

258 An iris cube in the same format as XWIND. 

259 plev: int 

260 The pressure level of which to compute the back trajectory on. The function will search to 

261 see if this exists and if not, will raise an exception. 

262 cyclic: bool 

263 If cyclic is True, then the code will assume no east/west boundary and if a back trajectory 

264 reaches the boundary, it will emerge out of the other side. This option is useful for large 

265 domains such as the K-SCALE tropical channel, where there are only north/south boundaries in 

266 the domain. 

267 multicore: bool 

268 If true, split up age of air diagnostic to use multiple cores (defaults to number of cores available to the process), otherwise run 

269 using a single process, which is easier to debug if developing the code. 

270 

271 Returns 

272 ------- 

273 ageofair_cube: Cube 

274 An iris cube of the age of air data, with 3 dimensions (time, latitude, longitude). 

275 

276 Notes 

277 ----- 

278 The age of air diagnostic was used in Warner et al. (2023) [Warneretal2023]_ to identify the relative 

279 role of spin-up from initial conditions and lateral boundary conditions over tropical Africa to explore 

280 the impact of new data assimilation techniques. A further paper is currently in review ([Warneretal2024]_) 

281 which applies the diagnostic more widely to the Australian ACCESS convection-permitting models. 

282 

283 References 

284 ---------- 

285 .. [Warneretal2023] Warner, J.L., Petch, J., Short, C., Bain, C., 2023. Assessing the impact of an NWP warm-start 

286 system on model spin-up over tropical Africa. QJ, 149( 751), pp.621-636. doi:10.1002/qj.4429 

287 .. [Warneretal2024] Diagnosing lateral boundary spin-up in regional models using an age of air diagnostic 

288 James L. Warner, Charmaine N. Franklin, Belinda Roux, Shaun Cooper, Susan Rennie, Vinod 

289 Kumar. 

290 Submitted for Quarterly Journal of the Royal Meteorological Society. 

291 

292 """ 

293 # Set up temporary directory to store intermediate age of air slices. 

294 tmpdir = tempfile.TemporaryDirectory(dir=os.getenv("CYLC_TASK_WORK_DIR")) 

295 logging.info("Made tmpdir %s", tmpdir.name) 

296 

297 # Check that all cubes are of same size (will catch different dimension orders too). 

298 if not XWIND.shape == YWIND.shape == WWIND.shape == GEOPOT.shape: 

299 raise ValueError("Cubes are not the same shape") 

300 

301 # Get time units and assign for later 

302 if str(XWIND.coord("time").units).startswith("hours since "): 

303 timeunit = "hour" 

304 else: 

305 raise NotImplementedError("Unsupported time base") 

306 

307 # Make data non-lazy to speed up code. 

308 logging.info("Making data non-lazy...") 

309 x_arr = XWIND.data 

310 y_arr = YWIND.data 

311 z_arr = WWIND.data 

312 g_arr = GEOPOT.data 

313 

314 # Get coord points 

315 lat_name, lon_name = get_cube_yxcoordname(XWIND) 

316 lats = XWIND.coord(lat_name).points 

317 lons = XWIND.coord(lon_name).points 

318 time = XWIND.coord("time").points 

319 

320 # Get time spacing of cube to determine whether the spacing in time is the 

321 # same throughout the cube. If not, then not supported. 

322 dt = XWIND.coord("time").points[1:] - XWIND.coord("time").points[:-1] 

323 if np.all(dt == dt[0]): 

324 dt = dt[0] 

325 else: 

326 raise NotImplementedError("Time intervals are not consistent") 

327 

328 # Some logic to determine which index each axis is, and check for ensembles. 

329 dimension_mapping = {} 

330 for coord in XWIND.dim_coords: 

331 dim_index = XWIND.coord_dims(coord.name())[0] 

332 dimension_mapping[coord.name()] = dim_index 

333 

334 if "realization" in dimension_mapping: 

335 ensemble_mode = True 

336 if dimension_mapping != { 

337 "realization": 0, 

338 "time": 1, 

339 "pressure": 2, 

340 lat_name: 3, 

341 lon_name: 4, 

342 }: 

343 raise ValueError( 

344 f"Dimension mapping not correct, ordered {dimension_mapping}" 

345 ) 

346 else: 

347 ensemble_mode = False 

348 if dimension_mapping != {"time": 0, "pressure": 1, lat_name: 2, lon_name: 3}: 

349 raise ValueError( 

350 f"Dimension mapping not correct, ordered {dimension_mapping}" 

351 ) 

352 

353 # Smooth vertical velocity to 2sigma (standard for 0.5 degree). 

354 logging.info("Smoothing vertical velocity...") 

355 if ensemble_mode: 

356 z_arr = gaussian_filter(z_arr, 2, mode="nearest", axes=(3, 4)) 

357 else: 

358 z_arr = gaussian_filter(z_arr, 2, mode="nearest", axes=(2, 3)) 

359 

360 # Get array index for user specified pressure level. 

361 if plev not in XWIND.coord("pressure").points: 

362 raise IndexError(f"Can't find plev {plev} in {XWIND.coord('pressure').points}") 

363 

364 # Find corresponding pressure level index 

365 plev_idx = np.where(XWIND.coord("pressure").points == plev)[0][0] 

366 

367 # Initialise cube containing age of air. 

368 if ensemble_mode: 

369 ageofair_cube = Cube( 

370 np.zeros( 

371 ( 

372 len(XWIND.coord("realization").points), 

373 len(time), 

374 len(lats), 

375 len(lons), 

376 ) 

377 ), 

378 long_name="age_of_air", 

379 units="hours", 

380 dim_coords_and_dims=[ 

381 (XWIND.coord("realization"), 0), 

382 (XWIND.coord("time"), 1), 

383 (XWIND.coord(lat_name), 2), 

384 (XWIND.coord(lon_name), 3), 

385 ], 

386 ) 

387 else: 

388 ageofair_cube = Cube( 

389 np.zeros((len(time), len(lats), len(lons))), 

390 long_name="age_of_air", 

391 units="hours", 

392 dim_coords_and_dims=[ 

393 (XWIND.coord("time"), 0), 

394 (XWIND.coord(lat_name), 1), 

395 (XWIND.coord(lon_name), 2), 

396 ], 

397 ) 

398 

399 # Unix API for getting set of usable CPUs. 

400 # See https://docs.python.org/3/library/os.html#os.cpu_count 

401 if multicore: 

402 num_usable_cores = len(os.sched_getaffinity(0)) 

403 # Use "spawn" method to avoid warnings before the default is changed in 

404 # python 3.14. See the (not very good) warning here: 

405 # https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods 

406 mp_context = multiprocessing.get_context("spawn") 

407 pool = mp_context.Pool(num_usable_cores) 

408 

409 logging.info("STARTING AOA DIAG...") 

410 start = datetime.datetime.now() 

411 

412 # Main call for calculating age of air diagnostic 

413 if ensemble_mode: 

414 for e in range(0, len(XWIND.coord("realization").points)): 

415 logging.info(f"Working on member {e}") 

416 

417 # Multiprocessing on each longitude slice 

418 func = partial( 

419 _aoa_core, 

420 np.copy(x_arr[e, :, :, :, :]), 

421 np.copy(y_arr[e, :, :, :, :]), 

422 np.copy(z_arr[e, :, :, :, :]), 

423 np.copy(g_arr[e, :, :, :, :]), 

424 lats, 

425 lons, 

426 dt, 

427 plev_idx, 

428 timeunit, 

429 cyclic, 

430 tmpdir.name, 

431 ) 

432 if multicore: 

433 pool.map(func, range(0, XWIND.shape[4])) 

434 else: 

435 # Convert to list to ensure everything is processed. 

436 list(map(func, range(0, XWIND.shape[4]))) 

437 

438 for i in range(0, XWIND.shape[4]): 

439 file = f"{tmpdir.name}/aoa_frag_{i:04}.npy" 

440 ageofair_cube.data[e, :, :, i] = np.load(file) 

441 

442 else: 

443 # Multiprocessing on each longitude slice 

444 func = partial( 

445 _aoa_core, 

446 np.copy(x_arr), 

447 np.copy(y_arr), 

448 np.copy(z_arr), 

449 np.copy(g_arr), 

450 lats, 

451 lons, 

452 dt, 

453 plev_idx, 

454 timeunit, 

455 cyclic, 

456 tmpdir.name, 

457 ) 

458 if multicore: 

459 pool.map(func, range(0, XWIND.shape[3])) 

460 else: 

461 # Convert to list to ensure everything is processed. 

462 list(map(func, range(0, XWIND.shape[3]))) 

463 

464 for i in range(0, XWIND.shape[3]): 

465 file = f"{tmpdir.name}/aoa_frag_{i:04}.npy" 

466 ageofair_cube.data[:, :, i] = np.load(file) 

467 

468 if multicore: 

469 # Wait for tasks to finish then clean up worker processes. 

470 pool.terminate() 

471 pool.join() 

472 

473 # Verbose for time taken to run, and collate tmp ndarrays into final cube, and return 

474 logging.info( 

475 "AOA DIAG DONE, took %s s", 

476 (datetime.datetime.now() - start).total_seconds(), 

477 ) 

478 

479 # Clean tmpdir 

480 tmpdir.cleanup() 

481 

482 return ageofair_cube