# finding the maximum of a function using jax

## Question:

I have a function which I would like to find its maximum by optimizing two of its variables using Jax.

The current code that I have currently, which does not work, reads

```
import jax.numpy as jnp
import jax
import scipy
import numpy as np
def temp_func(x,y,z):
tmp = x + jnp.dot( jnp.power(y,3), jnp.tanh(z) )
return -tmp
def obj_func(xy, z):
x,y = xy[:2], xy[2:].reshape(2,2)
return jnp.sum(temp_func(jnp.array(x),jnp.array(y),z))
grad_tmp = jax.grad(obj_func, argnums=0) # x,y
xy = jnp.concatenate([np.random.rand(2), np.random.rand(2*2) ])
z= jnp.array( np.random.rand(2,2) )
print(obj_func(xy,z))
result = scipy.optimize.minimize(obj_func,
xy,
args=(z,),
method='L-BFGS-B',
jac=grad_tmp
)
```

With this code, I get the error `ValueError: failed in converting 7th argument `

g’ of _lbfgsb.setulb to C/Fortran array`

Do you have any suggestions to resolve the issue?

## Answers:

You might think about using the `jax`

version of `scipy.optimize.minimize`

, which will automatically compute and use the derivative:

```
import jax.scipy.optimize
result = jax.scipy.optimize.minimize(obj_func, xy, args=(z,), method='BFGS')
```

That said, the results in either case are not going to be very meaningful, because your objective function is linearly decreasing in `x`

and `y`

, so it will be minimized when *x, y → ∞*