jax automatic differentiation - python

I have the following three functions implements in JAX.
def helper_1(params1):
...calculations...
return z
def helper_2(z, params2):
...calculations...
return y
def main(params1, params2):
z = helper_1(params1)
y = helper_2(z, params2)
return z,y
I am interested in the partial derivatives of the output from main, i.e. z and y, with respect to both params1 and params2. As params1 and params2 are low dimensional and z and y are high dimensional, I am using the jax.jacfwd function.
When calling
jax.jacfwd(main,argnums=(0,1))(params1,params2)
Jax computes the derivatives of z with respect to params1 (and params2, which in this case is just a bunch of zeros). My question is: does Jax recompute dz/d_param1 for the derivatives of y with respect to params1 and params2, or does it somehow figure out this has already been computed?
I don't know if this is relevant, but the 'helper_1' function contains functions from the TensorFlow library for Jax. Thanks!

In general, in the situation you describe JAX's forward-mode autodiff approach will re-use the derivative of z when computing the derivative of y. If you wish, you can confirm this by looking at the jaxpr of your differentiated function:
print(jax.make_jaxpr(jax.jacfwd(main, (0, 1)))(params1, params2))
Though if your function is more than moderately complicated, the output might be hard to understand.
As a general note, though, JAX's autodiff implementation does tend to produce a small number of unnecessary or duplicated computations. As a simple example consider this:
import jax
print(jax.make_jaxpr(jax.grad(jax.lax.sin))(1.0))
# { lambda ; a:f32[]. let
# _:f32[] = sin a
# b:f32[] = cos a
# c:f32[] = mul 1.0 b
# in (c,) }
Here the primal value sin(a) is computed even though it is never used in computing the final output.
In practice this can be addressed by wrapping your computation in jit, in which case the XLA compiler takes care of optimization, including dead code elimination and de-duplication when applicable:
result = jit(jax.jacfwd(main, (0, 1)))(params1, params2)

Related

Mixed partial dervative w.r.t. tensor in Pytorch

Question:
Is there any working method to calculate gradient of (non-scalar) tensor function?
Example
Given n by n symmetric matrices X, Y and matrix function Z(X, Y) = torch.mm(X.mm(X), Y) calculate d(dZ/dX)/dY.
Expected answer
d(dZ/dX)/dY = d(2*XY)/dY = 2*X
Attempts
Because torch's .backward() works only for scalar variables I've tried to calculate derivative by applying torch.autograd.grad() to each element of tensor Z, but this approach is not correct, because it gives d(X^2)/dX = X + 2*D where D is a diagonal matrix with diagonal values of X. For me it's a bit weird that torch has an ability to build a computational graph, but can't track tensor through it as a variable to get tensor derivative.
Edit
Question was not very clear, so I decided to give more details.
My aim is to get partial derivative of loss function, which involves two matrices as variables. It looks like that:
loss = torch.linalg.norm(my_formula(X, Y) , ord='fro')
And I need to find
d^2(loss)/d(Y^2)
d/dX[d(loss)/dY]
Torch is capable of calculating 1. by using .backward() two times, but it's problematic to find 2. because torch.autograd.grad() expects scalar input and not the tensor
TL;DR
For function f which takes a matrix and gives a scalar:
Find first order derivative, let's name it dX
Take trace: Tr(dX)
To get mixed partial derivative just use the trace from above: d/dY[df/dX] = d/dY[Tr(df/dX)]
Intro
At the moment of posting the question I was not really that good at theory of matrix derivatives, but now I know much more all thanks to this Yandex ml book (unfortunately, I didn't find the english equivalent). This is an attempt to give a full answer to my question.
Basic Theory
Forgive me, Lord, for ugly representation of latex
Let's say you have a function which takes matrix X and returns it's squared Frobenius norm: f(X) = ||X||_F^2
It is a well-known fact that: ||X||_F^2 = Tr(X X^T)
Let's define derivative as shown in same book: D[f] at X_0 = f(X + H) - f(X)
We are ready to find dg(X)/dX:
df(X)/dX = dTr(X X^T)/dX =
(using Trace's feature)
= Tr(d/dX[X X^T]) = Tr(dX/dX X^T + X d[X^T]/dX ) =
(then we should use the definition of derivative from above)
= Tr(HX^T + XH^T) = Tr(HX^T) + Tr(XH^T) =
(now the main trick is to get all matrices H on the right side and get something like
Tr(g(X) H) or Tr(g(X) H^T), where g(X) will be the derivative we are looking for)
= Tr(HX^T) + Tr(XH^T) = Tr(XH^T) + Tr(XH^T) = Tr(2*XH^T)
That means: df(X)/dX = 2X
Second order derivative
Now, after we found out how to get matrix derivatives, let's try to find second order derivative of the same function f(X):
d/dX[df(X)/dX] = d/dX[Tr(2XH_1^T)] = Tr(d/dX[2XH_1^T]) =
= Tr(2I H_2 H_1^T)
We found out that d/dX[df(X)/dX] = 2I where I stands for Identity matrix. But how will it help us to find derivatives in Pytorch?
Trace is the trick
As we can see from the formulas, both first and second order derivatives have Trace inside them, but when we take first order derivative we just instantly get matrix as a result. To get a higher order derivative we just need to take the derivative of trace of first order derivative:
d/dY[df/dX] = d/dY[Tr(df/dX)]
The thing is I was using JAX autograd library when this trick came to my mind, so the code with a function f(X,Y) will look like this:
def scalarized_dy(X, Y):
dY = grad(f, argnums=1)(X, Y)
return jnp.trace(dY)
dYX = grad(scalarized_dy, argnums=0)(X, Y)
dYY = grad(scalarized_dy, argnums=1)(X, Y)
In case of Pytorch I guess we will need to look after tensors' gradients (let loss be a function with X and Y as arguments):
loss = f(X, Y)
loss.backward(create_graph = True)
dX = torch.trace(X.grad)
dX.backward()
dXX = X.grad
dXY = Y.grad
Epilogue
I thought that the question itself is in some way interesting. Also, it took me several months to figure things out, so I decided to give my current point of view on this problem. I will not mark my answer as correct yet in hope that I will get some kind of feedback or, perhaps, even better answers or ideas.

Sympy function derivatives and sets of equations

I'm working with nonlinear systems of equations. These systems are generally a nonlinear vector differential equation.
I now want to use functions and derive them with respect to time and to their time-derivatives, and find equilibrium points by solving the nonlinear equations 0=rhs(eqs).
Similar things are needed to calculate the Euler-Lagrange equations, where you need the derivative of L wrt. diff(x,t).
Now my question is, how do I implement this in Sympy?
My main 2 problems are, that deriving a Symbol f wrt. t diff(f,t), I get 0. I can see, that with
x = Symbol('x',real=True);
diff(x.subs(x,x(t)),t) # because diff(x,t) => 0
and
diff(x**2, x)
does kind of work.
However, with
x = Fuction('x')(t);
diff(x,t);
I get this to work, but I cannot differentiate wrt. the funtion x itself, like
diff(x**2,x) -DOES NOT WORK.
Since I need these things, especially not only for scalars, but for vectors (using jacobian) all the time, I really want this to be a clean and functional workflow.
Which kind of type should I initiate my mathematical functions in Sympy in order to avoid strange substitutions?
It only gets worse for matricies, where I cannot get
eqns = Matrix([f1-5, f2+1]);
variabs = Matrix([f1,f2]);
nonlinsolve(eqns,variabs);
to work as expected, since it only allows symbols as input. Is there an easy conversion here? Like eqns.tolist() - which doesn't work either?
EDIT:
I just found this question, which was answered towards using expressions and matricies. I want to be able to solve sets of nonlinear equations, build the jacobian of a vector wrt. another vector and derive wrt. functions as stated above. Can anyone point me into a direction to start a concise workflow for this purpose? I guess the most complex task is calculating the Lie-derivative wrt. a vector or list of functions, the rest should be straight forward.
Edit 2:
def substi(expr,variables):
return expr.subs( {w:w(t)} )
would automate the subsitution, such that substi(vector_expr,varlist_vector).diff(t) is not all 0.
Yes, one has to insert an argument in a function before taking its derivative. But after that, differentiation with respect to x(t) works for me in SymPy 1.1.1, and I can also differentiate with respect to its derivative. Example of Euler-Lagrange equation derivation:
t = Symbol("t")
x = Function("x")(t)
L = x**2 + diff(x, t)**2 # Lagrangian
EL = -diff(diff(L, diff(x, t)), t) + diff(L, x)
Now EL is 2*x(t) - 2*Derivative(x(t), t, t) as expected.
That said, there is a build-in method for Euler-Lagrange:
EL = euler_equations(L)
would yield the same result, except presented as a differential equation with right-hand side 0: [Eq(2*x(t) - 2*Derivative(x(t), t, t), 0)]
The following defines x to be a function of t
import sympy as s
t = s.Symbol('t')
x = s.Function('x')(t)
This should solve your problem of diff(x,t) being evaluated as 0. But I think you will still run into problems later on in your calculations.
I also work with calculus of variations and Euler-Lagrange equations. In these calculations, x' needs to be treated as independent of x. So, it is generally better to use two entirely different variables for x and x' so as not to confuse Sympy with the relationship between those two variables. After we are done with the calculations in Sympy and we go back to our pen and paper we can substitute x' for the second variable.

Alternative plan of tf.floor

One of my operaction need integer, but output of convolution is float.
It means I need to use tf.floor, tf.ceil, tf.cast...etc to handle it.
But these operactions cause None gradients, since operactions like tf.floor are not differentiable
So, I tried something like below
First. detour
out1 = tf.subtract(vif, tf.subtract(vif, tf.floor(vif)))
But output of test.compute_gradient_error is 500 or 0, I don't think this is a reasonable gradient.
Second. override gradient function of floor
#ops.RegisterGradient("CustomFloor")
def _custom_floor_grad(op, grads):
return [grads]
A, B = 50, 7
shape = [A, B]
f = np.ones(shape, dtype=np.float32)
vif = tf.constant(f, dtype=tf.float32)
# out1 = tf.subtract(vif, tf.subtract(vif, tf.floor(vif)))
with tf.get_default_graph().gradient_override_map({"Floor": "CustomFloor"}):
out1 = tf.floor(vif)
with tf.Session() as sess:
err1 = tf.test.compute_gradient_error(vif, shape, out1, shape)
print err1
output of test.compute_gradient_error is 500 or 1, doesn't work too.
Question: A way to get integer and keep back propagation work fine (value like 2.0, 5.0 is ok)
In general, it's not inadvisable to solve discrete problem with gradient descent. You should be able express, to some extent integer solvers in TF but you're more or less on your own.
FWIW, the floor function looks like a saw. Its derivative is a constant function at 1 with little holes at every integer. At these positions you have a Dirac functional pointing downwards, like a rake if you wish. The Dirac functional has finite energy but no finite value.
The canonical way to tackle these problems is to relax the problem by "relaxiing" the hard floor constraint with something that is (at least once) differentiable (smooth).
There are multiple ways to do this. Perhaps the most popular are:
Hack up a function that looks like what you want. For instance a piece-wise linear function that slopes down quickly, but not vertically.
Replace step functions by sigmoids
Use a filter approximation which is well understood if it's a time series

scipy integrate over array with variable bounds

I am trying to integrate a function over a list of point and pass the whole array to an integration function in order ot vectorize the thing. For starters, calling scipy.integrate.quad is way too slow since I have something like 10 000 000 points to integrate. Using scipy.integrate.romberg does the trick much faster, almost instantaneous while quad is slow since you must loop over it or vectorize it.
My function is quite complicated, but for demonstation purpose, let's say I want to integrate x^2 from a to b, but x is an array of scalar to evaluate x. For example
import numpy as np
from scipy.integrate import quad, romberg
def integrand(x, y):
return x**2 + y**2
quad(integrand, 0, 10, args=(10) # this fails since y is not a scalar
romberg(integrand, 0, 10) # y works here, giving the integral over
# the entire range
But this only work for fixed bounds. Is there a way to do something like
z = np.arange(20,30)
romberg(integrand, 0, z) # Fails since the function doesn't seem to
# support variable bounds
Only way I see it is to re-implement the algorithm itself in numpy and use that instead so I can have variable bounds. Any function that supports something like this? There is also romb, where you must supply the values of integrand directly and a dx interval, but that will be too imprecise for my complicated function (the marcum Q function, couldn't find any implementation, that could be another way to dot it).
The best approach when trying to evaluate a special function is to write a function that uses the properties of the function to quickly and accurately evaluate it in all parameter regimes. It is quite unlikely that a single approach will give accurate (or even stable) results for all ranges of parameters. Direct evaluation of an integral, as in this case, will almost certainly break down in many cases.
That being said, the general problem of evaluating an integral over many ranges can be solved by turning the integral into a differential equation and solving that. Roughly, the steps would be
Given an integral I(t) which I will assume is an integral of a function f(x) from 0 to t [this can be generalized to an arbitrary lower limit], write it as the differential equation dI/dt = f(x).
Solve this differential equation using scipy.integrate.odeint() for some initial conditions (here I(0)) over some range of times from 0 to t. This range should contain all limits of interest. How finely this is sampled depends on the function and how accurately it needs to be evaluated.
The result will be the value of the integral from 0 to t for the set of t we input. We can turn this into a "continuous" function using interpolation. For example, using a spline we can define i = scipy.interpolate.InterpolatedUnivariateSpline(t,I).
Given a set of upper and lower limits in arrays b and a, respectively, then we can evaluate them all at once as res=i(b)-i(a).
Whether this approach will work in your case will require you to carefully study it over your range of parameters. Also note that the Marcum Q function involves a semi-infinite integral. In principle this is not a problem, just transform the integral to one over a finite range. For example, consider the transformation x->1/x. There is no guarantee this approach will be numerically stable for your problem.

Derivative of an array in python?

Currently I have two numpy arrays: x and y of the same size.
I would like to write a function (possibly calling numpy/scipy... functions if they exist):
def derivative(x, y, n = 1):
# something
return result
where result is a numpy array of the same size of x and containing the value of the n-th derivative of y regarding to x (I would like the derivative to be evaluated using several values of y in order to avoid non-smooth results).
This is not a simple problem, but there are a lot of methods that have been devised to handle it. One simple solution is to use finite difference methods. The command numpy.diff() uses finite differencing where you can specify the order of the derivative.
Wikipedia also has a page that lists the needed finite differencing coefficients for different derivatives of different accuracies. If the numpy function doesn't do what you want.
Depending on your application you can also use scipy.fftpack.diff which uses a completely different technique to do the same thing. Though your function needs a well defined Fourier transform.
There are lots and lots and lots of variants (e.g. summation by parts, finite differencing operators, or operators designed to preserve known evolution constants in your system of equations) on both of the two ideas above. What you should do will depend a great deal on what the problem is that you are trying to solve.
The good thing is that there is a lot of work has been done in this field. The Wikipedia page for Numerical Differentiation has some resources (though it is focused on finite differencing techniques).
The findiff project is a Python package that can do derivatives of arrays of any dimension with any desired accuracy order (of course depending on your hardware restrictions). It can handle arrays on uniform as well as non-uniform grids and also create generalizations of derivatives, i.e. general linear combinations of partial derivatives with constant and variable coefficients.
Would something like this solve your problem?
def get_inflection_points(arr, n=1):
"""
returns inflextion points from array
arr: array
n: n-th discrete difference
"""
inflections = []
dx = 0
for i, x in enumerate(np.diff(arr, n)):
if x >= dx and i > 0:
inflections.append(i*n)
dx = x
return inflections

Categories