|
| 1 | +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +======================================================================= |
| 15 | +*/ |
| 16 | +package org.tensorflow.op; |
| 17 | + |
| 18 | +import java.lang.reflect.Constructor; |
| 19 | +import java.util.List; |
| 20 | +import java.util.concurrent.ConcurrentHashMap; |
| 21 | +import java.util.concurrent.ConcurrentMap; |
| 22 | +import org.tensorflow.AbstractGradientAdapter; |
| 23 | +import org.tensorflow.Graph; |
| 24 | +import org.tensorflow.GraphOperation; |
| 25 | +import org.tensorflow.Operand; |
| 26 | +import org.tensorflow.Output; |
| 27 | +import org.tensorflow.internal.c_api.TFJ_Scope; |
| 28 | + |
| 29 | +/** |
| 30 | + * Dispatching adapter for Java-side custom gradient registration. |
| 31 | + * |
| 32 | + * <p>This class mirrors the behavior of TensorFlow Python's {@code tf.RegisterGradient} mechanism |
| 33 | + * by providing a centralized dispatch layer for custom gradients in the Java API. |
| 34 | + * |
| 35 | + * <p>Gradients may be registered in one of two forms for a given op type: |
| 36 | + * |
| 37 | + * <ul> |
| 38 | + * <li>A raw gradient ({@link RawCustomGradient}) operating directly on {@link GraphOperation} and |
| 39 | + * {@link Output} objects. |
| 40 | + * <li>A typed gradient ({@link CustomGradient}) operating on generated {@link RawOpInputs} |
| 41 | + * subclasses. |
| 42 | + * </ul> |
| 43 | + * |
| 44 | + * <p>For any given op type, exactly one gradient definition is permitted: either raw or typed. |
| 45 | + * Duplicate registrations, or attempts to mix raw and typed gradients for the same op type, are |
| 46 | + * rejected to prevent ambiguous dispatch behavior. |
| 47 | + * |
| 48 | + * <p>At runtime, {@link #apply(Graph, TFJ_Scope, GraphOperation, List)} determines the operation |
| 49 | + * type and dispatches to the corresponding registered gradient implementation. |
| 50 | + */ |
| 51 | +final class DispatchingGradientAdapter extends AbstractGradientAdapter { |
| 52 | + |
| 53 | + private final ConcurrentMap<String, RawCustomGradient> raw = new ConcurrentHashMap<>(); |
| 54 | + private final ConcurrentMap<String, TypedEntry<?>> typed = new ConcurrentHashMap<>(); |
| 55 | + |
| 56 | + private static String dupMsg(String opType, String existingKind, String newKind) { |
| 57 | + return "A " |
| 58 | + + existingKind |
| 59 | + + " gradient is already registered for op type '" |
| 60 | + + opType |
| 61 | + + "'. Raw and typed registrations are mutually exclusive; cannot register " |
| 62 | + + newKind |
| 63 | + + "."; |
| 64 | + } |
| 65 | + |
| 66 | + static final class TypedEntry<T extends RawOpInputs<?>> { |
| 67 | + final CustomGradient<T> grad; |
| 68 | + final Class<T> inputClass; |
| 69 | + final Constructor<T> ctor; |
| 70 | + |
| 71 | + TypedEntry(CustomGradient<T> grad, Class<T> inputClass) { |
| 72 | + this.grad = grad; |
| 73 | + this.inputClass = inputClass; |
| 74 | + try { |
| 75 | + this.ctor = inputClass.getConstructor(org.tensorflow.GraphOperation.class); |
| 76 | + } catch (NoSuchMethodException e) { |
| 77 | + throw new IllegalArgumentException( |
| 78 | + "Inputs class " + inputClass.getName() + " must have a public ctor(GraphOperation).", |
| 79 | + e); |
| 80 | + } |
| 81 | + } |
| 82 | + } |
| 83 | + |
| 84 | + void putRaw(String opType, RawCustomGradient g) { |
| 85 | + if (typed.containsKey(opType)) { |
| 86 | + throw new IllegalStateException(dupMsg(opType, "typed", "raw")); |
| 87 | + } |
| 88 | + RawCustomGradient prev = raw.putIfAbsent(opType, g); |
| 89 | + if (prev != null) { |
| 90 | + throw new IllegalStateException( |
| 91 | + "A raw gradient is already registered for op type '" + opType + "'."); |
| 92 | + } |
| 93 | + } |
| 94 | + |
| 95 | + <T extends RawOpInputs<?>> void putTyped( |
| 96 | + String opType, CustomGradient<T> g, Class<T> inputClass) { |
| 97 | + if (raw.containsKey(opType)) { |
| 98 | + throw new IllegalStateException(dupMsg(opType, "raw", "typed")); |
| 99 | + } |
| 100 | + TypedEntry<?> prev = typed.putIfAbsent(opType, new TypedEntry<>(g, inputClass)); |
| 101 | + if (prev != null) { |
| 102 | + throw new IllegalStateException( |
| 103 | + "A typed gradient is already registered for op type '" + opType + "'."); |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + @Override |
| 108 | + protected List<Operand<?>> apply( |
| 109 | + Graph graph, TFJ_Scope scope, GraphOperation operation, List<Output<?>> gradInputs) { |
| 110 | + |
| 111 | + final String opType = operation.type(); |
| 112 | + |
| 113 | + RawCustomGradient rg = raw.get(opType); |
| 114 | + if (rg != null) { |
| 115 | + // NativeScope & Ops constructors are package-private => must be in org.tensorflow.op |
| 116 | + Scope nativeScope = |
| 117 | + new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); |
| 118 | + return rg.call(new Ops(nativeScope), operation, gradInputs); |
| 119 | + } |
| 120 | + |
| 121 | + @SuppressWarnings("rawtypes") |
| 122 | + TypedEntry te = typed.get(opType); |
| 123 | + if (te != null) { |
| 124 | + return applyTyped(graph, scope, operation, gradInputs, te); |
| 125 | + } |
| 126 | + |
| 127 | + throw new IllegalStateException("No Java custom gradient registered for op type: " + opType); |
| 128 | + } |
| 129 | + |
| 130 | + private <T extends RawOpInputs<?>> List<Operand<?>> applyTyped( |
| 131 | + Graph graph, |
| 132 | + TFJ_Scope scope, |
| 133 | + GraphOperation operation, |
| 134 | + List<Output<?>> gradInputs, |
| 135 | + TypedEntry<T> te) { |
| 136 | + try { |
| 137 | + T inputs = te.ctor.newInstance(operation); |
| 138 | + Scope nativeScope = |
| 139 | + new NativeScope(scope, graph, operation.name()).withSubScope(operation.name()); |
| 140 | + return te.grad.call(new Ops(nativeScope), inputs, gradInputs); |
| 141 | + } catch (ReflectiveOperationException e) { |
| 142 | + throw new RuntimeException("Failed to instantiate inputs for " + te.inputClass.getName(), e); |
| 143 | + } |
| 144 | + } |
| 145 | +} |
0 commit comments