# finding the maximum of a function using jax

## Question:

I have a function which I would like to find its maximum by optimizing two of its variables using Jax.

The current code that I have currently, which does not work, reads

``````import jax.numpy as jnp
import jax
import scipy
import numpy as np

def temp_func(x,y,z):
tmp = x + jnp.dot( jnp.power(y,3), jnp.tanh(z) )
return -tmp
def obj_func(xy, z):
x,y = xy[:2], xy[2:].reshape(2,2)
return jnp.sum(temp_func(jnp.array(x),jnp.array(y),z))

xy = jnp.concatenate([np.random.rand(2), np.random.rand(2*2) ])
z= jnp.array( np.random.rand(2,2) )
print(obj_func(xy,z))

result = scipy.optimize.minimize(obj_func,
xy,
args=(z,),
method='L-BFGS-B',
)
``````

With this code, I get the error `ValueError: failed in converting 7th argument `g’ of _lbfgsb.setulb to C/Fortran array`
Do you have any suggestions to resolve the issue?

You might think about using the `jax` version of `scipy.optimize.minimize`, which will automatically compute and use the derivative:
``````import jax.scipy.optimize
That said, the results in either case are not going to be very meaningful, because your objective function is linearly decreasing in `x` and `y`, so it will be minimized when x, y → ∞