Issue
This Content is from Stack Overflow. Question asked by Noé Achache
After converting a model from tf1 to tf2, I cannot use the tf2 model with tf.GradientTape().
Some useful information:
- The model uses tfa.seq2seq layers, which are most likely at the source of the error (e.g. replacing the content of the function network below with a single tf.compat.v1.layers.Dense works fine)
- tensorflow 2.9.1
- Python 3.9
Here are 2 snippets to reproduce the error:
- First run this snippet to create/save the model
import functools
import tensorflow as tf
import tensorflow_addons as tfa
MODEL_FOLDER = "model"
LSTM_UNITS = 3
def network(input):
attention_mechanism = tfa.seq2seq.BahdanauAttention(
1, input, memory_sequence_length=None
)
cell = tf.compat.v1.nn.rnn_cell.LSTMCell(
LSTM_UNITS,
)
attention_cell = tfa.seq2seq.AttentionWrapper(
cell, attention_mechanism, output_attention=False
)
embedding_fn = functools.partial(tf.compat.v1.one_hot, depth=10)
output_layer = tf.compat.v1.layers.Dense(
10,
)
train_helper = tfa.seq2seq.GreedyEmbeddingSampler(embedding_fn)
decoder = tfa.seq2seq.BasicDecoder(
cell=attention_cell, sampler=train_helper, output_layer=output_layer
)
init_kwargs = {}
init_kwargs["start_tokens"] = tf.compat.v1.fill([1], 1)
init_kwargs["end_token"] = 2
init_kwargs["initial_state"] = attention_cell.get_initial_state(
batch_size=1, dtype=tf.compat.v1.float32
)
outputs, _, output_lengths = tfa.seq2seq.dynamic_decode(
decoder=decoder,
output_time_major=False,
impute_finished=False,
maximum_iterations=1,
decoder_init_kwargs=init_kwargs,
)
return outputs
def main(_):
input = tf.compat.v1.placeholder(dtype=tf.compat.v1.float32, shape=[1, 1, LSTM_UNITS])
fetches = {
"output": network(input).rnn_output,
}
with tf.compat.v1.Session() as sess:
builder = tf.compat.v1.saved_model.Builder(MODEL_FOLDER)
sess.run(
[
tf.compat.v1.global_variables_initializer(),
tf.compat.v1.local_variables_initializer(),
tf.compat.v1.tables_initializer(),
]
)
sig_def = tf.compat.v1.saved_model.predict_signature_def(
inputs={"input": input}, outputs=fetches
)
builder.add_meta_graph_and_variables(
sess,
tags=["serve"],
signature_def_map={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def
},
)
builder.save()
if __name__ == "__main__":
# Save model
tf.compat.v1.disable_eager_execution()
tf.compat.v1.disable_v2_behavior()
tf.compat.v1.app.run()
- Then run this snippet to load/use the model
import tensorflow as tf
MODEL_FOLDER = "model"
LSTM_UNITS = 3
if __name__ == "__main__":
# Load model and use it
input = tf.ones((1, 1, LSTM_UNITS))
model = tf.saved_model.load(
MODEL_FOLDER, tags="serve"
).signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
print(model(input)) # this works
with tf.GradientTape():
model(input) # AttributeError: 'NoneType' object has no attribute 'outer_context'
If anyone could help me that would be great, thank you!
Solution
This question is not yet answered, be the first one who answer using the comment. Later the confirmed answer will be published as the solution.
This Question and Answer are collected from stackoverflow and tested by JTuto community, is licensed under the terms of CC BY-SA 2.5. - CC BY-SA 3.0. - CC BY-SA 4.0.