This Content is from Stack Overflow. Question asked by Kevin
I’m using a custom Dense layer that scales an output (size 1) between zero and one.
class BoundingBoxNumber(tf.keras.layers.Layer): def __init__(self, self_input_shape): super(BoundingBoxNumber,self).__init__() self.input_shape_custom = self_input_shape self.input_shape_accu = 1 for item in self_input_shape: self.input_shape_accu *= item self.internal_dense_layer = tf.keras.layers.Dense(1, activation = 'tanh') @tf.function def call(self, inputs): inputs = tf.keras.layers.Flatten()(inputs) inputs.set_shape((1, self.input_shape_accu)) output = self.internal_dense_layer(inputs) output = tf.divide(output, 2) output = tf.math.add(output, 0.5) return(output)
I’m also using a custom training loop
for epoch in range(50): print("Epoch:", epoch, "of 50") average_loss = 0 for iter, item in enumerate(image): num_bounding_boxes = tf.shape(bb[iter]) float_target = tf.cast(1/num_bounding_boxes, tf.float32) with tf.GradientTape() as tape: logits = dense_bb_num_layer(item) loss = NumLoss(logits, float_target) print("Logits:", logits, "Target:", float_target, "Loss:", loss) average_loss += loss call_gradients = tape.gradient(loss, dense_bb_num_layer.trainable_weights) call_optimizer.apply_gradients(zip(call_gradients, dense_bb_num_layer.trainable_weights)) average_loss /= len(image) print(average_loss)
In addition, I’m using a custom loss function:
def NumLoss(logits, expected): return((logits - expected)**2)
When I run the layer inside the training loop, the result is always a float with a value of exactly 0.0 or 1.0, however, when I just call it, it’s as expected, a float value between zero and 1.
I’m not sure why this would be the case, and any help would be appreciated. I’m trying to train the layer to output values such as 0.333333, 0.25, 0.125, etc, so it’s difficult if it only outputs whole numbers.
This question is not yet answered, be the first one who answer using the comment. Later the confirmed answer will be published as the solution.
This Question and Answer are collected from stackoverflow and tested by JTuto community, is licensed under the terms of CC BY-SA 2.5. - CC BY-SA 3.0. - CC BY-SA 4.0.