I am working on pointnet++ model of deep learning. I have completed my preprocessing steps and defining neural network and PointNet++ utility function. But there is one function caled as query_ball_point which is continuosly giving me errors.
The error is as shown below
File "/home/aniruddha/PycharmProjects/Pointnet_Pointnet2_pytorch/models/pointnet2_utils.py", line 138, in query_ball_point
group_idx[mask] = group_first[mask]
IndexError: The shape of the mask [8, 512, 3] at index 2 does not match the shape of the indexed tensor [8, 512, 32] at index 2
This is the function for query_ball_point. I have done everything possible to solve this error and now I am unable to think of a good or correct solution
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Group points within a specified radius around the new_xyz points.
Args:
radius (float): The radius within which to look for neighboring points.
nsample (int): The maximum number of points to sample in each group.
xyz (torch.Tensor): All points, shape (B, N, 3).
new_xyz (torch.Tensor): Query points, shape (B, S, 3).
Returns:
torch.Tensor: Grouped point indices, shape (B, S, nsample).
"""
B, N, _ = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(xyz.device).view(1, 1, N).repeat(
[B, S, 1]) # Initialize group indices
sqrdists = square_distance(new_xyz, xyz) # Calculate squared distances between new_xyz and xyz
# Mask out points that are outside the radius
group_idx[sqrdists > radius ** 2] = N
# Sort indices based on distances and select the closest `nsample` points
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
# Handle cases where there are fewer than `nsample` valid points in the neighborhood
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx