Weird behaviour in tensorflow metric
Question:
I have created a tensorflow metric as seen below:
def AttackAcc(y_true, y_pred):
r = tf.random.uniform(shape=(), minval=0, maxval=11, dtype=tf.int32)
if tf.math.greater(r,tf.constant(5) ):
return tf.math.equal( tf.constant(0.6) , tf.constant(0.2) )
else:
return tf.math.equal( tf.constant(0.6) , tf.constant(0.6) )
The metric is added to the model.compile
as :
metrics=[AttackAcc]
This should return 0 half of the time and 1 in the other half. SO while training my model i should see a value for this metric of around 0.5.
However it is always 0.
Any ideas about why?
Answers:
It looks like you are comparing two constants and they will always not be equal. Try BinaryAccuracy and use your input variables to update the state.
def AttackAcc(y_true, y_pred):
r = tf.random.uniform(shape=(), minval=0, maxval=11, dtype=tf.int32)
acc_metric = tf.keras.metrics.BinaryAccuracy()
acc_metric.update_state(y_true, y_pred)
if tf.math.greater(r, tf.constant(5)):
return acc_metric.result()
else:
return 1 - acc_metric.result()
I have created a tensorflow metric as seen below:
def AttackAcc(y_true, y_pred):
r = tf.random.uniform(shape=(), minval=0, maxval=11, dtype=tf.int32)
if tf.math.greater(r,tf.constant(5) ):
return tf.math.equal( tf.constant(0.6) , tf.constant(0.2) )
else:
return tf.math.equal( tf.constant(0.6) , tf.constant(0.6) )
The metric is added to the model.compile
as :
metrics=[AttackAcc]
This should return 0 half of the time and 1 in the other half. SO while training my model i should see a value for this metric of around 0.5.
However it is always 0.
Any ideas about why?
It looks like you are comparing two constants and they will always not be equal. Try BinaryAccuracy and use your input variables to update the state.
def AttackAcc(y_true, y_pred):
r = tf.random.uniform(shape=(), minval=0, maxval=11, dtype=tf.int32)
acc_metric = tf.keras.metrics.BinaryAccuracy()
acc_metric.update_state(y_true, y_pred)
if tf.math.greater(r, tf.constant(5)):
return acc_metric.result()
else:
return 1 - acc_metric.result()