I implemented the following function
def t_asy(self, data, beta: float):
power = 1 + (beta * torch.linspace(0, 1, data.shape[-1], device=data.device)) * data.sqrt()
positive_values_transformed = data**power
return torch.where(data > 0, positive_values_transformed, data)
I tried to find the gradient in respect to data variable, I used this torch.func.grad(self.t_asy)(data)
to find the gradient, however, where data was lower than 0 I got value nan for the gradient.
Why is that?
Thanks