Skip to content

[ENH] Add support for nn losses to ptf-v2 #1970

@phoeenniixx

Description

@phoeenniixx

Is your feature request related to a problem? Please describe.
Currently, ptf-v2 only supports metrics as loss functions. As suggested by @fkiraly, we should also provide users some more flexibility by allowing the user to use nn losses as well.
But this poses some problems:

  • nn losses cannot handle list and currently for multi-target, ptf-v2 uses list of tensors
  • Some nn losses need the target (ground truth) and y_pred to be of same shape , while in ptf-v2 they are of different shapes (target - 2D, y_pred - 3D). While other nn losses require them to be of different shapes. The choice of shapes clearly depends on what kind of loss we are using.

Describe the solution you'd like
We could:

  • Create an adapter class for nn losses that handle the shapes internally, so that the output contract of model layer remains consistent. The shape handling specific to the loss we are using, will happen inside these adapters.
  • Create a wrapper for MultiLoss inside the BaseModel for nn losses (or for all losses?), which directly pass a list of same nn loss (or any other loss) if only one loss is passed in case of multi-target. This list will have same len as num_targets. This wrapper would send only tensors to nn losses and not list.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions