1+ function [ind ,value ] = cnn_classifier(A ,dims ,classifier ,thr )
2+
3+ % cnn_classifer classify spatial components using a pretrained CNN
4+ % classifier using the keras importer add on.
5+ % IND = cnn_classifier(A,dims,classifier,thr) returns a binary vector indicating
6+ % whether the set of spatial components A, with dimensions of the field
7+ % of view DIMS, pass the threshold THR for the given CLASSIFIER
8+ %
9+ % [IND,VALUE] = cnn_classifier(A,dims,classifier,thr) also returns the
10+ % output value of the classifier
11+ %
12+ % INPUTS:
13+ % A: 2d matrix
14+ % dims: vector with dimensions of the FOV
15+ % classifier: path to pretrained classifier model
16+ % thr: threshold for accepting component (default: 0.2)
17+ %
18+ % note: The function requires Matlab version 2017b (9.3) or later, Neural
19+ % Networks toolbox version 2017b (11.0) or later, the Neural Network
20+ % Toolbox(TM) Importer for TensorFlow-Keras Models.
21+
22+ % Written by Eftychios A. Pnevmatikakis. Classifier trained by Andrea
23+ % Giovannucci, Flatiron Institute, 2017
24+
25+ if verLessThan(' matlab' ,' 9.3' ) || verLessThan(' nnet' ,' 11.0' ) || isempty(which(' importKerasNetwork' ))
26+ error(strcat(' The function requires Matlab version 2017b (9.3) or later, Neural\n ' , ...
27+ ' Networks toolbox version 2017b (11.0) or lvalater, the Neural Networks ' , ...
28+ ' Toolbox(TM) Importer for TensorFlow-Keras Models.' ))
29+ end
30+
31+ if ~exist(' thr' ,' var' ); thr = 0.2 ; end
32+
33+ K = size(A ,2 ); % number of components
34+ A = A / spdiags(sqrt(sum(A .^ 2 ,1 ))' +eps ,0 ,K ,K ); % normalize to sum 1 for each compoennt
35+ A_com = extract_patch(A ,dims ,[50 ,50 ]); % extract 50 x 50 patches
36+
37+ net_classifier = importKerasNetwork(classifier );
38+ out = predict(net_classifier ,double(A_com ));
39+ value = out(: ,2 );
40+ ind = (value >= thr );
0 commit comments