Inference vs Training causes different outputs


This Content is from Stack Overflow. Question asked by Kevin

I’m using a custom Dense layer that scales an output (size 1) between zero and one.

class BoundingBoxNumber(tf.keras.layers.Layer):

    def __init__(self, self_input_shape):

        self.input_shape_custom = self_input_shape

        self.input_shape_accu = 1

        for item in self_input_shape:
            self.input_shape_accu *= item

        self.internal_dense_layer = tf.keras.layers.Dense(1, activation = 'tanh')

    def call(self, inputs):

        inputs = tf.keras.layers.Flatten()(inputs)

        inputs.set_shape((1, self.input_shape_accu))

        output = self.internal_dense_layer(inputs)


        output = tf.divide(output, 2)

        output = tf.math.add(output, 0.5)


I’m also using a custom training loop

for epoch in range(50):

    print("Epoch:", epoch, "of 50")

    average_loss = 0

    for iter, item in enumerate(image):
        num_bounding_boxes = tf.shape(bb[iter])[0]

        float_target = tf.cast(1/num_bounding_boxes, tf.float32)

        with tf.GradientTape() as tape:

            logits = dense_bb_num_layer(item)
            loss = NumLoss(logits, float_target)

            print("Logits:", logits, "Target:", float_target, "Loss:", loss)

        average_loss += loss

        call_gradients = tape.gradient(loss, dense_bb_num_layer.trainable_weights)

        call_optimizer.apply_gradients(zip(call_gradients, dense_bb_num_layer.trainable_weights))

    average_loss /= len(image)


In addition, I’m using a custom loss function:

def NumLoss(logits, expected):

    return((logits - expected)**2)

When I run the layer inside the training loop, the result is always a float with a value of exactly 0.0 or 1.0, however, when I just call it, it’s as expected, a float value between zero and 1.

I’m not sure why this would be the case, and any help would be appreciated. I’m trying to train the layer to output values such as 0.333333, 0.25, 0.125, etc, so it’s difficult if it only outputs whole numbers.


This question is not yet answered, be the first one who answer using the comment. Later the confirmed answer will be published as the solution.

This Question and Answer are collected from stackoverflow and tested by JTuto community, is licensed under the terms of CC BY-SA 2.5. - CC BY-SA 3.0. - CC BY-SA 4.0.

people found this article helpful. What about you?