# How do I formulate a conditional function

## Question:

I am trying to formulate te the following functions in python and I want to plot them

``````import matplotlib.pyplot as plt
import numpy as np

ss = np.linspace(300, 1000, 15)

def PT3000(ss):
if ss < 318.842719019854:
PT3 = 4.602 + 37440.0/ss
else :
PT3 =-0.3 + 3600.0/ss
return PT3

def PT2000(ss):
if ss < 318.842719019854:
PT2 = 4.602 + 37440.0/ss
elif ss > 945.33959:
PT2 =-0.3 + 3600.0/ss
else:
PT2 = 6.87109574235995e-6*ss**0.5*(-1 + 96000.0/ss) + 62.144
return PT2

fig= plt.figure()
plt.plot(ss,PT2000(ss))
plt.plot(ss,PT3000(ss))
plt.title('Productietijd [24x12]')
plt.xlabel('Verstijverafstand [mm]')
plt.ylabel('Productijd van een paneel [uur]')
plt.grid(visible=True)
plt.legend()
plt.show()
``````

I run into an error but I don’t understand what to do with it

``````ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
``````

You should apply `PT2000` and `PT3000` to each element of `ss`. However, Python passes the entire array `ss` at once and lets NumPy handle it.

Comparison of NumPy arrays returns an array, so code like `ss < 318.842719019854` results in an array of Boolean values. Thus, the `if` statement becomes something like this:

``````if np.array([False, True, False, ...]):
# do stuff
...
``````

"The truth value of an array with more than one element is ambiguous", because the array can contain both `True` values and `False` values.

The solution that’s usually suggested is to "use a.any() or a.all()", which check whether any or all elements of the array are `True`. This is not what you need here, since piecewise functions like `PT2000` and `PT3000` act on individual numbers, not entire arrays.

You can use `numpy.vectorize` to apply your functions elementwise:

``````PT3000_vectorized = np.vectorize(PT3000)
plt.plot(ss, PT2000_vectorized(ss))
``````

This will iterate over `ss` and pass its individual elements to the function, so comparisons will simply involve `float`s.

``````import matplotlib.legend as legend
import matplotlib.pyplot as plt
import numpy as np

ss = np.linspace(300, 1000, 15)

def PT3000(ss):
if np.all(ss < 318.842719019854):
PT3 = 4.602 + 37440.0/ss
else :
PT3 =-0.3 + 3600.0/ss
return PT3

def PT2000(ss):
if np.all(ss < 318.842719019854) :
PT2 = 4.602 + 37440.0/ss
elif np.all(ss > 945.33959):
PT2 =-0.3 + 3600.0/ss
else:
PT2 = 6.87109574235995e-6*ss**0.5*(-1 + 96000.0/ss) + 62.144
return PT2

fig= plt.figure()
plt.plot(ss,PT2000(ss), label='PT2000')
plt.plot(ss,PT3000(ss), label='PT3000')
plt.title('Productietijd [24x12]')
plt.xlabel('Verstijverafstand [mm]')
plt.ylabel('Productijd van een paneel [uur]')
plt.grid(visible=True)
plt.legend(loc='upper left')
plt.show()
``````
Categories: questions
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.