Skip to content

Commit 2118497

Browse files
authored
Merge pull request #632 from nfeybesse/custom/gradients-dispatch
Custom/gradients dispatch
2 parents 5ec6675 + da6c216 commit 2118497

12 files changed

Lines changed: 533 additions & 145 deletions

File tree

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractGradientAdapter.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,17 @@ private static TF_Output toNativeOutputs(List<Operand<?>> outputs) {
9292
new TF_Output(Pointer.malloc((long) outputs.size() * Pointer.sizeof(TF_Output.class)));
9393

9494
for (int i = 0; i < outputs.size(); ++i) {
95-
var output = outputs.get(i).asOutput();
95+
Operand<?> operand = outputs.get(i);
9696
var nativeOutput = nativeOutputs.getPointer(i);
97+
98+
// Convention: null Operand => NoGradient
99+
if (operand == null) {
100+
nativeOutput.oper((TF_Operation) null);
101+
nativeOutput.index(0);
102+
continue;
103+
}
104+
105+
var output = operand.asOutput();
97106
nativeOutput.oper(((GraphOperation) output.op()).getUnsafeNativeHandle());
98107
nativeOutput.index(output.index());
99108
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.tensorflow.internal.c_api.TF_Library;
4040
import org.tensorflow.internal.c_api.TF_Status;
4141
import org.tensorflow.op.CustomGradient;
42+
import org.tensorflow.op.GradientDispatch;
4243
import org.tensorflow.op.RawCustomGradient;
4344
import org.tensorflow.op.RawOpInputs;
4445
import org.tensorflow.op.annotation.OpInputsMetadata;
@@ -207,7 +208,10 @@ public static synchronized boolean registerCustomGradient(
207208
if (hasGradient(opType)) {
208209
return false;
209210
}
210-
TFJ_GradFuncAdapter g = RawCustomGradient.adapter(gradient);
211+
212+
GradientDispatch.putRaw(opType, gradient);
213+
TFJ_GradFuncAdapter g = GradientDispatch.adapter();
214+
211215
if (!TFJ_RegisterCustomGradient(opType, g)) {
212216
return false;
213217
}
@@ -255,7 +259,10 @@ public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGrad
255259
if (hasGradient(opType)) {
256260
return false;
257261
}
258-
TFJ_GradFuncAdapter g = CustomGradient.adapter(gradient, inputClass);
262+
263+
GradientDispatch.putTyped(opType, gradient, inputClass);
264+
TFJ_GradFuncAdapter g = GradientDispatch.adapter();
265+
259266
if (!TFJ_RegisterCustomGradient(opType, g)) {
260267
return false;
261268
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/CustomGradient.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import org.tensorflow.Operand;
2121
import org.tensorflow.Output;
2222
import org.tensorflow.TensorFlow;
23-
import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter;
2423

2524
/**
2625
* A custom gradient for ops of type {@link T}. Should be registered using {@link
@@ -48,15 +47,4 @@ public interface CustomGradient<T extends RawOpInputs> {
4847
* @return the gradients of the op's inputs.
4948
*/
5049
List<Operand<?>> call(Ops tf, T op, List<Output<?>> gradInputs);
51-
52-
/**
53-
* Create an adapter for the custom gradient so that it can be used by native code.
54-
*
55-
* <p>You should not be calling this yourself, use {@link TensorFlow#registerCustomGradient(Class,
56-
* CustomGradient)}.
57-
*/
58-
static <T extends RawOpInputs<?>> TFJ_GradFuncAdapter adapter(
59-
CustomGradient<T> gradient, Class<T> opClass) {
60-
return new TypedGradientAdapter<T>(gradient, opClass);
61-
}
6250
}
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/* Copyright 2026 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================
15+
*/
16+
package org.tensorflow.op;
17+
18+
import java.lang.reflect.Constructor;
19+
import java.util.List;
20+
import java.util.concurrent.ConcurrentHashMap;
21+
import java.util.concurrent.ConcurrentMap;
22+
import org.tensorflow.AbstractGradientAdapter;
23+
import org.tensorflow.Graph;
24+
import org.tensorflow.GraphOperation;
25+
import org.tensorflow.Operand;
26+
import org.tensorflow.Output;
27+
import org.tensorflow.internal.c_api.TFJ_Scope;
28+
29+
/**
30+
* Dispatching adapter for Java-side custom gradient registration.
31+
*
32+
* <p>This class mirrors the behavior of TensorFlow Python's {@code tf.RegisterGradient} mechanism
33+
* by providing a centralized dispatch layer for custom gradients in the Java API.
34+
*
35+
* <p>Gradients may be registered in one of two forms for a given op type:
36+
*
37+
* <ul>
38+
* <li>A raw gradient ({@link RawCustomGradient}) operating directly on {@link GraphOperation} and
39+
* {@link Output} objects.
40+
* <li>A typed gradient ({@link CustomGradient}) operating on generated {@link RawOpInputs}
41+
* subclasses.
42+
* </ul>
43+
*
44+
* <p>For any given op type, exactly one gradient definition is permitted: either raw or typed.
45+
* Duplicate registrations, or attempts to mix raw and typed gradients for the same op type, are
46+
* rejected to prevent ambiguous dispatch behavior.
47+
*
48+
* <p>At runtime, {@link #apply(Graph, TFJ_Scope, GraphOperation, List)} determines the operation
49+
* type and dispatches to the corresponding registered gradient implementation.
50+
*/
51+
final class DispatchingGradientAdapter extends AbstractGradientAdapter {
52+
53+
private final ConcurrentMap<String, RawCustomGradient> raw = new ConcurrentHashMap<>();
54+
private final ConcurrentMap<String, TypedEntry<?>> typed = new ConcurrentHashMap<>();
55+
56+
private static String dupMsg(String opType, String existingKind, String newKind) {
57+
return "A "
58+
+ existingKind
59+
+ " gradient is already registered for op type '"
60+
+ opType
61+
+ "'. Raw and typed registrations are mutually exclusive; cannot register "
62+
+ newKind
63+
+ ".";
64+
}
65+
66+
static final class TypedEntry<T extends RawOpInputs<?>> {
67+
final CustomGradient<T> grad;
68+
final Class<T> inputClass;
69+
final Constructor<T> ctor;
70+
71+
TypedEntry(CustomGradient<T> grad, Class<T> inputClass) {
72+
this.grad = grad;
73+
this.inputClass = inputClass;
74+
try {
75+
this.ctor = inputClass.getConstructor(org.tensorflow.GraphOperation.class);
76+
} catch (NoSuchMethodException e) {
77+
throw new IllegalArgumentException(
78+
"Inputs class " + inputClass.getName() + " must have a public ctor(GraphOperation).",
79+
e);
80+
}
81+
}
82+
}
83+
84+
void putRaw(String opType, RawCustomGradient g) {
85+
if (typed.containsKey(opType)) {
86+
throw new IllegalStateException(dupMsg(opType, "typed", "raw"));
87+
}
88+
RawCustomGradient prev = raw.putIfAbsent(opType, g);
89+
if (prev != null) {
90+
throw new IllegalStateException(
91+
"A raw gradient is already registered for op type '" + opType + "'.");
92+
}
93+
}
94+
95+
<T extends RawOpInputs<?>> void putTyped(
96+
String opType, CustomGradient<T> g, Class<T> inputClass) {
97+
if (raw.containsKey(opType)) {
98+
throw new IllegalStateException(dupMsg(opType, "raw", "typed"));
99+
}
100+
TypedEntry<?> prev = typed.putIfAbsent(opType, new TypedEntry<>(g, inputClass));
101+
if (prev != null) {
102+
throw new IllegalStateException(
103+
"A typed gradient is already registered for op type '" + opType + "'.");
104+
}
105+
}
106+
107+
@Override
108+
protected List<Operand<?>> apply(
109+
Graph graph, TFJ_Scope scope, GraphOperation operation, List<Output<?>> gradInputs) {
110+
111+
final String opType = operation.type();
112+
113+
RawCustomGradient rg = raw.get(opType);
114+
if (rg != null) {
115+
// NativeScope & Ops constructors are package-private => must be in org.tensorflow.op
116+
Scope nativeScope =
117+
new NativeScope(scope, graph, operation.name()).withSubScope(operation.name());
118+
return rg.call(new Ops(nativeScope), operation, gradInputs);
119+
}
120+
121+
@SuppressWarnings("rawtypes")
122+
TypedEntry te = typed.get(opType);
123+
if (te != null) {
124+
return applyTyped(graph, scope, operation, gradInputs, te);
125+
}
126+
127+
throw new IllegalStateException("No Java custom gradient registered for op type: " + opType);
128+
}
129+
130+
private <T extends RawOpInputs<?>> List<Operand<?>> applyTyped(
131+
Graph graph,
132+
TFJ_Scope scope,
133+
GraphOperation operation,
134+
List<Output<?>> gradInputs,
135+
TypedEntry<T> te) {
136+
try {
137+
T inputs = te.ctor.newInstance(operation);
138+
Scope nativeScope =
139+
new NativeScope(scope, graph, operation.name()).withSubScope(operation.name());
140+
return te.grad.call(new Ops(nativeScope), inputs, gradInputs);
141+
} catch (ReflectiveOperationException e) {
142+
throw new RuntimeException("Failed to instantiate inputs for " + te.inputClass.getName(), e);
143+
}
144+
}
145+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/* Copyright 2026 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================
15+
*/
16+
package org.tensorflow.op;
17+
18+
import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter;
19+
20+
/** Public bridge to a single native gradient adapter. */
21+
public final class GradientDispatch {
22+
23+
// package-private adapter that can access NativeScope/Ops constructors
24+
static final DispatchingGradientAdapter ADAPTER = new DispatchingGradientAdapter();
25+
26+
private GradientDispatch() {}
27+
28+
public static TFJ_GradFuncAdapter adapter() {
29+
return ADAPTER;
30+
}
31+
32+
public static void putRaw(String opType, RawCustomGradient gradient) {
33+
ADAPTER.putRaw(opType, gradient);
34+
}
35+
36+
public static <T extends RawOpInputs<?>> void putTyped(
37+
String opType, CustomGradient<T> gradient, Class<T> inputClass) {
38+
ADAPTER.putTyped(opType, gradient, inputClass);
39+
}
40+
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawCustomGradient.java

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import org.tensorflow.Operand;
2222
import org.tensorflow.Output;
2323
import org.tensorflow.TensorFlow;
24-
import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter;
2524

2625
/**
2726
* A custom gradient for an op of unspecified type. Should be registered using {@link
@@ -46,14 +45,4 @@ public interface RawCustomGradient {
4645
* @return the gradients of the op's inputs.
4746
*/
4847
List<Operand<?>> call(Ops tf, GraphOperation op, List<Output<?>> gradInputs);
49-
50-
/**
51-
* Create an adapter for the custom gradient so that it can be used by native code.
52-
*
53-
* <p>You should not be calling this yourself, use {@link
54-
* TensorFlow#registerCustomGradient(String, RawCustomGradient)}.
55-
*/
56-
static TFJ_GradFuncAdapter adapter(RawCustomGradient gradient) {
57-
return new RawGradientAdapter(gradient);
58-
}
5948
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/RawGradientAdapter.java

Lines changed: 0 additions & 44 deletions
This file was deleted.

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/TypedGradientAdapter.java

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)