|
classdef Net < handle |
|
|
|
|
|
properties (Access = private) |
|
hNet_self |
|
attributes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
end |
|
properties (SetAccess = private) |
|
layer_vec |
|
blob_vec |
|
inputs |
|
outputs |
|
name2layer_index |
|
name2blob_index |
|
layer_names |
|
blob_names |
|
end |
|
|
|
methods |
|
function self = Net(varargin) |
|
|
|
if ~(nargin == 1 && isstruct(varargin{1})) |
|
|
|
self = caffe.get_net(varargin{:}); |
|
return |
|
end |
|
|
|
hNet_net = varargin{1}; |
|
CHECK(is_valid_handle(hNet_net), 'invalid Net handle'); |
|
|
|
|
|
self.hNet_self = hNet_net; |
|
self.attributes = caffe_('net_get_attr', self.hNet_self); |
|
|
|
|
|
self.layer_vec = caffe.Layer.empty(); |
|
for n = 1:length(self.attributes.hLayer_layers) |
|
self.layer_vec(n) = caffe.Layer(self.attributes.hLayer_layers(n)); |
|
end |
|
|
|
|
|
self.blob_vec = caffe.Blob.empty(); |
|
for n = 1:length(self.attributes.hBlob_blobs); |
|
self.blob_vec(n) = caffe.Blob(self.attributes.hBlob_blobs(n)); |
|
end |
|
|
|
|
|
|
|
self.inputs = ... |
|
self.attributes.blob_names(self.attributes.input_blob_indices + 1); |
|
self.outputs = ... |
|
self.attributes.blob_names(self.attributes.output_blob_indices + 1); |
|
|
|
|
|
self.name2layer_index = containers.Map(self.attributes.layer_names, ... |
|
1:length(self.attributes.layer_names)); |
|
self.name2blob_index = containers.Map(self.attributes.blob_names, ... |
|
1:length(self.attributes.blob_names)); |
|
|
|
|
|
self.layer_names = self.attributes.layer_names; |
|
self.blob_names = self.attributes.blob_names; |
|
end |
|
function delete (self) |
|
if ~isempty(self.hNet_self) |
|
caffe_('delete_net', self.hNet_self); |
|
end |
|
end |
|
function layer = layers(self, layer_name) |
|
CHECK(ischar(layer_name), 'layer_name must be a string'); |
|
layer = self.layer_vec(self.name2layer_index(layer_name)); |
|
end |
|
function blob = blobs(self, blob_name) |
|
CHECK(ischar(blob_name), 'blob_name must be a string'); |
|
blob = self.blob_vec(self.name2blob_index(blob_name)); |
|
end |
|
function blob = params(self, layer_name, blob_index) |
|
CHECK(ischar(layer_name), 'layer_name must be a string'); |
|
CHECK(isscalar(blob_index), 'blob_index must be a scalar'); |
|
blob = self.layer_vec(self.name2layer_index(layer_name)).params(blob_index); |
|
end |
|
function forward_prefilled(self) |
|
caffe_('net_forward', self.hNet_self); |
|
end |
|
function backward_prefilled(self) |
|
caffe_('net_backward', self.hNet_self); |
|
end |
|
function res = forward(self, input_data) |
|
CHECK(iscell(input_data), 'input_data must be a cell array'); |
|
CHECK(length(input_data) == length(self.inputs), ... |
|
'input data cell length must match input blob number'); |
|
|
|
for n = 1:length(self.inputs) |
|
self.blobs(self.inputs{n}).set_data(input_data{n}); |
|
end |
|
self.forward_prefilled(); |
|
|
|
res = cell(length(self.outputs), 1); |
|
for n = 1:length(self.outputs) |
|
res{n} = self.blobs(self.outputs{n}).get_data(); |
|
end |
|
end |
|
function res = backward(self, output_diff) |
|
CHECK(iscell(output_diff), 'output_diff must be a cell array'); |
|
CHECK(length(output_diff) == length(self.outputs), ... |
|
'output diff cell length must match output blob number'); |
|
|
|
for n = 1:length(self.outputs) |
|
self.blobs(self.outputs{n}).set_diff(output_diff{n}); |
|
end |
|
self.backward_prefilled(); |
|
|
|
res = cell(length(self.inputs), 1); |
|
for n = 1:length(self.inputs) |
|
res{n} = self.blobs(self.inputs{n}).get_diff(); |
|
end |
|
end |
|
function copy_from(self, weights_file) |
|
CHECK(ischar(weights_file), 'weights_file must be a string'); |
|
CHECK_FILE_EXIST(weights_file); |
|
caffe_('net_copy_from', self.hNet_self, weights_file); |
|
end |
|
function reshape(self) |
|
caffe_('net_reshape', self.hNet_self); |
|
end |
|
function save(self, weights_file) |
|
CHECK(ischar(weights_file), 'weights_file must be a string'); |
|
caffe_('net_save', self.hNet_self, weights_file); |
|
end |
|
end |
|
end |
|
|