~cgeoga/StandaloneKNITRO.jl

StandaloneKNITRO.jl/src/forwrapper.jl -rw-r--r-- 2.2 KiB
954584dd — Chris Geoga Small tweaks with hvp options. 1 year, 3 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
struct GradWithTape{F} <: Function
  f::F
  tapes::Dict{Type, ReverseDiff.CompiledTape}
end

struct FoRWrapper{T,F,C}
  arglen::Int64 
  buf_g::Vector{T}
  buf_h::Matrix{T}
  fun::F
  grad_withtape::GradWithTape{F}
  cfg::C
end

getF(fw::FoRWrapper{T,F,C}) where{T,F,C} = F

GradWithTape(f) = GradWithTape(f, Dict{Type,ReverseDiff.CompiledTape}())

function (gt::GradWithTape{F})(g, p::P) where{F,P}
  if !haskey(gt.tapes, P)
    newtape     = ReverseDiff.GradientTape(gt.f, p)
    cnewtape    = ReverseDiff.compile(newtape)
    gt.tapes[P] = cnewtape
  end
  ReverseDiff.gradient!(g, gt.tapes[P], p)
end

# Note that this _does_ allocate. Still need to think about the HVP without
# allocation. For now, it's definitely an edge case.
function (gt::GradWithTape{F})(p::P) where{F,P}
  if !haskey(gt.tapes, P)
    newtape     = ReverseDiff.GradientTape(gt.f, p)
    cnewtape    = ReverseDiff.compile(newtape)
    gt.tapes[P] = cnewtape
  end
  out = zeros(eltype(P), length(p))
  ReverseDiff.gradient!(out, gt.tapes[P], p)
  out
end

function FoRWrapper(f::F, sample_arg::Vector{T}, 
                    chunksize=min(length(sample_arg),16)) where{F,T}
  # Allocate everything:
  n = length(sample_arg)
  buf_g    = Array{T}(undef, n)
  buf_h    = Array{T}(undef, n, n)
  grad_withtape = GradWithTape(f)
  # A jacobian config to cut down on allocs in the final calling interface.
  cfg = ForwardDiff.JacobianConfig(grad_withtape, buf_g, sample_arg, 
                                   ForwardDiff.Chunk{chunksize}())
  FoRWrapper(n, buf_g, buf_h, f, grad_withtape, cfg)
end

function (M::FoRWrapper{T,F,C})(case::Symbol, arg::Vector{T}, v=nothing) where{T,F,C}
  in(case, (:value, :gradient, :hvp, :hessian)) || throw(error("Case options are :value, :gradient, :hvp, or :hessian."))
  length(arg) == M.arglen || throw(error("Provided argument is not the correct length."))
  if case == :value
    return M.fun(arg)
  elseif case == :gradient
    M.grad_withtape(M.buf_g, arg)
    return M.buf_g
  elseif case == :hessian
    ForwardDiff.jacobian!(M.buf_h, M.grad_withtape, M.buf_g, arg, M.cfg)
    return M.buf_h
  elseif case == :hvp
    return ForwardDiff.derivative(t->M.grad_withtape(arg + t*v), 0.0)
  end
end