model.parameters() alternative for TransUNet from transunet python library

Question:

I am trying to implement TransUNet for breakhis dataset I was making the optimizer like this

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

my model is

from transunet import TransUNet
model = TransUNet(image_size=224, pretrain=True)

But the parameter() function does not work with TransUNet

This is the library I am using https://github.com/awsaf49/TransUNet-tf

I tried using named_parameter(), __dict__['parameters'],state_dict() but none of them work.

Asked By: Ankit Wankhede

||

Answers:

The TransUNet model from transunet Library is implemented in TensorFlow and .parameters() is a torch function.

For TensorFlow we can use something like this –

optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
Answered By: Ankit Wankhede