diff --git a/tensorflow_gnn/runner/orchestration_test.py b/tensorflow_gnn/runner/orchestration_test.py index e0162efc..90df3d09 100644 --- a/tensorflow_gnn/runner/orchestration_test.py +++ b/tensorflow_gnn/runner/orchestration_test.py @@ -281,6 +281,7 @@ class OrchestrationTests(tf.test.TestCase, parameterized.TestCase): def setUp(self): super().setUp() tfgnn.enable_graph_tensor_validation_at_runtime() + tf.keras.mixed_precision.set_global_policy("float32") @parameterized.named_parameters([ dict(