Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 4170b23

Browse files
committed
Add a attach_grad to hybradize block
format deduplicate MarkDCVariables rebase error in block.py format manual format error fix Add attach_grad gluon
1 parent e36c9f0 commit 4170b23

File tree

15 files changed

+228
-24
lines changed

15 files changed

+228
-24
lines changed

include/mxnet/c_api.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,13 @@ MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var,
12761276
NDArrayHandle* var_handles,
12771277
uint32_t* reqs_array,
12781278
NDArrayHandle* grad_handles);
1279+
/*!
1280+
* \brief mark nonleaf NDArrays as variables during deferredcomputation
1281+
* \param num_nleafs number of nonleaf NDArrays
1282+
* \param cnt_var count of existing marked nonleaf variables
1283+
* \return 0 when success, -1 when failure happens
1284+
*/
1285+
MXNET_DLL int MXNDArrayMarkDCVariables(NDArrayHandle* nleaf_handles, int num_nleafs, int cnt_var);
12791286
/*!
12801287
* \brief unmark nonleaf NDArrays to free the memory
12811288
* \param num_var number of variable NDArrays

include/mxnet/imperative.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ class Imperative {
290290
void MarkVariables(const std::vector<NDArray*>& variables,
291291
const std::vector<uint32_t>& grad_reqs,
292292
const std::vector<NDArray*>& gradients);
293+
/*! \brief mark nonleaf variables during DC for computing gradients. */
294+
void MarkDCVariables(const std::vector<NDArray*>& nleafs, int cnt_vars);
293295
/*! \brief unmark nonleaf variables to free the memory. */
294296
void DropGrads(const std::vector<NDArray*>& variables);
295297
/*! \brief compute the gradient of outputs w.r.t variables. */

include/mxnet/ndarray.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ class NDArray {
351351
bool fresh_out_grad() const;
352352
/*! \return updated grad state in autograd_entry_ */
353353
void set_fresh_out_grad(bool state) const;
354+
/*! \brief copy the autograd_entry_ from src NDArray */
355+
void copy_autograd_entry_(const NDArray* src);
354356
/*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized
355357
* Throws an exception if the indices array shape is inconsistent
356358
* Returns false if the indices array is empty(nnz = 0) for csr/row_sparse

python/mxnet/_ctypes/cached_op.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __call__(self, *args, **kwargs):
7777
if not default_device:
7878
default_device = kwargs.pop('default_ctx', None)
7979
out = kwargs.pop('out', None)
80+
nleaf_vars = [container.data() for container in kwargs.pop('_nleaf_vars', [])]
8081
if kwargs:
8182
raise TypeError(
8283
"CachedOp.__call__ got unexpected keyword argument(s): " + \
@@ -93,7 +94,10 @@ def __call__(self, *args, **kwargs):
9394
*args,
9495
type_id,
9596
device_id,
96-
*out_arg
97+
len(out_arg),
98+
*out_arg,
99+
len(nleaf_vars),
100+
*nleaf_vars
97101
)
98102
if out is not None:
99103
return out

python/mxnet/gluon/block.py

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@
3333
import json
3434
import numpy as np
3535

36-
from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB
36+
from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB, \
37+
_as_list
3738
from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \
3839
profiler as _profiler, device as _device
3940
from ..symbol.numpy import _symbol as np_symbol
4041
from ..symbol import Symbol, fromjson
4142
from ..ndarray import NDArray, get_dtype_name
42-
from .parameter import Parameter, DeferredInitializationError
43+
from .parameter import Parameter, DeferredInitializationError, Intermediate
4344
from .utils import _indent, _brief_print_list, HookHandle, shape_is_known
4445
from .utils import _check_same_symbol_type, _check_all_np_ndarrays, _check_block_input_np_ndarrays
4546
from .. import numpy_extension as _mx_npx
@@ -1091,6 +1092,7 @@ def __init__(self):
10911092
self._backend_opts = {}
10921093
self._partition_if_dynamic = True
10931094
self._first_forward = True
1095+
self._nleaf_vars = OrderedDict()
10941096

10951097
def __setattr__(self, name, value):
10961098
"""Registers parameters."""
@@ -1302,7 +1304,7 @@ def _call_cached_op(self, *args):
13021304
args_without_none = [ele for ele in args if ele is not None]
13031305
cargs = [args_without_none[i] if is_arg else i.data()
13041306
for is_arg, name, i in self._cached_op_args]
1305-
out = self._cached_op(*cargs)
1307+
out = self._cached_op(*cargs, _nleaf_vars=self._nleaf_vars.values())
13061308
if isinstance(out, NDArray):
13071309
out = [out]
13081310
return _regroup(out, self._out_format)
@@ -1678,6 +1680,92 @@ def reset_ctx(self, ctx):
16781680
self.reset_device(ctx)
16791681

16801682

1683+
def intermediate(self, names, var_arrays_inp, grad_req='write'):
1684+
"""Mark the intermediate variables.
1685+
1686+
Parameters
1687+
----------
1688+
name : str or tuple[str], name of the registered intermediate variable
1689+
var_arrays_inp : ndarray or tuple[ndarray], the output of the expression
1690+
grad_req : str, gradient request
1691+
"""
1692+
if not self._active:
1693+
var_arrays = _as_list(var_arrays_inp)
1694+
names = _as_list(names)
1695+
self._nleaf_vars.update(
1696+
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
1697+
else:
1698+
prev_val = dc.set_deferred_compute(False)
1699+
var_arrays = _as_list(var_arrays_inp)
1700+
names = _as_list(names)
1701+
# Prepare ctypes array types
1702+
import ctypes
1703+
var_handles_type = ctypes.c_void_p * len(var_arrays)
1704+
# Convert handles
1705+
var_handles = var_handles_type(*[arr.handle for arr in var_arrays])
1706+
check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars)))
1707+
self._nleaf_vars.update(
1708+
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
1709+
dc.set_deferred_compute(prev_val)
1710+
return var_arrays_inp
1711+
1712+
def attach_grad_intermediate(self):
1713+
"""Attach gradient to all the intermediate variables.
1714+
"""
1715+
for val in self._nleaf_vars.values():
1716+
val.data().attach_grad(grad_req=val.grad_req)
1717+
1718+
def get_intermediate(self, names):
1719+
"""Get the intermediate variables by names
1720+
"""
1721+
if isinstance(names, list):
1722+
return [self._nleaf_vars[n] for n in names]
1723+
else:
1724+
return self._nleaf_vars[names]
1725+
1726+
def intermediate(self, names, var_arrays_inp, grad_req='write'):
1727+
"""Mark the intermediate variables.
1728+
1729+
Parameters
1730+
----------
1731+
name : str or tuple[str], name of the registered intermediate variable
1732+
var_arrays_inp : ndarray or tuple[ndarray], the output of the expression
1733+
grad_req : str, gradient request
1734+
"""
1735+
if not self._active:
1736+
var_arrays = _as_list(var_arrays_inp)
1737+
names = _as_list(names)
1738+
self._nleaf_vars.update(
1739+
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
1740+
else:
1741+
prev_val = dc.set_deferred_compute(False)
1742+
var_arrays = _as_list(var_arrays_inp)
1743+
names = _as_list(names)
1744+
# Prepare ctypes array types
1745+
import ctypes
1746+
var_handles_type = ctypes.c_void_p * len(var_arrays)
1747+
# Convert handles
1748+
var_handles = var_handles_type(*[arr.handle for arr in var_arrays])
1749+
check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars)))
1750+
self._nleaf_vars.update(
1751+
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
1752+
dc.set_deferred_compute(prev_val)
1753+
return var_arrays_inp
1754+
1755+
def attach_grad_intermediate(self):
1756+
"""Attach gradient to all the intermediate variables.
1757+
"""
1758+
for val in self._nleaf_vars.values():
1759+
val.data().attach_grad(grad_req=val.grad_req)
1760+
1761+
def get_intermediate(self, names):
1762+
"""Get the intermediate variables by names
1763+
"""
1764+
if isinstance(names, list):
1765+
return [self._nleaf_vars[n] for n in names]
1766+
else:
1767+
return self._nleaf_vars[names]
1768+
16811769
class SymbolBlock(HybridBlock):
16821770
"""Construct block from symbol. This is useful for using pre-trained models
16831771
as feature extractors. For example, you may want to extract the output

python/mxnet/gluon/parameter.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,3 +773,40 @@ def grad_req(self, req):
773773
warnings.warn('Constant parameter "{}" does not support '
774774
'grad_req other than "null", and new value "{}" '
775775
'is ignored.'.format(self.name, req))
776+
777+
class Intermediate:
778+
"""A Container holding marked intermediate variables of Blocks.
779+
780+
Parameters
781+
----------
782+
name : str.
783+
Name of this parameter. It be used to retrieve the marked variables.
784+
grad_req : {'write', 'add', 'null'}, default 'write'
785+
Specifies how to update gradient to grad arrays.
786+
787+
- ``'write'`` means everytime gradient is written to grad :py:class:`NDArray`.
788+
- ``'add'`` means everytime gradient is added to the grad :py:class:`NDArray`. You need
789+
to manually call ``zero_grad()`` to clear the gradient buffer before each
790+
iteration when using this option.
791+
- 'null' means gradient is not requested for this parameter. gradient arrays
792+
will not be allocated.
793+
"""
794+
def __init__(self, name, data=None, grad_req='write'):
795+
self._name = name
796+
self._data = data
797+
self._grad_req = grad_req
798+
799+
def __repr__(self):
800+
s = 'Intermediate name={name}'
801+
return s.format(name=self._name)
802+
803+
def data(self):
804+
return self._data
805+
806+
@property
807+
def name(self):
808+
return self._name
809+
810+
@property
811+
def grad_req(self):
812+
return self._grad_req

src/api/cached_op_api.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,18 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke")
4444
ndinputs.push_back(static_cast<mxnet::NDArray*>(args[i]));
4545
}
4646

47+
int num_outputs = args[num_inputs + 4];
48+
int num_nleafs = args[num_inputs + num_outputs + 5];
4749
std::vector<NDArray*> ndoutputs;
4850
ndoutputs.reserve(op->num_outputs());
49-
if (args[num_inputs + 4].type_code() == kNull) {
51+
if (args[num_inputs + 5].type_code() == kNull) {
5052
for (int i = 0; i < op->num_outputs(); ++i)
5153
ndoutputs.push_back(new NDArray());
5254
} else {
53-
int array_size = args_size - num_inputs - 4;
55+
int array_size = args_size - num_inputs - num_nleafs - 6;
5456
CHECK_EQ(array_size, op->num_outputs()) << "CachedOp expects " << op->num_outputs()
5557
<< " outputs, but " << array_size << " was given.";
56-
for (int i = num_inputs + 4; i < array_size; ++i) {
58+
for (int i = num_inputs + 5; i < num_inputs + num_outputs + 5; ++i) {
5759
ndoutputs.push_back(args[i].operator mxnet::NDArray*());
5860
}
5961
}
@@ -69,6 +71,13 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke")
6971
default_dev_id = ctx.dev_id;
7072
}
7173

74+
std::vector<NDArray*> nleafs;
75+
nleafs.reserve(num_nleafs);
76+
for (int i = 0; i < num_nleafs; ++i) {
77+
nleafs.push_back(static_cast<mxnet::NDArray*>(args[i + num_inputs + num_outputs + 6]));
78+
}
79+
op->set_nleafs(nleafs);
80+
7281
// construct default context
7382
Context ctx =
7483
Context::Create(static_cast<Context::DeviceType>(default_dev_type), default_dev_id);

src/c_api/c_api_ndarray.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,3 +495,15 @@ int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle* output_handles,
495495
*out = s;
496496
API_END_HANDLE_ERROR(delete s;);
497497
}
498+
499+
int MXNDArrayMarkDCVariables(NDArrayHandle* nleaf_handles, int num_nleafs, int cnt_var) {
500+
API_BEGIN();
501+
std::vector<NDArray*> nleafs;
502+
nleafs.reserve(num_nleafs);
503+
for (int i = 0; i < num_nleafs; ++i) {
504+
NDArray* array = reinterpret_cast<NDArray*>(nleaf_handles[i]);
505+
nleafs.emplace_back(array);
506+
}
507+
Imperative::Get()->MarkDCVariables(nleafs, cnt_var);
508+
API_END();
509+
}

src/imperative/cached_op.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,8 @@ OpStatePtr CachedOp::DynamicForward(const Context& default_ctx,
801801
recording && inlining_,
802802
nullptr,
803803
monitor_callback_,
804-
monitor_all_);
804+
monitor_all_,
805+
nleafs_);
805806
} else {
806807
mxnet::ShapeVector shapes = g.GetAttr<mxnet::ShapeVector>("shape");
807808
NaiveRunGraph(false,
@@ -1063,6 +1064,7 @@ void CachedOp::StaticBackward(const bool retain_graph,
10631064
if (!idx.exist(entry.node.get()))
10641065
continue;
10651066
auto eid = idx.entry_id(entry);
1067+
state.array_reqs[eid] = reqs[iter->second];
10661068
// An input and an output may share the same array.
10671069
INIT_DETACHED(outputs[iter->second], arrays[eid]);
10681070
arrays[eid] = outputs[iter->second];
@@ -1073,6 +1075,7 @@ void CachedOp::StaticBackward(const bool retain_graph,
10731075
if (!idx.exist(entry.node.get()))
10741076
continue;
10751077
auto eid = idx.entry_id(entry);
1078+
state.array_reqs[eid] = reqs[i];
10761079
// An input and an output may share the same array.
10771080
INIT_DETACHED(outputs[i], arrays[eid]);
10781081
arrays[eid] = outputs[i];

src/imperative/cached_op.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,9 @@ class CachedOp {
491491
const std::unordered_set<uint32_t>& mutable_input_nodes() const {
492492
return fwd_graph_.indexed_graph().mutable_input_nodes();
493493
}
494+
void set_nleafs(const std::vector<NDArray*>& nleafs) {
495+
nleafs_ = nleafs;
496+
}
494497
virtual std::vector<nnvm::NodeEntry> Gradient(const nnvm::ObjectPtr& node,
495498
const std::vector<nnvm::NodeEntry>& ograds) const;
496499
virtual OpStatePtr Forward(const std::shared_ptr<CachedOp>& op_ptr,
@@ -649,6 +652,7 @@ class CachedOp {
649652
std::vector<uint32_t> bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_;
650653
std::vector<bool> save_inputs_, save_outputs_;
651654
std::vector<OpReqType> bwd_output_reqs_;
655+
std::vector<NDArray*> nleafs_;
652656

653657
std::function<void(const char*, const char*, NDArrayHandle)> monitor_callback_{nullptr};
654658
bool monitor_all_{false};

0 commit comments

Comments
 (0)