How to use values from previous Keras layer in convert_to_tensor_fn for TensorFlow Probability DistributionLambda

Question:

I have a Keras/TensorFlow Probability model where I would like to include values from the prior layer in the convert_to_tensor_fn parameter in the following DistributionLambda layer. Ideally, I wish I could do something like this:

from functools import partial
import tensorflow as tf
from tensorflow.keras import layers, Model
import tensorflow_probability as tfp
from typing import Union
tfd = tfp.distributions

zero_buffer = 1e-5


def quantile(s: tfd.Distribution, q: Union[tf.Tensor, float]) -> Union[tf.Tensor, float]:
    return s.quantile(q)


# 4 records (1st value represents CDF value, 
#            2nd represents location, 
#            3rd represents scale)
sample_input = tf.constant([[0.25, 0.0, 1.0], 
                            [0.5, 1.0, 0.5], 
                            [0.75, -1.0, 2.0], 
                            [0.95, 3.0, 2.5]], dtype=tf.float32)

# Build toy model for demonstration
input_layer = layers.Input(3)
dist = tfp.layers.DistributionLambda(
    make_distribution_fn=lambda t: tfd.Normal(loc=t[..., 1],
                                              scale=zero_buffer + tf.nn.softplus(t[..., 2])),
    convert_to_tensor_fn=lambda t, s: partial(quantile, q=t[..., 0])(s)
)(input_layer)
model = Model(input_layer, dist)

However, according to the documentation, the convert_to_tensor_fn is required to only take a tfd.Distribution as input; the convert_to_tensor_fn=lambda t, s: code doesn’t work in the code above.

How can I access data from the prior layer in the convert_to_tensor_fn? I’m assuming there’s a clever way to create a partial function, or something similar, to get this to work.

Outside of the Keras model framework, this is fairly easy to do using code similar to the example below:

# input data in Tensor Constant form
cdf_data = tf.constant([0.25, 0.5, 0.75, 0.95], dtype=tf.float32)
norm_mu = tf.constant([0.0, 1.0, -1.0, 3.0], dtype=tf.float32)
norm_scale = tf.constant([1.0, 0.5, 2.0, 2.5], dtype=tf.float32)

quant = partial(quantile, q=cdf_data)
norm = tfd.Normal(loc=norm_mu, scale=norm_scale)
quant(norm)

Output:

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([-0.6744898,  1.       ,  0.3489796,  7.112134 ], dtype=float32)>
Asked By: Jed

||

Answers:

the quantiles Fn is used for improving the performance of the results. It effects data presentation and calculation.

Sample of learning they precise in the learning process but quarantines results because varies of instruments or students.

Sample: Qualatine and Normal Distributions

Create a DistributionLamda layer working with the model then Qualatine the results for the presentation layer.

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

from typing import Union

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Functions
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""   
def normal_sp(params):
    return tfd.Normal(loc=params,
                      scale=1e-5 + 0.00001*tf.keras.backend.exp(params))# both parameters are learnable
                      
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""   
layer_0 = tf.keras.layers.Dense(32, activation='relu')
result_0 = layer_0( tf.constant([0.,  1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.], shape=(1, 11)) )

layer_1 = tfp.layers.DistributionLambda( normal_sp )

# Get quartiles of x with various interpolation choices.
x = tf.constant([0.,  1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.], shape=(1, 11))
                      
model = tf.keras.Sequential([
    tf.keras.Input(shape=(11)),
    layer_0,
    layer_1,
])

model.summary()

result = model.predict(x)
print( tfp.stats.quantiles(result, num_quantiles=4, interpolation='nearest') )

Output:

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #
=================================================================
 dense (Dense)               (None, 32)                384

 distribution_lambda (Distri  ((None, 32),             0
 butionLambda)                (None, 32))

=================================================================
Total params: 384
Trainable params: 384
Non-trainable params: 0
_________________________________________________________________
1/1 [==============================] - 0s 110ms/step
tf.Tensor(
[-3.3217777e-05  1.1406353e-05  5.9894159e-05  3.1857936e+00
  9.6232319e+00], shape=(5,), dtype=float32)
Answered By: Jirayu Kaewprateep

I found a solution to this problem on my own, and decided to post it here.

You can create a wrapper class for the tfp.Normal distribution that takes in the cdf value as an argument, and then you overwrite a couple of methods to do what you want. You especially need to overwrite the _sample_n method and replace it with the quantile function instead of a random draw from the distribution. The class would look something like this:

import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability.python.internal import dtype_util, tensor_util, reparameterization, samplers
from tensorflow_probability.python.internal import prefer_static as ps
tfd = tfp.distributions


class NormalWrapper(tfp.distributions.Normal):
    def __init__(self,
                 loc,
                 scale,
                 cdf_vals,
                 validate_args=False,
                 allow_nan_stats=True,
                 name='NormalCDF'):
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            dtype = dtype_util.common_dtype([loc, scale], dtype_hint=tf.float32)
            self._cdf_vals = tensor_util.convert_nonref_to_tensor(
                cdf_vals, dtype=dtype, name='cdf_vals')
        super(NormalWrapper, self).__init__(loc=loc,
                                            scale=scale,
                                            validate_args=validate_args,
                                            allow_nan_stats=allow_nan_stats,
                                            name=name)
        self._parameters = parameters

    def _parameter_properties(self, dtype=tf.float32, num_classes=None):
        return dict(
            loc=tfp.util.ParameterProperties(),
            scale=tfp.util.ParameterProperties(
                default_constraining_bijector_fn=(
                    lambda: tf.nn.softplus(low=dtype_util.eps(dtype)))),
            cdf_vals=tfp.util.ParameterProperties(),
        )

    @property
    def cdf_vals(self):
        return self._cdf_vals

    def _sample_n(self, n, seed=None):
        loc = tf.convert_to_tensor(self.loc)
        scale = tf.convert_to_tensor(self.scale)
        cdf_vals = tf.convert_to_tensor(self.cdf_vals)
        shape = ps.concat([[n], self._batch_shape_tensor(loc=loc, scale=scale, cdf_vals=cdf_vals)], axis=0)
        return tf.reshape(self.quantile(cdf_vals), shape=shape)

Once you have that class, you can create your DistributionLambda layer like this:

dist = tfp.layers.DistributionLambda(
    make_distribution_fn=lambda t: NormalWrapper(loc=t[..., 1],
                                                 scale=zero_buffer + tf.nn.softplus(t[..., 2]),
                                                 cdf_vals=t[..., 0]),
)(input_layer)
Answered By: Jed