Error getting gradients from model converted from tf1 to tf2 with tfa.seq2seq layers: “‘NoneType’ object has no attribute ‘outer_context'”

Issue

This Content is from Stack Overflow. Question asked by Noé Achache

After converting a model from tf1 to tf2, I cannot use the tf2 model with tf.GradientTape().

Some useful information:

  • The model uses tfa.seq2seq layers, which are most likely at the source of the error (e.g. replacing the content of the function network below with a single tf.compat.v1.layers.Dense works fine)
  • tensorflow 2.9.1
  • Python 3.9

Here are 2 snippets to reproduce the error:

  • First run this snippet to create/save the model
import functools
import tensorflow as tf
import tensorflow_addons as tfa

MODEL_FOLDER = "model"
LSTM_UNITS = 3

def network(input):
    attention_mechanism = tfa.seq2seq.BahdanauAttention(
        1, input, memory_sequence_length=None
    )
    cell = tf.compat.v1.nn.rnn_cell.LSTMCell(
      LSTM_UNITS,
    )
    attention_cell = tfa.seq2seq.AttentionWrapper(
        cell, attention_mechanism, output_attention=False
    )

    embedding_fn = functools.partial(tf.compat.v1.one_hot, depth=10)
    output_layer = tf.compat.v1.layers.Dense(
        10,
    )

    train_helper = tfa.seq2seq.GreedyEmbeddingSampler(embedding_fn)
    decoder = tfa.seq2seq.BasicDecoder(
            cell=attention_cell, sampler=train_helper, output_layer=output_layer
    )
    init_kwargs = {}
    init_kwargs["start_tokens"] = tf.compat.v1.fill([1], 1)
    init_kwargs["end_token"] = 2
    init_kwargs["initial_state"] = attention_cell.get_initial_state(
        batch_size=1, dtype=tf.compat.v1.float32
    )
    outputs, _, output_lengths = tfa.seq2seq.dynamic_decode(
        decoder=decoder,
        output_time_major=False,
        impute_finished=False,
        maximum_iterations=1,
        decoder_init_kwargs=init_kwargs,
    )
    return outputs


def main(_):
    input = tf.compat.v1.placeholder(dtype=tf.compat.v1.float32, shape=[1, 1, LSTM_UNITS])

    fetches = {
        "output": network(input).rnn_output,
    }

    with tf.compat.v1.Session() as sess:
        builder = tf.compat.v1.saved_model.Builder(MODEL_FOLDER)

        sess.run(
            [
                tf.compat.v1.global_variables_initializer(),
                tf.compat.v1.local_variables_initializer(),
                tf.compat.v1.tables_initializer(),
            ]
        )

        sig_def = tf.compat.v1.saved_model.predict_signature_def(
            inputs={"input": input}, outputs=fetches
        )

        builder.add_meta_graph_and_variables(
            sess,
            tags=["serve"],
            signature_def_map={
                tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def
            },
        )
        builder.save()

if __name__ == "__main__":

    # Save model
    tf.compat.v1.disable_eager_execution()
    tf.compat.v1.disable_v2_behavior()
    tf.compat.v1.app.run()
  • Then run this snippet to load/use the model
import tensorflow as tf

MODEL_FOLDER = "model"
LSTM_UNITS = 3

if __name__ == "__main__":

    # Load model and use it
    input = tf.ones((1, 1, LSTM_UNITS))
    model = tf.saved_model.load(
        MODEL_FOLDER, tags="serve"
    ).signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

    print(model(input))  # this works
    with tf.GradientTape():
        model(input) # AttributeError: 'NoneType' object has no attribute 'outer_context'

If anyone could help me that would be great, thank you!



Solution

This question is not yet answered, be the first one who answer using the comment. Later the confirmed answer will be published as the solution.

This Question and Answer are collected from stackoverflow and tested by JTuto community, is licensed under the terms of CC BY-SA 2.5. - CC BY-SA 3.0. - CC BY-SA 4.0.

people found this article helpful. What about you?