pass custom scaling operation in python

Question:

i am following an example of the https://github.com/google/lightweight_mmm but instead of using the default setting for scalars, which is mean:

media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)

i need to use the lambda function:

lambda x: jnp.mean(x[x > 0])

How can this be done? I tried couple of things, but since i am a complete beginner, i feel lost.

So i have tried:

lambda x: jnp.mean(x[x > 0])
media_scaler = preprocessing.CustomScaler(divide_operation=x)

and

lambda x: jnp.mean(x[x > 0])
media_scaler = preprocessing.CustomScaler(divide_operation=lambda)

None of these work.

Asked By: Nneka

||

Answers:

This should do it

div = lambda x: jnp.mean(x[x > 0])
media_scaler = preprocessing.CustomScaler(divide_operation=div)
Answered By: ACarter
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.