Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 25a508d

Browse files
committedMar 5, 2025·
Running with single year batch, downsampling outputvars
1 parent 04365bd commit 25a508d

File tree

8 files changed

+568
-1
lines changed

8 files changed

+568
-1
lines changed
 

‎experiments/calibration/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,7 @@ Thermodynamics = "b60c26fb-14c3-4610-9d3e-2d17fe7ff00c"
3030
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
3131

3232
[compat]
33-
ClimaCalibrate = "0.0.10"
33+
ClimaCalibrate = "0.0.11"
3434
ClimaTimeSteppers = "0.8.2"
35+
ClimaCore = "0.14.26"
36+
EnsembleKalmanProcesses = "2.0"
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Generate and save experiment observations to disk
2+
using ClimaAnalysis, JLD2
3+
include("experiments/calibration/coarse_amip/observation_utils.jl")
4+
5+
const obs_dir = "/home/ext_nefrathe_caltech_edu/calibration_obs"
6+
const diagnostic_dir = "experiments/calibration/output/iteration_000/member_001/model_config/output_0000/clima_atmos/"
7+
8+
diagnostic_var2d = OutputVar(joinpath(diagnostic_dir, "rsdt_1M_average.nc"));
9+
pressure = OutputVar(joinpath(diagnostic_dir, "pfull_1M_average.nc"));
10+
diagnostic_var3d = OutputVar(joinpath(diagnostic_dir, "ta_1M_average.nc"));
11+
diagnostic_var3d = ClimaAnalysis.Atmos.to_pressure_coordinates(diagnostic_var3d, pressure)
12+
13+
nt = get_all_output_vars(obs_dir, diagnostic_var2d, diagnostic_var3d)
14+
nyears = 18
15+
observation_vec = create_observation_vector(nt, nyears)
16+
JLD2.save_object("experiments/calibration/coarse_amip/observations.jld2", observation_vec)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
FLOAT_TYPE: "Float32"
2+
albedo_model: "CouplerAlbedo"
3+
atmos_config_file: "config/longrun_configs/amip_target_edonly.yml"
4+
checkpoint_dt: "99999days"
5+
coupler_toml: ["toml/amip.toml"]
6+
print_config_dict: true
7+
coupler_output_dir: "experiments/calibration/output"
8+
dt: "800secs"
9+
dt_cpl: "800secs"
10+
dt_save_state_to_disk: "99999days"
11+
dt_save_to_sol: "99999days"
12+
dz_bottom: 100.0
13+
energy_check: false
14+
h_elem: 3
15+
land_albedo_type: "map_temporal"
16+
mode_name: "amip"
17+
mono_surface: false
18+
netcdf_output_at_levels: true
19+
output_default_diagnostics: false
20+
use_coupler_diagnostics: false
21+
radiation_reset_rng_seed: true
22+
rayleigh_sponge: true
23+
start_date: "20000901"
24+
surface_setup: "PrescribedSurface"
25+
t_end: "457days"
26+
topography: "Earth"
27+
topo_smoothing: true
28+
turb_flux_partition: "CombinedStateFluxesMOST"
29+
viscous_sponge: true
30+
z_elem: 39
31+
z_max: 60000.0
32+
extra_atmos_diagnostics:
33+
- reduction_time: average
34+
short_name: rsut
35+
period: 1months
36+
writer: nc
37+
- reduction_time: average
38+
short_name: rlut
39+
period: 1months
40+
writer: nc
41+
- reduction_time: average
42+
short_name: rsdt
43+
period: 1months
44+
writer: nc
45+
- reduction_time: average
46+
short_name: rsutcs
47+
period: 1months
48+
writer: nc
49+
- reduction_time: average
50+
short_name: rlutcs
51+
period: 1months
52+
writer: nc
53+
- reduction_time: average
54+
short_name: pr
55+
period: 1months
56+
writer: nc
57+
- reduction_time: average
58+
short_name: ts
59+
period: 1months
60+
writer: nc
61+
- reduction_time: average
62+
short_name: ta
63+
period: 1months
64+
writer: nc
65+
- reduction_time: average
66+
short_name: hur
67+
period: 1months
68+
writer: nc
69+
- reduction_time: average
70+
short_name: hus
71+
period: 1months
72+
writer: nc
73+
- reduction_time: average
74+
short_name: clw
75+
period: 1months
76+
writer: nc
77+
- reduction_time: average
78+
short_name: cli
79+
period: 1months
80+
writer: nc
81+
- reduction_time: average
82+
short_name: pfull
83+
period: 1months
84+
writer: nc
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import ClimaCoupler
2+
import ClimaCalibrate
3+
import CUDA
4+
import EnsembleKalmanProcesses as EKP
5+
ENV["CLIMACOMMS_DEVICE"] = "CUDA"
6+
ENV["CLIMACOMMS_CONTEXT"] = "SINGLETON"
7+
import JLD2
8+
include(joinpath(pkgdir(ClimaCoupler), "experiments", "ClimaEarth", "setup_run.jl"))
9+
10+
function ClimaCalibrate.forward_model(iter, member)
11+
12+
config_file = joinpath(pkgdir(ClimaCoupler), "experiments", "calibration", "coarse_amip", "model_config.yml")
13+
config_dict = get_coupler_config_dict(config_file)
14+
15+
output_dir_root = config_dict["coupler_output_dir"]
16+
eki = JLD2.load_object(ClimaCalibrate.ekp_path(output_dir_root, iter))
17+
minibatch = EKP.get_current_minibatch(eki)
18+
config_dict["start_date"] = minibatch_to_start_date(minibatch)
19+
20+
# Set member parameter file
21+
sampled_parameter_file = ClimaCalibrate.parameter_path(output_dir_root, iter, member)
22+
config_dict["calibration_toml"] = sampled_parameter_file
23+
# Set member output directory
24+
member_output_dir = ClimaCalibrate.path_to_ensemble_member(output_dir_root, iter, member)
25+
config_dict["coupler_output_dir"] = member_output_dir
26+
27+
return setup_and_run(config_dict)
28+
end
29+
30+
31+
function minibatch_to_start_date(batch)
32+
start_year = minimum(batch) + 1999
33+
@assert start_year >= 2000
34+
return "$(start_year)0901"
35+
end
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using ClimaAnalysis, Dates
2+
import ClimaCalibrate
3+
import ClimaCoupler
4+
import JLD2
5+
import EnsembleKalmanProcesses as EKP
6+
obs = JLD2.load_object("experiments/calibration/coarse_amip/observations.jld2")
7+
const single_member_dims = length(EKP.get_obs(first(obs)))
8+
include(joinpath(pkgdir(ClimaCoupler), "experiments/calibration/coarse_amip/observation_utils.jl"))
9+
10+
function ClimaCalibrate.observation_map(iteration)
11+
G_ensemble = Array{Float64}(undef, single_member_dims, ensemble_size)
12+
for m in 1:ensemble_size
13+
@show m
14+
member_path = ClimaCalibrate.path_to_ensemble_member(output_dir, iteration, m)
15+
simdir_path = joinpath(member_path, "model_config/output_active")
16+
if isdir(simdir_path)
17+
simdir = SimDir(simdir_path)
18+
G_ensemble[:, m] .= process_member_data(simdir)
19+
else
20+
@info "No data found for member $m."
21+
G_ensemble[:, m] .= NaN
22+
end
23+
end
24+
return G_ensemble
25+
end
26+
27+
function process_outputvar(simdir, name)
28+
days = 86_400
29+
30+
monthly_avgs = get_monthly_averages(simdir, name)
31+
# Preprocess to match observations
32+
if has_altitude(monthly_avgs)
33+
pressure = get_monthly_averages(simdir, "pfull")
34+
monthly_avgs = ClimaAnalysis.Atmos.to_pressure_coordinates(monthly_avgs, pressure)
35+
monthly_avgs = limit_pressure_dim_to_era5_range(monthly_avgs)
36+
end
37+
monthly_avgs = ClimaAnalysis.replace(monthly_avgs, missing => 0.0, NaN => 0.0)
38+
monthly_avgs = ClimaAnalysis.shift_to_start_of_previous_month(monthly_avgs)
39+
40+
# Cut off first 3 months
41+
single_year = window(monthly_avgs, "time"; left = 92days)
42+
seasons = split_by_season_across_time(single_year)
43+
# Ensure we are splitting evenly across seasons
44+
@assert all(map(x -> length(times(x)) == 3, seasons))
45+
seasonal_avgs = average_time.(seasons)
46+
47+
downsampled_seasonal_avg_arrays = downsample.(seasonal_avgs, 3)
48+
return vcat(vec.(downsampled_seasonal_avg_arrays)...)
49+
# return vectorize_nyears_of_seasonal_outputvars(seasonal_avgs, 1)
50+
end
51+
52+
function process_member_data(simdir::SimDir)
53+
isempty(simdir) && return fill!(zeros(single_member_length), NaN)
54+
55+
pressure = get_monthly_averages(simdir, "pfull")
56+
57+
rsdt_full = get_monthly_averages(simdir, "rsdt")
58+
rsut_full = get_monthly_averages(simdir, "rsut")
59+
rlut_full = get_monthly_averages(simdir, "rlut")
60+
year_net_radiation = (rlut_full + rsut_full - rsdt_full) |> average_lat |> average_lon |> average_time
61+
62+
rsut = process_outputvar(simdir, "rsut")
63+
rlut = process_outputvar(simdir, "rlut")
64+
rsutcs = process_outputvar(simdir, "rsutcs")
65+
rlutcs = process_outputvar(simdir, "rlutcs")
66+
cre = rsut + rlut - rsutcs - rlutcs
67+
68+
pr = process_outputvar(simdir, "pr")
69+
# shf = process_outputvar(simdir, "shf")
70+
ts = process_outputvar(simdir, "ts")
71+
72+
ta = process_outputvar(simdir, "ta")
73+
hur = process_outputvar(simdir, "hur")
74+
hus = process_outputvar(simdir, "hus")
75+
# clw = get_seasonal_averages(simdir, "clw")
76+
# cli = get_seasonal_averages(simdir, "cli")
77+
78+
return vcat(year_net_radiation.data, rsut, rlut, cre, pr, ts)#, ta, hur, hus)
79+
end
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
# Utilities for processing observations/OutputVars
2+
# Used to generate observations and compute the observation map
3+
using ClimaAnalysis, ClimaCoupler
4+
using Dates, LinearAlgebra, Statistics
5+
import EnsembleKalmanProcesses as EKP
6+
7+
# Workaround to read ql, qi from file nicely
8+
push!(ClimaAnalysis.Var.TIME_NAMES, "valid_time")
9+
10+
# Constants
11+
const days_in_seconds = 86_400
12+
const months = 31days_in_seconds
13+
const years = 365days_in_seconds
14+
const start_date = DateTime(2000, 3, 1)
15+
const first_year_start_date = DateTime(2000, 12, 1)
16+
17+
include(joinpath(pkgdir(ClimaCoupler), "experiments/ClimaEarth/leaderboard/data_sources.jl"))
18+
19+
# The ERA5 pressure range is not as large as the ClimaAtmos default pressure levels,
20+
# so we need to limit outputvars to the ERA5 dims
21+
function limit_pressure_dim_to_era5_range(diagnostic_var3d)
22+
@assert has_pressure(diagnostic_var3d)
23+
era5_pressure_min = 100.0 # Pa
24+
era5_pressure_max = 100_000.0 # Pa
25+
pfull_dims = diagnostic_var3d.dims[pressure_name(diagnostic_var3d)]
26+
left = minimum(filter(x -> x >= era5_pressure_min, pfull_dims))
27+
right = maximum(filter(x -> x <= era5_pressure_max, pfull_dims))
28+
# Window the diagnostic var to use the era5 pressure bounds
29+
return window(diagnostic_var3d, "pfull"; left, right)
30+
end
31+
32+
"""
33+
get_all_output_vars(obs_dir, diagnostic_var2d, diagnostic_var3d)
34+
35+
Return a NamedTuple of OutputVars containing all initial coarse AMIP observations.
36+
Start date is set to `DateTime(2000, 3, 1)`. All OutputVars are resampled to the model diagnostic grid.
37+
"""
38+
function get_all_output_vars(obs_dir, diagnostic_var2d, diagnostic_var3d)
39+
diagnostic_var3d = limit_pressure_dim_to_era5_range(diagnostic_var3d)
40+
41+
resample_2d(output_var) = resampled_as(output_var, diagnostic_var2d, dim_names = ["longitude", "latitude"])
42+
resample_3d(output_var) =
43+
resampled_as(output_var, diagnostic_var3d, dim_names = ["longitude", "latitude", "pressure_level"])
44+
resample(ov) = has_pressure(ov) ? resample_3d(ov) : resample_2d(ov)
45+
46+
era5_outputvar(path) = OutputVar(path; new_start_date = start_date, shift_by = Dates.firstdayofmonth)
47+
48+
rad_and_pr_obs_dict = get_obs_var_dict()
49+
50+
# 2D Fields
51+
# TOA incoming shortwave radiation
52+
rsdt = resample(rad_and_pr_obs_dict["rsdt"](start_date))
53+
# TOA outgoing long, shortwave radiation
54+
rlut = resample(rad_and_pr_obs_dict["rlut"](start_date))
55+
rsut = resample(rad_and_pr_obs_dict["rsut"](start_date))
56+
# TOA clearsky outgoing long, shortwave radiation
57+
rsutcs = resample(rad_and_pr_obs_dict["rsutcs"](start_date))
58+
rlutcs = resample(rad_and_pr_obs_dict["rlutcs"](start_date))
59+
60+
# TOA net radiative flux
61+
net_rad = rlut + rsut - rsdt
62+
# For some reason we need to add the start date back in
63+
net_rad.attributes["start_date"] = string(start_date)
64+
65+
# cloud radiative effect
66+
cre = rsut + rlut - rsutcs - rlutcs
67+
cre.attributes["start_date"] = string(start_date)
68+
69+
# Precipitation
70+
pr = resample(rad_and_pr_obs_dict["pr"](start_date))
71+
72+
# Latent heat flux
73+
lhf = resample(era5_outputvar(joinpath(obs_dir, "era5_monthly_averages_surface_single_level_mslhf.nc")))
74+
# Sensible heat flux
75+
shf = resample(era5_outputvar(joinpath(obs_dir, "era5_monthly_averages_surface_single_level_msshf.nc")))
76+
shf = lhf + shf
77+
shf.attributes["start_date"] = string(start_date)
78+
# Surface temperature
79+
ts = resample(era5_outputvar(joinpath(obs_dir, "era5_monthly_avg_ts.nc")))
80+
# 3D Fields
81+
# Air temperature
82+
ta = resample(era5_outputvar(joinpath(obs_dir, "era5_monthly_avg_pressure_level_t.nc")))
83+
84+
# relative humidity
85+
hur = resample(era5_outputvar(joinpath(obs_dir, "era5_monthly_avg_pressure_level_r.nc")))
86+
# specific humidity
87+
hus = resample(era5_outputvar(joinpath(obs_dir, "era5_monthly_avg_pressure_level_q.nc")))
88+
89+
# # Cloud specific liquid water content
90+
# ql = era5_outputvar(joinpath(obs_dir, "era5_specific_cloud_liquid_water_content.nc"))
91+
# # Cloud specific ice water content
92+
# qi = era5_outputvar(joinpath(obs_dir, "era5_specific_cloud_ice_water_content.nc"))
93+
# foreach((ql, qi)) do var
94+
# # Convert from hPa to Pa in-place so we don't create more huge OutputVars
95+
# @assert var.dim_attributes[pressure_name(var)]["units"] == "hPa"
96+
# var.dims[pressure_name(var)] .*= 100.0
97+
# set_dim_units!(var, pressure_name(var), "Pa")
98+
# end
99+
100+
# ql = resample(reverse_dim(reverse_dim(ql, latitude_name(ql)), pressure_name(ql)))
101+
# qi = resample(reverse_dim(reverse_dim(qi, latitude_name(qi)), pressure_name(qi)))
102+
103+
return (; rlut, rsut, rsutcs, rlutcs, pr, net_rad, cre, shf, ts, ta, hur, hus)#, ql, qi)
104+
end
105+
106+
#####
107+
# Processing to create EKP.ObservationSeries
108+
#####
109+
110+
to_datetime(start_date, time) = DateTime(start_date) + Second(time)
111+
112+
get_monthly_averages(simdir, var_name) = get(simdir; short_name = var_name, reduction = "average", period = "1M")
113+
114+
get_seasonal_averages(var) = average_time.(split_by_season_across_time(var))
115+
116+
function get_seasonal_averages(simdir, var_name)
117+
var = get(simdir; short_name = var_name, reduction = "average", period = "1M")
118+
get_seasonal_averages(var)
119+
end
120+
121+
function get_yearly_averages(var)
122+
seasonal_avgs = get_seasonal_averages(var)
123+
nyears = fld(length(seasonal_avgs), 4)
124+
matrices = getproperty.(seasonal_avgs, :data)
125+
year_averaged_matrices = map(1:nyears) do i
126+
start_idx = (i - 1) * 4 + 1
127+
end_idx = i * 4
128+
group = matrices[start_idx:end_idx]
129+
130+
# Compute the average matrix for this group
131+
averaged_matrix = sum(group) / 4
132+
averaged_matrix
133+
end
134+
return year_averaged_matrices
135+
end
136+
137+
# Given an outputvar, compute the standard deviation at each point for each season.
138+
function get_seasonal_stdev(output_var)
139+
all_seasonal_averages = get_seasonal_averages(output_var)
140+
all_seasonal_averages = downsample.(all_seasonal_averages, 3)
141+
seasonal_average_matrix = cat(all_seasonal_averages...; dims = 3)
142+
interannual_stdev = std(seasonal_average_matrix, dims = 3)
143+
# TODO: Add intraseasonal stdev?
144+
return dropdims(interannual_stdev; dims = 3)
145+
end
146+
147+
# Given an outputvar, compute the covariance for each season.
148+
function get_seasonal_covariance(output_var)
149+
stdev = get_seasonal_stdev(output_var)
150+
return Diagonal(vec(stdev) .^ 2)
151+
end
152+
153+
# Given a year, return the indices of that year within a seasonal array
154+
# Assume each year has 4 seasons and starts at index % 4 == 1
155+
function get_year_indices(year)
156+
start_index = (year * 4) - 3
157+
end_index = year * 4
158+
return start_index:end_index
159+
end
160+
161+
# Take in a vector of seasonal average OutputVars and a range or single number representing the years,
162+
# return a vector of all data within the year range
163+
function vectorize_nyears_of_seasonal_outputvars(vec_of_vars, year_range)
164+
# Generate indices for all specified years
165+
all_year_indices = vcat(get_year_indices.(year_range)...)
166+
result = vcat(vec.(getproperty.(vec_of_vars[all_year_indices], :data))...)
167+
return result
168+
end
169+
170+
# Make an EKP.Observation of a single year of seasonal averages from an OutputVar
171+
function make_single_year_of_seasonal_observations(output_var, yr)
172+
seasonal_avgs = get_seasonal_averages(output_var)
173+
downsampled_seasonal_avg_arrays = downsample.(seasonal_avgs, 3)
174+
all_year_indices = vcat(get_year_indices.(yr)...)
175+
obs_vec = vcat(vec.(downsampled_seasonal_avg_arrays[all_year_indices])...)
176+
177+
name = get(output_var.attributes, "CF_name", get(output_var.attributes, "long_name", ""))
178+
cov = get_seasonal_covariance(output_var)
179+
return EKP.Observation(obs_vec, Diagonal(repeat(cov.diag, 4)), "$(yr)_$name")
180+
end
181+
182+
"""
183+
create_observation_vector(nt, yrs = 19)
184+
185+
"""
186+
function create_observation_vector(nt, yrs = 19)
187+
# Starting year is 2000-12 to 2001-11
188+
t_start = Second(first_year_start_date - start_date).value
189+
rsut = window(nt.rsut, "time"; left = t_start)
190+
rlut = window(nt.rlut, "time"; left = t_start)
191+
rsutcs = window(nt.rsutcs, "time"; left = t_start)
192+
rlutcs = window(nt.rlutcs, "time"; left = t_start)
193+
cre = window(nt.cre, "time"; left = t_start)
194+
195+
# Net radiation uses yearly averages, so we treat it differently
196+
net_rad = window(nt.net_rad, "time"; left = t_start) |> average_lat |> average_lon
197+
net_rad = get_yearly_averages(net_rad)
198+
net_rad_stdev = std(cat(net_rad..., dims = 3), dims = 3)
199+
net_rad_covariance = Diagonal(vec(net_rad_stdev) .^ 2)
200+
201+
ts = window(nt.ts, "time"; left = t_start)
202+
pr = window(nt.pr, "time"; left = t_start)
203+
shf = window(nt.shf, "time"; left = t_start)
204+
205+
ta = window(nt.ta, "time"; left = t_start)
206+
207+
hur = window(nt.hur, "time"; left = t_start)
208+
hus = window(nt.hus, "time"; left = t_start)
209+
210+
all_observations = map(1:yrs) do yr
211+
net_rad_obs = EKP.Observation(vec(net_rad[yr]), net_rad_covariance, "$(yr)_net_rad")
212+
213+
rsut_obs = make_single_year_of_seasonal_observations(rsut, yr)
214+
rlut_obs = make_single_year_of_seasonal_observations(rlut, yr)
215+
rsutcs_obs = make_single_year_of_seasonal_observations(rsutcs, yr)
216+
rlutcs_obs = make_single_year_of_seasonal_observations(rlutcs, yr)
217+
cre_obs = make_single_year_of_seasonal_observations(cre, yr)
218+
pr_obs = make_single_year_of_seasonal_observations(pr, yr)
219+
# shf_obs = make_single_year_of_seasonal_observations(shf, yr)
220+
ts_obs = make_single_year_of_seasonal_observations(ts, yr)
221+
222+
# ta_obs = make_single_year_of_seasonal_observations(ta, yr)
223+
# hur_obs = make_single_year_of_seasonal_observations(hur, yr)
224+
# hus_obs = make_single_year_of_seasonal_observations(hus, yr)
225+
EKP.combine_observations([net_rad_obs, rsut_obs, rlut_obs, cre_obs, pr_obs, ts_obs])#, ta_obs, hur_obs, hus_obs])
226+
end
227+
228+
return all_observations # NOT an EKP.ObservationSeries
229+
end
230+
231+
downsample(var::ClimaAnalysis.OutputVar, n) = downsample(var.data, n)
232+
233+
function downsample(arr::AbstractArray, n)
234+
if n < 1
235+
error("Downsampling factor n must be at least 1.")
236+
end
237+
if ndims(arr) == 2
238+
return arr[1:n:end, 1:n:end]
239+
elseif ndims(arr) == 3
240+
return arr[1:n:end, 1:n:end, :]
241+
else
242+
error("Only 2D and 3D arrays are supported.")
243+
end
244+
end
245+
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
using Distributed
2+
import ClimaCalibrate as CAL
3+
using ClimaCalibrate
4+
using ClimaAnalysis
5+
import ClimaAnalysis: SimDir, get, slice, average_xy
6+
import ClimaComms
7+
import ClimaCoupler
8+
using EnsembleKalmanProcesses.ParameterDistributions
9+
import EnsembleKalmanProcesses as EKP
10+
11+
include(joinpath(pkgdir(ClimaCoupler), "experiments/calibration/coarse_amip/observation_map.jl"))
12+
13+
ENV["JULIA_WORKER_TIMEOUT"] = "300.0"
14+
# addprocs(CAL.SlurmManager())
15+
# Make variables and the forward model available on the worker sessions
16+
@everywhere begin
17+
import ClimaCoupler
18+
experiment_dir = joinpath(pkgdir(ClimaCoupler), "experiments/calibration/coarse_amip/")
19+
include(joinpath(experiment_dir, "model_interface.jl"))
20+
end
21+
22+
# Experiment Configuration
23+
output_dir = "experiments/calibration/output"
24+
ensemble_size = 20
25+
n_iterations = 18 # Cycle through all data
26+
priors = [constrained_gaussian("liquid_cloud_effective_radius", 14e-6, 6e-6, 2.5e-6, 21.5e-6)]
27+
prior = combine_distributions(priors)
28+
observation_vec = JLD2.load_object(joinpath(experiment_dir, "observations.jld2"))
29+
30+
batch_size = 1
31+
num_batches = cld(length(observation_vec), batch_size)
32+
batches = map(1:num_batches) do i
33+
start_idx = (i - 1) * batch_size + 1
34+
end_idx = min(i * batch_size, length(observation_vec))
35+
start_idx:end_idx
36+
end
37+
minibatcher = EKP.FixedMinibatcher(batches)
38+
39+
series_names = string.(1:length(observation_vec))
40+
observation_series = EKP.ObservationSeries(observation_vec, minibatcher, series_names)
41+
42+
eki = EKP.EnsembleKalmanProcess(
43+
EKP.construct_initial_ensemble(prior, ensemble_size),
44+
observation_series,
45+
EKP.TransformInversion(),
46+
)
47+
48+
eki = CAL.calibrate(CAL.WorkerBackend, eki, ensemble_size, n_iterations, prior, output_dir)
49+
50+
# Postprocessing
51+
import Statistics: var, mean
52+
using Test
53+
import CairoMakie
54+
55+
function scatter_plot(eki::EKP.EnsembleKalmanProcess)
56+
f = CairoMakie.Figure(resolution = (800, 600))
57+
ax = CairoMakie.Axis(f[1, 1], ylabel = "Parameter Value", xlabel = "Top of atmosphere radiative SW flux")
58+
59+
g = vec.(EKP.get_g(eki; return_array = true))
60+
params = vec.((EKP.get_ϕ(prior, eki)))
61+
62+
for (gg, uu) in zip(g, params)
63+
CairoMakie.scatter!(ax, gg, uu)
64+
end
65+
66+
CairoMakie.vlines!(ax, observations, linestyle = :dash)
67+
68+
output = joinpath(output_dir, "scatter.png")
69+
CairoMakie.save(output, f)
70+
return output
71+
end
72+
73+
function param_versus_iter_plot(eki::EKP.EnsembleKalmanProcess)
74+
f = CairoMakie.Figure(resolution = (800, 600))
75+
ax = CairoMakie.Axis(f[1, 1], ylabel = "Parameter Value", xlabel = "Iteration")
76+
params = EKP.get_ϕ(prior, eki)
77+
for (i, param) in enumerate(params)
78+
CairoMakie.scatter!(ax, fill(i, length(param)), vec(param))
79+
end
80+
81+
output = joinpath(output_dir, "param_vs_iter.png")
82+
CairoMakie.save(output, f)
83+
return output
84+
end
85+
86+
scatter_plot(eki)
87+
param_versus_iter_plot(eki)
88+
89+
params = EKP.get_ϕ(prior, eki)
90+
spread = map(var, params)
91+
92+
# Spread should be heavily decreased as particles have converged
93+
@test last(spread) / first(spread) < 0.1
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#!/bin/bash
2+
3+
#SBATCH --partition=a3
4+
#SBATCH --output="run_calibration.txt"
5+
#SBATCH --time=24:00:00
6+
#SBATCH --ntasks=20
7+
#SBATCH --gpus-per-task=1
8+
#SBATCH --cpus-per-task=4
9+
10+
julia --project=experiments/calibration -e 'using Pkg; Pkg.develop(;path="."); Pkg.instantiate(;verbose=true)'
11+
12+
julia --project=experiments/calibration experiments/calibration/coarse_amip/run_calibration.jl
13+

0 commit comments

Comments
 (0)
Please sign in to comment.