@@ -56,7 +56,7 @@ create_model(expt_config, args) = get_model(expt_config; device=gpu, warmup=true
5656
5757function get_loss_function (args)
5858 if args[" model-type" ] == " VANILLA"
59- function loss_function_closure_vanilla (x, y, model, ps, st)
59+ function loss_function_closure_vanilla (x, y, model, ps, st, w_skip = args[ " w-skip " ] )
6060 (ŷ, soln), st_ = model (x, ps, st)
6161 celoss = logitcrossentropy (ŷ, y)
6262 skiploss = FastDEQExperiments. mae (soln. u₀, soln. z_star)
@@ -65,11 +65,11 @@ function get_loss_function(args)
6565 end
6666 return loss_function_closure_vanilla
6767 else
68- function loss_function_closure_skip (x, y, model, ps, st)
68+ function loss_function_closure_skip (x, y, model, ps, st, w_skip = args[ " w-skip " ] )
6969 (ŷ, soln), st_ = model (x, ps, st)
7070 celoss = logitcrossentropy (ŷ, y)
7171 skiploss = FastDEQExperiments. mae (soln. u₀, soln. z_star)
72- loss = celoss + args[ " w-skip " ] * skiploss
72+ loss = celoss + w_skip * skiploss
7373 return loss, st_, (ŷ, soln. nfe, celoss, skiploss, soln. residual)
7474 end
7575 return loss_function_closure_skip
@@ -185,7 +185,7 @@ function validate(val_loader, model, ps, st, loss_function, args)
185185end
186186
187187# Training
188- function train_one_epoch (train_loader, model, ps, st, optimiser_state, epoch, loss_function, args)
188+ function train_one_epoch (train_loader, model, ps, st, optimiser_state, epoch, loss_function, w_skip, args)
189189 batch_time = AverageMeter (" Batch Time" , " 6.3f" )
190190 data_time = AverageMeter (" Data Time" , " 6.3f" )
191191 forward_pass_time = AverageMeter (" Forward Pass Time" , " 6.3f" )
@@ -212,7 +212,7 @@ function train_one_epoch(train_loader, model, ps, st, optimiser_state, epoch, lo
212212 # Gradients and Update
213213 _t = time ()
214214 (loss, st, (ŷ, nfe_, celoss, skiploss, resi)), back = Zygote. pullback (
215- p -> loss_function (x, y, model, p, st), ps
215+ p -> loss_function (x, y, model, p, st, w_skip ), ps
216216 )
217217 forward_pass_time (time () - _t, B)
218218 _t = time ()
@@ -353,10 +353,12 @@ function main(args)
353353
354354 st = hasproperty (expt_config, :pretrain_epochs ) && getproperty (expt_config, :pretrain_epochs ) > 0 ? Lux. update_state (st, :fixed_depth , Val (getproperty (expt_config, :num_layers ))) : st
355355
356+ wskip_sched = ParameterSchedulers. Exp (args[" w-skip" ], 0.92f0 )
357+
356358 for epoch in args[" start-epoch" ]: (expt_config. nepochs)
357359 # Train for 1 epoch
358360 ps, st, optimiser_state, train_stats = train_one_epoch (
359- train_loader, model, ps, st, optimiser_state, epoch, loss_function, args
361+ train_loader, model, ps, st, optimiser_state, epoch, loss_function, wskip_sched (epoch), args
360362 )
361363 train_stats = get_loggable_stats (train_stats)
362364
0 commit comments