Source code for torchium.optimizers.second_order.lbfgs

import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
from typing import Any, Dict, List, Optional, Tuple
import math
from collections import defaultdict, deque


[docs] class LBFGS(Optimizer): """Limited-memory Broyden-Fletcher-Goldfarb-Shanno optimizer"""
[docs] def __init__( self, params, lr=1, max_iter=20, max_eval=None, tolerance_grad=1e-7, tolerance_change=1e-9, history_size=100, line_search_fn=None, ): if max_eval is None: max_eval = max_iter * 5 // 4 defaults = dict( lr=lr, max_iter=max_iter, max_eval=max_eval, tolerance_grad=tolerance_grad, tolerance_change=tolerance_change, history_size=history_size, line_search_fn=line_search_fn, ) super().__init__(params, defaults) if len(self.param_groups) != 1: raise ValueError("LBFGS doesn't support per-parameter options") self._params = self.param_groups[0]["params"] self._numel_cache = None
def _numel(self): if self._numel_cache is None: self._numel_cache = sum(p.numel() for p in self._params) return self._numel_cache def _gather_flat_grad(self): views = [] for p in self._params: if p.grad is None: view = p.new(p.numel()).zero_() elif p.grad.is_sparse: view = p.grad.to_dense().view(-1) else: view = p.grad.view(-1) views.append(view) return torch.cat(views, 0) def _add_grad(self, step_size, update): offset = 0 for p in self._params: numel = p.numel() p.add_(update[offset : offset + numel].view_as(p), alpha=step_size) offset += numel
[docs] def step(self, closure): """Perform a single optimization step""" assert len(self.param_groups) == 1 group = self.param_groups[0] lr = group["lr"] max_iter = group["max_iter"] max_eval = group["max_eval"] tolerance_grad = group["tolerance_grad"] tolerance_change = group["tolerance_change"] line_search_fn = group["line_search_fn"] history_size = group["history_size"] state = self.state[self._params[0]] # initialize state if len(state) == 0: state["func_evals"] = 0 state["n_iter"] = 0 # evaluate initial f(x) and df/dx orig_loss = closure() loss = float(orig_loss) current_evals = 1 state["func_evals"] += 1 flat_grad = self._gather_flat_grad() opt_cond = flat_grad.abs().max() <= tolerance_grad # check optimality condition if opt_cond: return orig_loss # tensors cached in state (for tracing) d = state.get("d") t = state.get("t") old_dirs = state.get("old_dirs") old_stps = state.get("old_stps") ro = state.get("ro") H_diag = state.get("H_diag") prev_flat_grad = state.get("prev_flat_grad") prev_loss = state.get("prev_loss") n_iter = 0 # optimize for a max of max_iter iterations while n_iter < max_iter: # keep track of nb of iterations n_iter += 1 state["n_iter"] += 1 ############################################################ # compute gradient descent direction ############################################################ if state["n_iter"] == 1: d = flat_grad.neg() old_dirs = [] old_stps = [] ro = [] H_diag = 1 else: # do lbfgs update (update memory) y = flat_grad.sub(prev_flat_grad) s = d.mul(t) ys = y.dot(s) # y*s if ys > 1e-10: # updating memory if len(old_dirs) == history_size: # shift history by one (limited-memory) old_dirs.pop(0) old_stps.pop(0) ro.pop(0) # store new direction/step old_dirs.append(y) old_stps.append(s) ro.append(1.0 / ys) # update scale of initial Hessian approximation H_diag = ys / y.dot(y) # (y*s)/(y*y) # compute the approximate (L-BFGS) inverse Hessian # multiplied by the gradient num_old = len(old_dirs) if "al" not in state: state["al"] = [None] * history_size al = state["al"] # iteration in L-BFGS loop collapsed to use just one buffer q = flat_grad.neg() for i in range(num_old - 1, -1, -1): al[i] = old_stps[i].dot(q) * ro[i] q.add_(old_dirs[i], alpha=-al[i]) # multiply by initial Hessian # r/d is the final direction d = r = torch.mul(q, H_diag) for i in range(num_old): be_i = old_dirs[i].dot(r) * ro[i] r.add_(old_stps[i], alpha=al[i] - be_i) if prev_flat_grad is None: prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format) else: prev_flat_grad.copy_(flat_grad) prev_loss = loss ############################################################ # compute step length ############################################################ # reset initial guess for step size if state["n_iter"] == 1: t = min(1.0, 1.0 / flat_grad.abs().sum()) * lr else: t = lr # directional derivative gtd = flat_grad.dot(d) # g * d # directional derivative is below tolerance if gtd > -tolerance_change: break # optional line search: user function ls_func_evals = 0 if line_search_fn is not None: # perform line search, using user function if line_search_fn != "strong_wolfe": raise RuntimeError("only 'strong_wolfe' is supported") else: x_init = self._clone_param() def obj_func(x, t, d): return self._directional_evaluate(closure, x, t, d) loss, flat_grad, t, ls_func_evals = _strong_wolfe(obj_func, x_init, t, d, loss, flat_grad, gtd) self._add_grad(t, d) opt_cond = flat_grad.abs().max() <= tolerance_grad else: # no line search, simply move with fixed-step self._add_grad(t, d) if n_iter != max_iter: # re-evaluate function only if not in last iteration # the reason we do this: in a stochastic setting, # no use to re-evaluate that function here with torch.enable_grad(): loss = float(closure()) flat_grad = self._gather_flat_grad() opt_cond = flat_grad.abs().max() <= tolerance_grad ls_func_evals = 1 # update func eval current_evals += ls_func_evals state["func_evals"] += ls_func_evals ############################################################ # check conditions ############################################################ if n_iter == max_iter: break if current_evals >= max_eval: break # optimal condition if opt_cond: break # lack of progress if d.mul(t).abs().max() <= tolerance_change: break if abs(loss - prev_loss) < tolerance_change: break state["d"] = d state["t"] = t state["old_dirs"] = old_dirs state["old_stps"] = old_stps state["ro"] = ro state["H_diag"] = H_diag state["prev_flat_grad"] = prev_flat_grad state["prev_loss"] = prev_loss return orig_loss
def _clone_param(self): return [p.clone(memory_format=torch.contiguous_format) for p in self._params] def _set_param(self, params_data): for p, pdata in zip(self._params, params_data): p.copy_(pdata) def _directional_evaluate(self, closure, x, t, d): self._add_grad(t, d) with torch.enable_grad(): loss = float(closure()) flat_grad = self._gather_flat_grad() self._set_param(x) return loss, flat_grad
def _strong_wolfe(obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25): """Strong Wolfe line search""" # This is a simplified version - full implementation would be more complex d_norm = d.abs().max() g = g.clone(memory_format=torch.contiguous_format) # evaluate objective and gradient at the initial step length f_new, g_new = obj_func(x, t, d) ls_func_evals = 1 gtd_new = g_new.dot(d) # check first Wolfe condition if f_new > (f + c1 * t * gtd) or ls_func_evals >= max_ls: return f_new, g_new, t, ls_func_evals # check second Wolfe condition if abs(gtd_new) <= -c2 * gtd: return f_new, g_new, t, ls_func_evals # if we get here, return the current point return f_new, g_new, t, ls_func_evals