Is there a model.summary() in Trax?

Question:

I’m working with Trax, a framework built by the Google Brain team to work with deep learning models as an alternative to TensorFlow. As a TensorFlow developer, I’m pretty used to the model.summary() method (documented here) to display a full model summary, for example:

model.summary()


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 16, 303)]         0         
_________________________________________________________________
bidirectional (Bidirectional (None, 16, 256)           442368    
_________________________________________________________________
time_distributed (TimeDistri (None, 16, 22)            5654      
=================================================================
Total params: 448,022
Trainable params: 448,022
Non-trainable params: 0

Is there something equivalent in Trax?

Asked By: Emiliano Viotti

||

Answers:

Currently, there does not appear to be a method similar to .summary() in Trax; the closest thing is that you can print the model. Adapting the example from the documentation:

from trax import layers as tl

model = tl.Serial(
    tl.Embedding(vocab_size=8192, d_feature=256),
    tl.Mean(axis=1),  # Average on axis 1 (length of sentence).
    tl.Dense(2),      # Classify 2 classes.
)

print(model)

Result:

Serial[
  Embedding_8192_256
  Mean
  Dense_2
]

Although nowhere as detailed as Tensorflow’s model.summary(), there is still useful info in the print output: notice that the parameters of the embedding layer are included in the printout; notice also that, if you change the model’s last layer to tl.Dense(3), the respective output will change to Dense_3.

Answered By: desertnaut