Skip to content

Commit 4c213d4

Browse files
committed
function for classifier
1 parent 470d213 commit 4c213d4

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

utilities/cnn_classifier.m

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
function [ind,value] = cnn_classifier(A,dims,classifier,thr)
2+
3+
%cnn_classifer classify spatial components using a pretrained CNN
4+
%classifier using the keras importer add on.
5+
% IND = cnn_classifier(A,dims,classifier,thr) returns a binary vector indicating
6+
% whether the set of spatial components A, with dimensions of the field
7+
% of view DIMS, pass the threshold THR for the given CLASSIFIER
8+
%
9+
% [IND,VALUE] = cnn_classifier(A,dims,classifier,thr) also returns the
10+
% output value of the classifier
11+
%
12+
% INPUTS:
13+
% A: 2d matrix
14+
% dims: vector with dimensions of the FOV
15+
% classifier: path to pretrained classifier model
16+
% thr: threshold for accepting component (default: 0.2)
17+
%
18+
% note: The function requires Matlab version 2017b (9.3) or later, Neural
19+
% Networks toolbox version 2017b (11.0) or later, the Neural Network
20+
% Toolbox(TM) Importer for TensorFlow-Keras Models.
21+
22+
% Written by Eftychios A. Pnevmatikakis. Classifier trained by Andrea
23+
% Giovannucci, Flatiron Institute, 2017
24+
25+
if verLessThan('matlab','9.3') || verLessThan('nnet','11.0') || isempty(which('importKerasNetwork'))
26+
error(strcat('The function requires Matlab version 2017b (9.3) or later, Neural\n', ...
27+
'Networks toolbox version 2017b (11.0) or lvalater, the Neural Networks ', ...
28+
'Toolbox(TM) Importer for TensorFlow-Keras Models.'))
29+
end
30+
31+
if ~exist('thr','var'); thr = 0.2; end
32+
33+
K = size(A,2); % number of components
34+
A = A/spdiags(sqrt(sum(A.^2,1))'+eps,0,K,K); % normalize to sum 1 for each compoennt
35+
A_com = extract_patch(A,dims,[50,50]); % extract 50 x 50 patches
36+
37+
net_classifier = importKerasNetwork(classifier);
38+
out = predict(net_classifier,double(A_com));
39+
value = out(:,2);
40+
ind = (value >= thr);

0 commit comments

Comments
 (0)