Issue
This Content is from Stack Overflow. Question asked by Yrogirg
I can’t find a reference documentation for the gradient
function of Julia Flux, there are only several tutorial examples.
I understand how gradient
is used to compute gradients of functions, e.g. the syntax
f(x, y) = x^2 + y^2
df(x, y) = gradient(f, x, y)
will essentially yield df(x, y) = (2x, 2y)
. However later the tutorial uses the following syntax without any explanation:
gs = gradient(() -> loss(x, y), Flux.params(W, b))
I think there are ways to interpret a gradient of () -> loss(x, y)
from the math point of view, but I am not sure that this is what is going on here. So what is this anonymous function and why was gradient
designed that way? A link to the full documentation of gradient
would be appreciated.
Solution
First, the link to gradient
docs:
https://fluxml.ai/Zygote.jl/dev/#Zygote.gradient
gradient
knows to return the derivative with respect to variables referenced within a function but not passed as arguments. This is the case in the example in the question. It is called implicit style in the manual. The parameters to differentiate are passed in a Params
typed value. In the example the Params
is created by helper function Flux.params
. When using implicit style, the function is passed as a zero-argument function (see manual).
Regarding the syntax itself (irrespective of use with gradient
):
() -> loss(x, y)
is an anonymous function (sometimes called a lambda function in functional programming context). Essentially like a regular function, but for one-time use with no need to fret over choosing a name.
The Julia manual link regarding these is https://docs.julialang.org/en/v1/manual/functions/#man-anonymous-functions
Here are some more examples:
With 1 parameter: map(x -> x^2 + 2x - 1, [1, 3, -1])
With 2 parameters: (x,y) -> 2*x + y
With no parameters: () -> time()
This Question was asked in StackOverflow by Yrogirg and Answered by Dan Getz It is licensed under the terms of CC BY-SA 2.5. - CC BY-SA 3.0. - CC BY-SA 4.0.