File size: 4,912 Bytes
3bbb319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
classdef Net < handle
% Wrapper class of caffe::Net in matlab
properties (Access = private)
hNet_self
attributes
% attribute fields
% hLayer_layers
% hBlob_blobs
% input_blob_indices
% output_blob_indices
% layer_names
% blob_names
end
properties (SetAccess = private)
layer_vec
blob_vec
inputs
outputs
name2layer_index
name2blob_index
layer_names
blob_names
end
methods
function self = Net(varargin)
% decide whether to construct a net from model_file or handle
if ~(nargin == 1 && isstruct(varargin{1}))
% construct a net from model_file
self = caffe.get_net(varargin{:});
return
end
% construct a net from handle
hNet_net = varargin{1};
CHECK(is_valid_handle(hNet_net), 'invalid Net handle');
% setup self handle and attributes
self.hNet_self = hNet_net;
self.attributes = caffe_('net_get_attr', self.hNet_self);
% setup layer_vec
self.layer_vec = caffe.Layer.empty();
for n = 1:length(self.attributes.hLayer_layers)
self.layer_vec(n) = caffe.Layer(self.attributes.hLayer_layers(n));
end
% setup blob_vec
self.blob_vec = caffe.Blob.empty();
for n = 1:length(self.attributes.hBlob_blobs);
self.blob_vec(n) = caffe.Blob(self.attributes.hBlob_blobs(n));
end
% setup input and output blob and their names
% note: add 1 to indices as matlab is 1-indexed while C++ is 0-indexed
self.inputs = ...
self.attributes.blob_names(self.attributes.input_blob_indices + 1);
self.outputs = ...
self.attributes.blob_names(self.attributes.output_blob_indices + 1);
% create map objects to map from name to layers and blobs
self.name2layer_index = containers.Map(self.attributes.layer_names, ...
1:length(self.attributes.layer_names));
self.name2blob_index = containers.Map(self.attributes.blob_names, ...
1:length(self.attributes.blob_names));
% expose layer_names and blob_names for public read access
self.layer_names = self.attributes.layer_names;
self.blob_names = self.attributes.blob_names;
end
function delete (self)
if ~isempty(self.hNet_self)
caffe_('delete_net', self.hNet_self);
end
end
function layer = layers(self, layer_name)
CHECK(ischar(layer_name), 'layer_name must be a string');
layer = self.layer_vec(self.name2layer_index(layer_name));
end
function blob = blobs(self, blob_name)
CHECK(ischar(blob_name), 'blob_name must be a string');
blob = self.blob_vec(self.name2blob_index(blob_name));
end
function blob = params(self, layer_name, blob_index)
CHECK(ischar(layer_name), 'layer_name must be a string');
CHECK(isscalar(blob_index), 'blob_index must be a scalar');
blob = self.layer_vec(self.name2layer_index(layer_name)).params(blob_index);
end
function forward_prefilled(self)
caffe_('net_forward', self.hNet_self);
end
function backward_prefilled(self)
caffe_('net_backward', self.hNet_self);
end
function res = forward(self, input_data)
CHECK(iscell(input_data), 'input_data must be a cell array');
CHECK(length(input_data) == length(self.inputs), ...
'input data cell length must match input blob number');
% copy data to input blobs
for n = 1:length(self.inputs)
self.blobs(self.inputs{n}).set_data(input_data{n});
end
self.forward_prefilled();
% retrieve data from output blobs
res = cell(length(self.outputs), 1);
for n = 1:length(self.outputs)
res{n} = self.blobs(self.outputs{n}).get_data();
end
end
function res = backward(self, output_diff)
CHECK(iscell(output_diff), 'output_diff must be a cell array');
CHECK(length(output_diff) == length(self.outputs), ...
'output diff cell length must match output blob number');
% copy diff to output blobs
for n = 1:length(self.outputs)
self.blobs(self.outputs{n}).set_diff(output_diff{n});
end
self.backward_prefilled();
% retrieve diff from input blobs
res = cell(length(self.inputs), 1);
for n = 1:length(self.inputs)
res{n} = self.blobs(self.inputs{n}).get_diff();
end
end
function copy_from(self, weights_file)
CHECK(ischar(weights_file), 'weights_file must be a string');
CHECK_FILE_EXIST(weights_file);
caffe_('net_copy_from', self.hNet_self, weights_file);
end
function reshape(self)
caffe_('net_reshape', self.hNet_self);
end
function save(self, weights_file)
CHECK(ischar(weights_file), 'weights_file must be a string');
caffe_('net_save', self.hNet_self, weights_file);
end
end
end
|