@@ -66,7 +66,8 @@ namespace Microsoft.ML.Vision
6666 ///
6767 /// ### Training Algorithm Details
6868 /// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained model such as Resnet50 for the purpose
69- /// of classifying images.
69+ /// of classifying images. The technique was inspired from [TensorFlow's retrain image classification tutorial]
70+ /// (https://www.tensorflow.org/hub/tutorials/image_retraining)
7071 /// ]]>
7172 /// </format>
7273 /// </remarks>
@@ -392,7 +393,7 @@ public sealed class Options : TrainerInputBaseWithLabel
392393 public Action < ImageClassificationMetrics > MetricsCallback = null ;
393394
394395 /// <summary>
395- /// Indicates the path where the models get downloaded to and cache files saved, default is a new temporary directory
396+ /// Indicates the path where the image bottleneck cache files and trained model are saved, default is a new temporary directory
396397 /// </summary>
397398 [ Argument ( ArgumentType . AtMostOnce , HelpText = "Indicates the path where the models get downloaded to and cache files saved, default is a new temporary directory." , SortOrder = 15 ) ]
398399 public string WorkspacePath = null ;
@@ -591,6 +592,7 @@ private void InitializeTrainingGraph(IDataView input)
591592 _classCount = labelCount == 1 ? 2 : ( int ) labelCount ;
592593 var imageSize = ImagePreprocessingSize [ _options . Arch ] ;
593594 _session = LoadTensorFlowSessionFromMetaGraph ( Host , _options . Arch ) . Session ;
595+ _session . graph . as_default ( ) ;
594596 ( _jpegData , _resizedImage ) = AddJpegDecoding ( imageSize . Item1 , imageSize . Item2 , 3 ) ;
595597 _jpegDataTensorName = _jpegData . name ;
596598 _resizedImageTensorName = _resizedImage . name ;
@@ -631,6 +633,14 @@ private protected override MulticlassPredictionTransformer<ImageClassificationMo
631633
632634 private protected override ImageClassificationModelParameters TrainModelCore ( TrainContext trainContext )
633635 {
636+ // Workspace directory is cleaned after training run. However, the pipeline can be re-used by calling
637+ // fit() again after transform(), in which case we must ensure workspace directory exists. This scenario
638+ // is typical in the case of cross-validation.
639+ if ( ! Directory . Exists ( _options . WorkspacePath ) )
640+ {
641+ Directory . CreateDirectory ( _options . WorkspacePath ) ;
642+ }
643+
634644 InitializeTrainingGraph ( trainContext . TrainingSet . Data ) ;
635645 CheckTrainingParameters ( _options ) ;
636646 var validationSet = trainContext . ValidationSet ? . Data ?? _options . ValidationSet ;
@@ -1301,7 +1311,7 @@ private void VariableSummaries(RefVariable var)
13011311 var optimizer = useLearningRateScheduler ? tf . train . GradientDescentOptimizer ( _learningRateInput ) :
13021312 tf . train . GradientDescentOptimizer ( learningRate ) ;
13031313
1304- _trainStep = optimizer . minimize ( crossEntropyMean ) ;
1314+ _trainStep = optimizer . minimize ( crossEntropyMean ) ;
13051315 } ) ;
13061316
13071317 return ( _trainStep , crossEntropyMean , _labelTensor , _softMaxTensor ) ;
@@ -1341,6 +1351,11 @@ private void Dispose(bool disposing)
13411351 {
13421352 _session . close ( ) ;
13431353 }
1354+
1355+ if ( _session != null && _session . graph != IntPtr . Zero )
1356+ {
1357+ _session . graph . Dispose ( ) ;
1358+ }
13441359 }
13451360
13461361 /// <summary>
@@ -1527,6 +1542,11 @@ private void Dispose(bool disposing)
15271542 {
15281543 _session . close ( ) ;
15291544 }
1545+
1546+ if ( _session != null && _session . graph != IntPtr . Zero )
1547+ {
1548+ _session . graph . Dispose ( ) ;
1549+ }
15301550 }
15311551 }
15321552}
0 commit comments