I am working on an algorithm which needs gradient information.
I tried numdifftools.Gradient,this function works well,but the time cost is unsustainable.
Briefly,what Im doing is initializing a multi-dimensional(says d) vector t,then use the t vector to parameterize a matrix A,then the matrix along with other information gives a energy value,which is a scalar output.
I need the gradient (energy on t) element-wisely,so that i can update the t parameters and continue the loop.
My code looks like this:
def initialize(d):
......
return t
def A(t):
......
return A,result_2,result3...
def energy(t,A,para_2,para_3...)
......
some matrix calculation including kron etc.
......
return e
grad = numdifftools.Gradient(energy)(t)
#this return the same shape of t
#represents element-wise gradient w.r.t. the energy function.
t -= grad * learning_rate
this works exactly what i want,however,when the dimension goes bigger,the gradient calculation may take several minutes in only 1 iteration,while I need to perform thousands of the iterations,
I tried to use Google's JAX,and it seems that JAX only work when the output has just one scalar, while here I need matrix results.
Actually, u dont need to know what exactly I am doing,this is just a time cost optimization problem about gradient.
Is there any better way to do this?
Have you tried cache decorator? It might help
from functools import cache
#cache
def my_function():
"""do things"""
Related
I have the following three functions implements in JAX.
def helper_1(params1):
...calculations...
return z
def helper_2(z, params2):
...calculations...
return y
def main(params1, params2):
z = helper_1(params1)
y = helper_2(z, params2)
return z,y
I am interested in the partial derivatives of the output from main, i.e. z and y, with respect to both params1 and params2. As params1 and params2 are low dimensional and z and y are high dimensional, I am using the jax.jacfwd function.
When calling
jax.jacfwd(main,argnums=(0,1))(params1,params2)
Jax computes the derivatives of z with respect to params1 (and params2, which in this case is just a bunch of zeros). My question is: does Jax recompute dz/d_param1 for the derivatives of y with respect to params1 and params2, or does it somehow figure out this has already been computed?
I don't know if this is relevant, but the 'helper_1' function contains functions from the TensorFlow library for Jax. Thanks!
In general, in the situation you describe JAX's forward-mode autodiff approach will re-use the derivative of z when computing the derivative of y. If you wish, you can confirm this by looking at the jaxpr of your differentiated function:
print(jax.make_jaxpr(jax.jacfwd(main, (0, 1)))(params1, params2))
Though if your function is more than moderately complicated, the output might be hard to understand.
As a general note, though, JAX's autodiff implementation does tend to produce a small number of unnecessary or duplicated computations. As a simple example consider this:
import jax
print(jax.make_jaxpr(jax.grad(jax.lax.sin))(1.0))
# { lambda ; a:f32[]. let
# _:f32[] = sin a
# b:f32[] = cos a
# c:f32[] = mul 1.0 b
# in (c,) }
Here the primal value sin(a) is computed even though it is never used in computing the final output.
In practice this can be addressed by wrapping your computation in jit, in which case the XLA compiler takes care of optimization, including dead code elimination and de-duplication when applicable:
result = jit(jax.jacfwd(main, (0, 1)))(params1, params2)
Question:
Is there any working method to calculate gradient of (non-scalar) tensor function?
Example
Given n by n symmetric matrices X, Y and matrix function Z(X, Y) = torch.mm(X.mm(X), Y) calculate d(dZ/dX)/dY.
Expected answer
d(dZ/dX)/dY = d(2*XY)/dY = 2*X
Attempts
Because torch's .backward() works only for scalar variables I've tried to calculate derivative by applying torch.autograd.grad() to each element of tensor Z, but this approach is not correct, because it gives d(X^2)/dX = X + 2*D where D is a diagonal matrix with diagonal values of X. For me it's a bit weird that torch has an ability to build a computational graph, but can't track tensor through it as a variable to get tensor derivative.
Edit
Question was not very clear, so I decided to give more details.
My aim is to get partial derivative of loss function, which involves two matrices as variables. It looks like that:
loss = torch.linalg.norm(my_formula(X, Y) , ord='fro')
And I need to find
d^2(loss)/d(Y^2)
d/dX[d(loss)/dY]
Torch is capable of calculating 1. by using .backward() two times, but it's problematic to find 2. because torch.autograd.grad() expects scalar input and not the tensor
TL;DR
For function f which takes a matrix and gives a scalar:
Find first order derivative, let's name it dX
Take trace: Tr(dX)
To get mixed partial derivative just use the trace from above: d/dY[df/dX] = d/dY[Tr(df/dX)]
Intro
At the moment of posting the question I was not really that good at theory of matrix derivatives, but now I know much more all thanks to this Yandex ml book (unfortunately, I didn't find the english equivalent). This is an attempt to give a full answer to my question.
Basic Theory
Forgive me, Lord, for ugly representation of latex
Let's say you have a function which takes matrix X and returns it's squared Frobenius norm: f(X) = ||X||_F^2
It is a well-known fact that: ||X||_F^2 = Tr(X X^T)
Let's define derivative as shown in same book: D[f] at X_0 = f(X + H) - f(X)
We are ready to find dg(X)/dX:
df(X)/dX = dTr(X X^T)/dX =
(using Trace's feature)
= Tr(d/dX[X X^T]) = Tr(dX/dX X^T + X d[X^T]/dX ) =
(then we should use the definition of derivative from above)
= Tr(HX^T + XH^T) = Tr(HX^T) + Tr(XH^T) =
(now the main trick is to get all matrices H on the right side and get something like
Tr(g(X) H) or Tr(g(X) H^T), where g(X) will be the derivative we are looking for)
= Tr(HX^T) + Tr(XH^T) = Tr(XH^T) + Tr(XH^T) = Tr(2*XH^T)
That means: df(X)/dX = 2X
Second order derivative
Now, after we found out how to get matrix derivatives, let's try to find second order derivative of the same function f(X):
d/dX[df(X)/dX] = d/dX[Tr(2XH_1^T)] = Tr(d/dX[2XH_1^T]) =
= Tr(2I H_2 H_1^T)
We found out that d/dX[df(X)/dX] = 2I where I stands for Identity matrix. But how will it help us to find derivatives in Pytorch?
Trace is the trick
As we can see from the formulas, both first and second order derivatives have Trace inside them, but when we take first order derivative we just instantly get matrix as a result. To get a higher order derivative we just need to take the derivative of trace of first order derivative:
d/dY[df/dX] = d/dY[Tr(df/dX)]
The thing is I was using JAX autograd library when this trick came to my mind, so the code with a function f(X,Y) will look like this:
def scalarized_dy(X, Y):
dY = grad(f, argnums=1)(X, Y)
return jnp.trace(dY)
dYX = grad(scalarized_dy, argnums=0)(X, Y)
dYY = grad(scalarized_dy, argnums=1)(X, Y)
In case of Pytorch I guess we will need to look after tensors' gradients (let loss be a function with X and Y as arguments):
loss = f(X, Y)
loss.backward(create_graph = True)
dX = torch.trace(X.grad)
dX.backward()
dXX = X.grad
dXY = Y.grad
Epilogue
I thought that the question itself is in some way interesting. Also, it took me several months to figure things out, so I decided to give my current point of view on this problem. I will not mark my answer as correct yet in hope that I will get some kind of feedback or, perhaps, even better answers or ideas.
I am implementing a customer operation whose gradients must be calculated. The following is the function:
def difference(prod,box):
result = tf.Variable(tf.zeros((prod.shape[0],box.shape[1]),dtype=tf.float16))
for i in tf.range(0,prod.shape[0]):
for j in tf.range(0,box.shape[1]):
result[i,j].assign((tf.reduce_prod(box[:,j])-tf.reduce_prod(prod[i,:]))/tf.reduce_prod(box[:,j]))
return result
I am unable to calculate the gradients with respect to box, the tape.gradient() is returning None, here is the code I have written for calculating gradients
prod = tf.constant([[3,4,5],[4,5,6],[1,3,3]],dtype=tf.float16)
box = tf.Variable([[4,5],[5,6],[5,7]],dtype=tf.float16)
with tf.GradientTape() as tape:
tape.watch(box)
loss = difference(prod,box)
print(tape.gradient(loss,box))
I am not able to find the reason for unconnected gradients. Is the result variable causing it? Kindly suggest an alternative implementation.
Yes, in order to calculate gradients we need a set of (differentiable) operations on your variables.
You should re-write difference as a function of the 2 input tensors. I think (though happy to confess I am not 100% sure!) that it is the use of 'assign' that makes the gradient tape fall over.
Perhaps something like this:
def difference(prod, box):
box_red = tf.reduce_prod(box, axis=0)
prod_red = tf.reduce_prod(prod, axis=1)
return (tf.expand_dims(box_red, 0) - tf.expand_dims(prod_red, 1)) / tf.expand_dims(box_red, 0)
would get you the desired result
One of my operaction need integer, but output of convolution is float.
It means I need to use tf.floor, tf.ceil, tf.cast...etc to handle it.
But these operactions cause None gradients, since operactions like tf.floor are not differentiable
So, I tried something like below
First. detour
out1 = tf.subtract(vif, tf.subtract(vif, tf.floor(vif)))
But output of test.compute_gradient_error is 500 or 0, I don't think this is a reasonable gradient.
Second. override gradient function of floor
#ops.RegisterGradient("CustomFloor")
def _custom_floor_grad(op, grads):
return [grads]
A, B = 50, 7
shape = [A, B]
f = np.ones(shape, dtype=np.float32)
vif = tf.constant(f, dtype=tf.float32)
# out1 = tf.subtract(vif, tf.subtract(vif, tf.floor(vif)))
with tf.get_default_graph().gradient_override_map({"Floor": "CustomFloor"}):
out1 = tf.floor(vif)
with tf.Session() as sess:
err1 = tf.test.compute_gradient_error(vif, shape, out1, shape)
print err1
output of test.compute_gradient_error is 500 or 1, doesn't work too.
Question: A way to get integer and keep back propagation work fine (value like 2.0, 5.0 is ok)
In general, it's not inadvisable to solve discrete problem with gradient descent. You should be able express, to some extent integer solvers in TF but you're more or less on your own.
FWIW, the floor function looks like a saw. Its derivative is a constant function at 1 with little holes at every integer. At these positions you have a Dirac functional pointing downwards, like a rake if you wish. The Dirac functional has finite energy but no finite value.
The canonical way to tackle these problems is to relax the problem by "relaxiing" the hard floor constraint with something that is (at least once) differentiable (smooth).
There are multiple ways to do this. Perhaps the most popular are:
Hack up a function that looks like what you want. For instance a piece-wise linear function that slopes down quickly, but not vertically.
Replace step functions by sigmoids
Use a filter approximation which is well understood if it's a time series
I am playing around with logistic regression in Python. I have implemented a version where the minimization of the cost function is done via gradient descent, and now I'd like to use the BFGS algorithm from scipy (scipy.optimize.fmin_bfgs).
I have a set of data (features in matrix X, with one sample in every row of X, and correpsonding labels in vertical vector y). I am trying to find parameters Theta to minimize:
I have trouble understanding how fmin_bfgs works exactly. As far as I get it, I have to pass a function to be minimized and a set of initial values for Thetas.
I do the following:
initial_values = numpy.zeros((len(X[0]), 1))
myargs = (X, y)
theta = scipy.optimize.fmin_bfgs(computeCost, x0=initial_values, args=myargs)
where computeCost calculates J(Thetas) as illustrated above. But I get some index-related errors, so I think I am not supplying what fmin_bfgs expects.
Can anyone shed some light on this?
After wasting hours on it, solved again by power of posting...I was defining computeCost(X, y, Thetas), but as Thetas is the target parameter for optimization, it should have been the first parameter in the signature. Fixed and works!
i don't know your whole code, but have you tried
initial_values = numpy.zeros(len(X[0]))
? This x0 should be a 1d vector, i think.