- 
                Notifications
    You must be signed in to change notification settings 
- Fork 723
Open
Labels
Description
Is your feature request related to a problem? Please describe.
Currently, ptf-v2 only supports metrics as loss functions. As suggested by @fkiraly, we should also provide users some more flexibility by allowing the user to use nn losses as well.
But this poses some problems:
- nnlosses cannot handle- listand currently for multi-target,- ptf-v2uses- listof tensors
- Some nnlosses need thetarget(ground truth) andy_predto be of same shape , while inptf-v2they are of different shapes (target- 2D,y_pred- 3D). While othernnlosses require them to be of different shapes. The choice of shapes clearly depends on what kind of loss we are using.
Describe the solution you'd like
We could:
- Create an adapter class for nnlosses that handle the shapes internally, so that the output contract of model layer remains consistent. The shape handling specific to the loss we are using, will happen inside these adapters.
- Create a wrapper for MultiLossinside theBaseModelfornnlosses (or for all losses?), which directly pass a list of samennloss (or any other loss) if only one loss is passed in case of multi-target. This list will have samelenasnum_targets. This wrapper would send only tensors tonnlosses and notlist.