ECCV 2020 Submission
class QuickCumsum(torch.autograd.Function):
@staticmethod
def forward(ctx, x, ranks):
"""Perform sum pooling where each bin has a variable
number of points to be pooled.
x: N x D tensor of features to be pooled
ranks: N tensor of bin ids
"""
x = x.cumsum(0)
kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
kept[:-1] = (ranks[1:] != ranks[:-1])
x = x[kept]
x = torch.cat((x[:1], x[1:] - x[:-1]))
# save kept for backward
ctx.save_for_backward(kept)
return x, geom_feats
@staticmethod
def backward(ctx, gradx, gradgeom):
kept, = ctx.saved_tensors
back = torch.cumsum(kept, 0)
back[kept] -= 1
val = gradx[back]
return val, None