import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
from typing import Any, Dict, List, Optional, Tuple
import math
from collections import defaultdict, deque
[docs]
class LBFGS(Optimizer):
"""Limited-memory Broyden-Fletcher-Goldfarb-Shanno optimizer"""
[docs]
def __init__(
self,
params,
lr=1,
max_iter=20,
max_eval=None,
tolerance_grad=1e-7,
tolerance_change=1e-9,
history_size=100,
line_search_fn=None,
):
if max_eval is None:
max_eval = max_iter * 5 // 4
defaults = dict(
lr=lr,
max_iter=max_iter,
max_eval=max_eval,
tolerance_grad=tolerance_grad,
tolerance_change=tolerance_change,
history_size=history_size,
line_search_fn=line_search_fn,
)
super().__init__(params, defaults)
if len(self.param_groups) != 1:
raise ValueError("LBFGS doesn't support per-parameter options")
self._params = self.param_groups[0]["params"]
self._numel_cache = None
def _numel(self):
if self._numel_cache is None:
self._numel_cache = sum(p.numel() for p in self._params)
return self._numel_cache
def _gather_flat_grad(self):
views = []
for p in self._params:
if p.grad is None:
view = p.new(p.numel()).zero_()
elif p.grad.is_sparse:
view = p.grad.to_dense().view(-1)
else:
view = p.grad.view(-1)
views.append(view)
return torch.cat(views, 0)
def _add_grad(self, step_size, update):
offset = 0
for p in self._params:
numel = p.numel()
p.add_(update[offset : offset + numel].view_as(p), alpha=step_size)
offset += numel
[docs]
def step(self, closure):
"""Perform a single optimization step"""
assert len(self.param_groups) == 1
group = self.param_groups[0]
lr = group["lr"]
max_iter = group["max_iter"]
max_eval = group["max_eval"]
tolerance_grad = group["tolerance_grad"]
tolerance_change = group["tolerance_change"]
line_search_fn = group["line_search_fn"]
history_size = group["history_size"]
state = self.state[self._params[0]]
# initialize state
if len(state) == 0:
state["func_evals"] = 0
state["n_iter"] = 0
# evaluate initial f(x) and df/dx
orig_loss = closure()
loss = float(orig_loss)
current_evals = 1
state["func_evals"] += 1
flat_grad = self._gather_flat_grad()
opt_cond = flat_grad.abs().max() <= tolerance_grad
# check optimality condition
if opt_cond:
return orig_loss
# tensors cached in state (for tracing)
d = state.get("d")
t = state.get("t")
old_dirs = state.get("old_dirs")
old_stps = state.get("old_stps")
ro = state.get("ro")
H_diag = state.get("H_diag")
prev_flat_grad = state.get("prev_flat_grad")
prev_loss = state.get("prev_loss")
n_iter = 0
# optimize for a max of max_iter iterations
while n_iter < max_iter:
# keep track of nb of iterations
n_iter += 1
state["n_iter"] += 1
############################################################
# compute gradient descent direction
############################################################
if state["n_iter"] == 1:
d = flat_grad.neg()
old_dirs = []
old_stps = []
ro = []
H_diag = 1
else:
# do lbfgs update (update memory)
y = flat_grad.sub(prev_flat_grad)
s = d.mul(t)
ys = y.dot(s) # y*s
if ys > 1e-10:
# updating memory
if len(old_dirs) == history_size:
# shift history by one (limited-memory)
old_dirs.pop(0)
old_stps.pop(0)
ro.pop(0)
# store new direction/step
old_dirs.append(y)
old_stps.append(s)
ro.append(1.0 / ys)
# update scale of initial Hessian approximation
H_diag = ys / y.dot(y) # (y*s)/(y*y)
# compute the approximate (L-BFGS) inverse Hessian
# multiplied by the gradient
num_old = len(old_dirs)
if "al" not in state:
state["al"] = [None] * history_size
al = state["al"]
# iteration in L-BFGS loop collapsed to use just one buffer
q = flat_grad.neg()
for i in range(num_old - 1, -1, -1):
al[i] = old_stps[i].dot(q) * ro[i]
q.add_(old_dirs[i], alpha=-al[i])
# multiply by initial Hessian
# r/d is the final direction
d = r = torch.mul(q, H_diag)
for i in range(num_old):
be_i = old_dirs[i].dot(r) * ro[i]
r.add_(old_stps[i], alpha=al[i] - be_i)
if prev_flat_grad is None:
prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
else:
prev_flat_grad.copy_(flat_grad)
prev_loss = loss
############################################################
# compute step length
############################################################
# reset initial guess for step size
if state["n_iter"] == 1:
t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr
else:
t = lr
# directional derivative
gtd = flat_grad.dot(d) # g * d
# directional derivative is below tolerance
if gtd > -tolerance_change:
break
# optional line search: user function
ls_func_evals = 0
if line_search_fn is not None:
# perform line search, using user function
if line_search_fn != "strong_wolfe":
raise RuntimeError("only 'strong_wolfe' is supported")
else:
x_init = self._clone_param()
def obj_func(x, t, d):
return self._directional_evaluate(closure, x, t, d)
loss, flat_grad, t, ls_func_evals = _strong_wolfe(obj_func, x_init, t, d, loss, flat_grad, gtd)
self._add_grad(t, d)
opt_cond = flat_grad.abs().max() <= tolerance_grad
else:
# no line search, simply move with fixed-step
self._add_grad(t, d)
if n_iter != max_iter:
# re-evaluate function only if not in last iteration
# the reason we do this: in a stochastic setting,
# no use to re-evaluate that function here
with torch.enable_grad():
loss = float(closure())
flat_grad = self._gather_flat_grad()
opt_cond = flat_grad.abs().max() <= tolerance_grad
ls_func_evals = 1
# update func eval
current_evals += ls_func_evals
state["func_evals"] += ls_func_evals
############################################################
# check conditions
############################################################
if n_iter == max_iter:
break
if current_evals >= max_eval:
break
# optimal condition
if opt_cond:
break
# lack of progress
if d.mul(t).abs().max() <= tolerance_change:
break
if abs(loss - prev_loss) < tolerance_change:
break
state["d"] = d
state["t"] = t
state["old_dirs"] = old_dirs
state["old_stps"] = old_stps
state["ro"] = ro
state["H_diag"] = H_diag
state["prev_flat_grad"] = prev_flat_grad
state["prev_loss"] = prev_loss
return orig_loss
def _clone_param(self):
return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
def _set_param(self, params_data):
for p, pdata in zip(self._params, params_data):
p.copy_(pdata)
def _directional_evaluate(self, closure, x, t, d):
self._add_grad(t, d)
with torch.enable_grad():
loss = float(closure())
flat_grad = self._gather_flat_grad()
self._set_param(x)
return loss, flat_grad
def _strong_wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25):
"""Strong Wolfe line search"""
# This is a simplified version - full implementation would be more complex
d_norm = d.abs().max()
g = g.clone(memory_format=torch.contiguous_format)
# evaluate objective and gradient at the initial step length
f_new, g_new = obj_func(x, t, d)
ls_func_evals = 1
gtd_new = g_new.dot(d)
# check first Wolfe condition
if f_new > (f + c1 * t * gtd) or ls_func_evals >= max_ls:
return f_new, g_new, t, ls_func_evals
# check second Wolfe condition
if abs(gtd_new) <= -c2 * gtd:
return f_new, g_new, t, ls_func_evals
# if we get here, return the current point
return f_new, g_new, t, ls_func_evals