In this case, the output is always zeros regardless of the input. Because after the max op, dim 1 only has one element and x - x.mean is always 0.
|
def forward(self, x): |
|
""" |
|
Args: |
|
x: Input tensor of shape (batch_size, in_features) |
|
|
|
Returns: |
|
Output tensor of shape (batch_size, out_features) |
|
""" |
|
x = self.gemm(x) |
|
x = torch.max(x, dim=self.max_dim, keepdim=True).values |
|
x = x - x.mean(dim=1, keepdim=True) |
|
x = torch.nn.functional.gelu(x) |
|
return x |
In this case, the output is always zeros regardless of the input. Because after the max op, dim 1 only has one element and x - x.mean is always 0.
KernelBench/KernelBench/level2/80_Gemm_Max_Subtract_GELU.py
Lines 13 to 25 in 768d52c