@@ -186,6 +186,9 @@ private void CheckTrainingParameters(ImageClassificationEstimator.Options option
186186
187187 if ( _session . graph . OperationByName ( _labelTensor . name . Split ( ':' ) [ 0 ] ) == null )
188188 throw Host . ExceptParam ( nameof ( options . TensorFlowLabel ) , $ "'{ options . TensorFlowLabel } ' does not exist in the model") ;
189+ if ( options . EarlyStoppingCriteria != null && options . ValidationSet == null && options . TestOnTrainSet == false )
190+ throw Host . ExceptParam ( nameof ( options . EarlyStoppingCriteria ) , $ "Early stopping enabled but unable to find a validation" +
191+ $ " set and/or train set testing disabled. Please disable early stopping or either provide a validation set or enable train set training.") ;
189192 }
190193
191194 private ( Tensor , Tensor ) AddJpegDecoding ( int height , int width , int depth )
@@ -381,6 +384,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
381384 float crossentropy = 0 ;
382385 for ( int epoch = 0 ; epoch < epochs ; epoch += 1 )
383386 {
387+ batchIndex = 0 ;
384388 metrics . Train . Accuracy = 0 ;
385389 metrics . Train . CrossEntropy = 0 ;
386390 metrics . Train . BatchProcessedCount = 0 ;
@@ -432,6 +436,42 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
432436 }
433437 }
434438
439+ //Process last incomplete batch
440+ if ( batchIndex > 0 )
441+ {
442+ featureTensorShape [ 0 ] = batchIndex ;
443+ featureBatchSizeInBytes = sizeof ( float ) * featureLength * batchIndex ;
444+ labelTensorShape [ 0 ] = batchIndex ;
445+ labelBatchSizeInBytes = sizeof ( long ) * batchIndex ;
446+ runner . AddInput ( new Tensor ( featureBatchPtr , featureTensorShape , TF_DataType . TF_FLOAT , featureBatchSizeInBytes ) , 0 )
447+ . AddInput ( new Tensor ( labelBatchPtr , labelTensorShape , TF_DataType . TF_INT64 , labelBatchSizeInBytes ) , 1 )
448+ . Run ( ) ;
449+
450+ metrics . Train . BatchProcessedCount += 1 ;
451+
452+ if ( options . TestOnTrainSet && statisticsCallback != null )
453+ {
454+ var outputTensors = testEvalRunner
455+ . AddInput ( new Tensor ( featureBatchPtr , featureTensorShape , TF_DataType . TF_FLOAT , featureBatchSizeInBytes ) , 0 )
456+ . AddInput ( new Tensor ( labelBatchPtr , labelTensorShape , TF_DataType . TF_INT64 , labelBatchSizeInBytes ) , 1 )
457+ . Run ( ) ;
458+
459+ outputTensors [ 0 ] . ToScalar < float > ( ref accuracy ) ;
460+ outputTensors [ 1 ] . ToScalar < float > ( ref crossentropy ) ;
461+ metrics . Train . Accuracy += accuracy ;
462+ metrics . Train . CrossEntropy += crossentropy ;
463+
464+ outputTensors [ 0 ] . Dispose ( ) ;
465+ outputTensors [ 1 ] . Dispose ( ) ;
466+ }
467+
468+ batchIndex = 0 ;
469+ featureTensorShape [ 0 ] = batchSize ;
470+ featureBatchSizeInBytes = sizeof ( float ) * featureBatch . Length ;
471+ labelTensorShape [ 0 ] = batchSize ;
472+ labelBatchSizeInBytes = sizeof ( long ) * batchSize ;
473+ }
474+
435475 if ( options . TestOnTrainSet && statisticsCallback != null )
436476 {
437477 metrics . Train . Epoch = epoch ;
@@ -443,7 +483,15 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
443483 }
444484
445485 if ( validationSet == null )
486+ {
487+ //Early stopping check
488+ if ( options . EarlyStoppingCriteria != null )
489+ {
490+ if ( options . EarlyStoppingCriteria . ShouldStop ( metrics . Train ) )
491+ break ;
492+ }
446493 continue ;
494+ }
447495
448496 batchIndex = 0 ;
449497 metrics . Train . BatchProcessedCount = 0 ;
@@ -481,6 +529,31 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
481529 }
482530 }
483531
532+ //Process last incomplete batch
533+ if ( batchIndex > 0 )
534+ {
535+ featureTensorShape [ 0 ] = batchIndex ;
536+ featureBatchSizeInBytes = sizeof ( float ) * featureLength * batchIndex ;
537+ labelTensorShape [ 0 ] = batchIndex ;
538+ labelBatchSizeInBytes = sizeof ( long ) * batchIndex ;
539+ var outputTensors = validationEvalRunner
540+ . AddInput ( new Tensor ( featureBatchPtr , featureTensorShape , TF_DataType . TF_FLOAT , featureBatchSizeInBytes ) , 0 )
541+ . AddInput ( new Tensor ( labelBatchPtr , labelTensorShape , TF_DataType . TF_INT64 , labelBatchSizeInBytes ) , 1 )
542+ . Run ( ) ;
543+
544+ outputTensors [ 0 ] . ToScalar < float > ( ref accuracy ) ;
545+ metrics . Train . Accuracy += accuracy ;
546+ metrics . Train . BatchProcessedCount += 1 ;
547+ batchIndex = 0 ;
548+
549+ featureTensorShape [ 0 ] = batchSize ;
550+ featureBatchSizeInBytes = sizeof ( float ) * featureBatch . Length ;
551+ labelTensorShape [ 0 ] = batchSize ;
552+ labelBatchSizeInBytes = sizeof ( long ) * batchSize ;
553+
554+ outputTensors [ 0 ] . Dispose ( ) ;
555+ }
556+
484557 if ( statisticsCallback != null )
485558 {
486559 metrics . Train . Epoch = epoch ;
0 commit comments