How do I formulate a conditional function

Question:

I am trying to formulate te the following functions in python and I want to plot them

import matplotlib.pyplot as plt
import numpy as np

ss = np.linspace(300, 1000, 15)

def PT3000(ss):
    if ss < 318.842719019854:
        PT3 = 4.602 + 37440.0/ss
    else :
        PT3 =-0.3 + 3600.0/ss
    return PT3

def PT2000(ss):
    if ss < 318.842719019854:
        PT2 = 4.602 + 37440.0/ss
    elif ss > 945.33959:
        PT2 =-0.3 + 3600.0/ss
    else:
        PT2 = 6.87109574235995e-6*ss**0.5*(-1 + 96000.0/ss) + 62.144
    return PT2

fig= plt.figure()
plt.plot(ss,PT2000(ss))
plt.plot(ss,PT3000(ss))
plt.title('Productietijd [24x12]')
plt.xlabel('Verstijverafstand [mm]')
plt.ylabel('Productijd van een paneel [uur]')
plt.grid(visible=True)
plt.legend()
plt.show()

I run into an error but I don’t understand what to do with it

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Asked By: Aileen Obianyor

||

Answers:

You should apply PT2000 and PT3000 to each element of ss. However, Python passes the entire array ss at once and lets NumPy handle it.

Comparison of NumPy arrays returns an array, so code like ss < 318.842719019854 results in an array of Boolean values. Thus, the if statement becomes something like this:

if np.array([False, True, False, ...]):
    # do stuff
    ...

"The truth value of an array with more than one element is ambiguous", because the array can contain both True values and False values.

The solution that’s usually suggested is to "use a.any() or a.all()", which check whether any or all elements of the array are True. This is not what you need here, since piecewise functions like PT2000 and PT3000 act on individual numbers, not entire arrays.

You can use numpy.vectorize to apply your functions elementwise:

PT3000_vectorized = np.vectorize(PT3000)
plt.plot(ss, PT2000_vectorized(ss))

This will iterate over ss and pass its individual elements to the function, so comparisons will simply involve floats.

Answered By: ForceBru
import matplotlib.legend as legend
import matplotlib.pyplot as plt
import numpy as np

ss = np.linspace(300, 1000, 15)

def PT3000(ss):
    if np.all(ss < 318.842719019854):
        PT3 = 4.602 + 37440.0/ss
    else :
        PT3 =-0.3 + 3600.0/ss
    return PT3

def PT2000(ss):
    if np.all(ss < 318.842719019854) :
        PT2 = 4.602 + 37440.0/ss
    elif np.all(ss > 945.33959):
        PT2 =-0.3 + 3600.0/ss
    else:
        PT2 = 6.87109574235995e-6*ss**0.5*(-1 + 96000.0/ss) + 62.144
    return PT2

fig= plt.figure()
plt.plot(ss,PT2000(ss), label='PT2000')
plt.plot(ss,PT3000(ss), label='PT3000')
plt.title('Productietijd [24x12]')
plt.xlabel('Verstijverafstand [mm]')
plt.ylabel('Productijd van een paneel [uur]')
plt.grid(visible=True)
plt.legend(loc='upper left')
plt.show()
Answered By: Bilal Belli