finding the maximum of a function using jax


I have a function which I would like to find its maximum by optimizing two of its variables using Jax.

The current code that I have currently, which does not work, reads

import jax.numpy as jnp
import jax 
import scipy
import numpy as np

def temp_func(x,y,z):
    tmp = x + jnp.power(y,3), jnp.tanh(z) )
    return -tmp
def obj_func(xy, z):
    x,y = xy[:2], xy[2:].reshape(2,2)
    return jnp.sum(temp_func(jnp.array(x),jnp.array(y),z))

grad_tmp = jax.grad(obj_func, argnums=0) # x,y

xy = jnp.concatenate([np.random.rand(2), np.random.rand(2*2) ])
z= jnp.array( np.random.rand(2,2) )

result = scipy.optimize.minimize(obj_func,

With this code, I get the error ValueError: failed in converting 7th argument g’ of _lbfgsb.setulb to C/Fortran array`
Do you have any suggestions to resolve the issue?

Asked By: Shasa



You might think about using the jax version of scipy.optimize.minimize, which will automatically compute and use the derivative:

import jax.scipy.optimize
result = jax.scipy.optimize.minimize(obj_func, xy, args=(z,), method='BFGS')

That said, the results in either case are not going to be very meaningful, because your objective function is linearly decreasing in x and y, so it will be minimized when x, y → ∞

Answered By: jakevdp
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.