Tensorflow_io: ValueError: Cannot infer argument `num` from shape (None, None, None)

Question:

I am trying to read and decode tiff images in tensorflow. I am using tensrflow_io package as follows, I am getting this error that I cant figure out.

import tensorflow as tf
import tensorflow_io as tfio
import os

def process_image(image):

  image = tf.io.read_file(image)
  image = tfio.experimental.image.decode_tiff(image)
  image = tfio.experimental.color.rgba_to_rgb(image)
  return image

path = os.path.join(os.curdir, '*.TIF')
files = tf.data.Dataset.list_files(path)

Output:

for file in files.take(5):
  print(file)

tf.Tensor(b'./SIMCEPImages_A01_C1_F1_s10_w1.TIF', shape=(), dtype=string)
tf.Tensor(b'./SIMCEPImages_A01_C1_F1_s04_w1.TIF', shape=(), dtype=string)
tf.Tensor(b'./SIMCEPImages_A01_C1_F1_s12_w1.TIF', shape=(), dtype=string)
tf.Tensor(b'./SIMCEPImages_A01_C1_F1_s04_w2.TIF', shape=(), dtype=string)
tf.Tensor(b'./SIMCEPImages_A01_C1_F1_s11_w1.TIF', shape=(), dtype=string)

Now if I call:

dataset = files.map(process_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)

for img in dataset.take(5):
  print(img.shape)

ValueError: in user code:

    File "<ipython-input-4-1d2deab36c6d>", line 5, in process_image  *
        image = tfio.experimental.color.rgba_to_rgb(image)
    File "/usr/local/lib/python3.7/dist-packages/tensorflow_io/python/experimental/color_ops.py", line 80, in rgba_to_rgb  *
        rgba = tf.unstack(input, axis=-1)

    ValueError: Cannot infer argument `num` from shape (None, None, None)
Asked By: Gopal Bhattrai

||

Answers:

The problem is that tfio.experimental.color.rgba_to_rgb uses unstack under the hood, which cannot work in graph mode. One solution would be to manually index the channels you want according to the source code for rgba_to_rgb. Here is a working example:

import numpy as np
from PIL import Image
import tensorflow as tf
import tensorflow_io as tfio
import os

# Create dummy data
data = np.random.randint(0, 255, (10,10)).astype(np.uint8)
im = Image.fromarray(data)
im.save('image1.tif')
im.save('image2.tif')

def process_image(image):

  image = tf.io.read_file(image)
  image = tfio.experimental.image.decode_tiff(image)
  r, g, b = image[:, :, 0], image[:, :, 1], image[:, :, 2]
  return tf.stack([r, g, b], axis=-1)

path = os.path.join(os.curdir, '*.tif')
files = tf.data.Dataset.list_files(path)

for file in files.take(5):
  print(file)

dataset = files.map(process_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
for img in dataset.take(5):
  print(img.shape)
tf.Tensor(b'./image2.tif', shape=(), dtype=string)
tf.Tensor(b'./image1.tif', shape=(), dtype=string)
(10, 10, 3)
(10, 10, 3)

If you really want to use tfio.experimental.color.rgba_to_rgb, it will have be out of graph mode, using for example tf.py_function.

Answered By: AloneTogether

I changed a bit of code that is because the argument is not updated as you expected, this way is easy to understand. ( arg 0 )

[ Sample ]:

import os
from os.path import exists

import tensorflow as tf
import tensorflow_io as tfio

import matplotlib.pyplot as plt

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
None
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
config = tf.config.experimental.set_memory_growth(physical_devices[0], True)
print(physical_devices)
print(config)

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
Variables
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
PATH = os.path.join('F:\datasets\downloads\Actors\train\Pikaploy', '*.tif')
PATH_2 = os.path.join('F:\datasets\downloads\Actors\train\Candidt Kibt', '*.tif')
files = tf.data.Dataset.list_files(PATH)
files_2 = tf.data.Dataset.list_files(PATH_2)

list_file = []
list_file_actual = []
list_label = []
list_label_actual = [ 'Pikaploy', 'Pikaploy', 'Pikaploy', 'Pikaploy', 'Pikaploy', 'Candidt Kibt', 'Candidt Kibt', 'Candidt Kibt', 'Candidt Kibt', 'Candidt Kibt' ]
for file in files.take(5):
    image = tf.io.read_file( file )
    image = tfio.experimental.image.decode_tiff(image, index=0)
    list_file_actual.append(image)
    image = tf.image.resize(image, [32,32], method='nearest')
    list_file.append(image)
    list_label.append(1)
    
for file in files_2.take(5):
    image = tf.io.read_file( file )
    image = tfio.experimental.image.decode_tiff(image, index=0)
    list_file_actual.append(image)
    image = tf.image.resize(image, [32,32], method='nearest')
    list_file.append(image)
    list_label.append(9)

checkpoint_path = "F:\models\checkpoint\" + os.path.basename(__file__).split('.')[0] + "\TF_DataSets_01.h5"
checkpoint_dir = os.path.dirname(checkpoint_path)
loggings = "F:\models\checkpoint\" + os.path.basename(__file__).split('.')[0] + "\loggings.log"

if not exists(checkpoint_dir) : 
    os.mkdir(checkpoint_dir)
    print("Create directory: " + checkpoint_dir)
    
log_dir = checkpoint_dir

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
DataSet
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
dataset = tf.data.Dataset.from_tensor_slices((tf.constant(tf.cast(list_file, dtype=tf.int64), shape=(10, 1, 32, 32, 4), dtype=tf.int64),tf.constant(list_label, shape=(10, 1, 1), dtype=tf.int64)))

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Initialize
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=( 32, 32, 4 )),
    tf.keras.layers.Normalization(mean=3., variance=2.),
    tf.keras.layers.Normalization(mean=4., variance=6.),
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Reshape((128, 225)),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(96, return_sequences=True, return_state=False)),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(96)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(192, activation='relu'),
    tf.keras.layers.Dense(10),
])

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Optimizer
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
optimizer = tf.keras.optimizers.Nadam(
    learning_rate=0.00001, beta_1=0.9, beta_2=0.999, epsilon=1e-07,
    name='Nadam'
)

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Loss Fn
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""                               
lossfn = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=False,
    reduction=tf.keras.losses.Reduction.AUTO,
    name='sparse_categorical_crossentropy'
)

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Summary
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model.compile(optimizer=optimizer, loss=lossfn, metrics=['accuracy'])

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: FileWriter
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
if exists(checkpoint_path) :
    model.load_weights(checkpoint_path)
    print("model load: " + checkpoint_path)
    input("Press Any Key!")

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Training
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
history = model.fit( dataset, batch_size=100, epochs=50 )
model.save_weights(checkpoint_path)

plt.figure(figsize=(5,2))
plt.title("Actors recognitions")
for i in range(len(list_file)):
    img = tf.keras.preprocessing.image.array_to_img(
        list_file[i],
        data_format=None,
        scale=True
    )
    img_array = tf.keras.preprocessing.image.img_to_array(img)
    img_array = tf.expand_dims(img_array, 0)
    predictions = model.predict(img_array)
    score = tf.nn.softmax(predictions[0])
    plt.subplot(5, 2, i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(list_file_actual[i])
    plt.xlabel(str(round(score[tf.math.argmax(score).numpy()].numpy(), 2)) + ":" +  str(list_label_actual[tf.math.argmax(score)]))
    
plt.show()

input('...')

[ Output ]:
Sample

Answered By: Jirayu Kaewprateep