Target_centers are normlized before calculating center offset.
target_centers = torch.nn.functional.normalize(target_centers, dim=-1) center_offset = (1-alpha)*(features.detach() - target_centers)
However, if features' norm is bigger than 1, these centers will move towards inf. While features' norm is smaller 1, these centers will head towards 0. So, you may let alone the normalization before calculating center offset.