diloco


It is difficult to locate and tightly synchronise a large number of accelerators
for outer_step in range(T):
	for worker in range(k):
		local_param = params[outer_step-1]
		for inner_step in range(H):
			loss = model(shard_i, local_param)
			local_gradients = inner_opt.step(local_param, grad_loss)

	delta_param = avg(params[outer_step-1], local_param)
	params[outer_step] = outer_opt.step(params[outer_step-1], delta_param)
since we decay the inner learning rate and the outer gradient gets naturally smaller over the course of training we don't need to decay the outer learning rate