|
33 | 33 | import json |
34 | 34 | import numpy as np |
35 | 35 |
|
36 | | -from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB |
| 36 | +from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB, \ |
| 37 | + _as_list |
37 | 38 | from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \ |
38 | 39 | profiler as _profiler, device as _device |
39 | 40 | from ..symbol.numpy import _symbol as np_symbol |
40 | 41 | from ..symbol import Symbol, fromjson |
41 | 42 | from ..ndarray import NDArray, get_dtype_name |
42 | | -from .parameter import Parameter, DeferredInitializationError |
| 43 | +from .parameter import Parameter, DeferredInitializationError, Intermediate |
43 | 44 | from .utils import _indent, _brief_print_list, HookHandle, shape_is_known |
44 | 45 | from .utils import _check_same_symbol_type, _check_all_np_ndarrays, _check_block_input_np_ndarrays |
45 | 46 | from .. import numpy_extension as _mx_npx |
@@ -1091,6 +1092,7 @@ def __init__(self): |
1091 | 1092 | self._backend_opts = {} |
1092 | 1093 | self._partition_if_dynamic = True |
1093 | 1094 | self._first_forward = True |
| 1095 | + self._nleaf_vars = OrderedDict() |
1094 | 1096 |
|
1095 | 1097 | def __setattr__(self, name, value): |
1096 | 1098 | """Registers parameters.""" |
@@ -1302,7 +1304,7 @@ def _call_cached_op(self, *args): |
1302 | 1304 | args_without_none = [ele for ele in args if ele is not None] |
1303 | 1305 | cargs = [args_without_none[i] if is_arg else i.data() |
1304 | 1306 | for is_arg, name, i in self._cached_op_args] |
1305 | | - out = self._cached_op(*cargs) |
| 1307 | + out = self._cached_op(*cargs, _nleaf_vars=self._nleaf_vars.values()) |
1306 | 1308 | if isinstance(out, NDArray): |
1307 | 1309 | out = [out] |
1308 | 1310 | return _regroup(out, self._out_format) |
@@ -1678,6 +1680,92 @@ def reset_ctx(self, ctx): |
1678 | 1680 | self.reset_device(ctx) |
1679 | 1681 |
|
1680 | 1682 |
|
| 1683 | + def intermediate(self, names, var_arrays_inp, grad_req='write'): |
| 1684 | + """Mark the intermediate variables. |
| 1685 | +
|
| 1686 | + Parameters |
| 1687 | + ---------- |
| 1688 | + name : str or tuple[str], name of the registered intermediate variable |
| 1689 | + var_arrays_inp : ndarray or tuple[ndarray], the output of the expression |
| 1690 | + grad_req : str, gradient request |
| 1691 | + """ |
| 1692 | + if not self._active: |
| 1693 | + var_arrays = _as_list(var_arrays_inp) |
| 1694 | + names = _as_list(names) |
| 1695 | + self._nleaf_vars.update( |
| 1696 | + {name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)}) |
| 1697 | + else: |
| 1698 | + prev_val = dc.set_deferred_compute(False) |
| 1699 | + var_arrays = _as_list(var_arrays_inp) |
| 1700 | + names = _as_list(names) |
| 1701 | + # Prepare ctypes array types |
| 1702 | + import ctypes |
| 1703 | + var_handles_type = ctypes.c_void_p * len(var_arrays) |
| 1704 | + # Convert handles |
| 1705 | + var_handles = var_handles_type(*[arr.handle for arr in var_arrays]) |
| 1706 | + check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars))) |
| 1707 | + self._nleaf_vars.update( |
| 1708 | + {name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)}) |
| 1709 | + dc.set_deferred_compute(prev_val) |
| 1710 | + return var_arrays_inp |
| 1711 | + |
| 1712 | + def attach_grad_intermediate(self): |
| 1713 | + """Attach gradient to all the intermediate variables. |
| 1714 | + """ |
| 1715 | + for val in self._nleaf_vars.values(): |
| 1716 | + val.data().attach_grad(grad_req=val.grad_req) |
| 1717 | + |
| 1718 | + def get_intermediate(self, names): |
| 1719 | + """Get the intermediate variables by names |
| 1720 | + """ |
| 1721 | + if isinstance(names, list): |
| 1722 | + return [self._nleaf_vars[n] for n in names] |
| 1723 | + else: |
| 1724 | + return self._nleaf_vars[names] |
| 1725 | + |
| 1726 | + def intermediate(self, names, var_arrays_inp, grad_req='write'): |
| 1727 | + """Mark the intermediate variables. |
| 1728 | +
|
| 1729 | + Parameters |
| 1730 | + ---------- |
| 1731 | + name : str or tuple[str], name of the registered intermediate variable |
| 1732 | + var_arrays_inp : ndarray or tuple[ndarray], the output of the expression |
| 1733 | + grad_req : str, gradient request |
| 1734 | + """ |
| 1735 | + if not self._active: |
| 1736 | + var_arrays = _as_list(var_arrays_inp) |
| 1737 | + names = _as_list(names) |
| 1738 | + self._nleaf_vars.update( |
| 1739 | + {name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)}) |
| 1740 | + else: |
| 1741 | + prev_val = dc.set_deferred_compute(False) |
| 1742 | + var_arrays = _as_list(var_arrays_inp) |
| 1743 | + names = _as_list(names) |
| 1744 | + # Prepare ctypes array types |
| 1745 | + import ctypes |
| 1746 | + var_handles_type = ctypes.c_void_p * len(var_arrays) |
| 1747 | + # Convert handles |
| 1748 | + var_handles = var_handles_type(*[arr.handle for arr in var_arrays]) |
| 1749 | + check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars))) |
| 1750 | + self._nleaf_vars.update( |
| 1751 | + {name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)}) |
| 1752 | + dc.set_deferred_compute(prev_val) |
| 1753 | + return var_arrays_inp |
| 1754 | + |
| 1755 | + def attach_grad_intermediate(self): |
| 1756 | + """Attach gradient to all the intermediate variables. |
| 1757 | + """ |
| 1758 | + for val in self._nleaf_vars.values(): |
| 1759 | + val.data().attach_grad(grad_req=val.grad_req) |
| 1760 | + |
| 1761 | + def get_intermediate(self, names): |
| 1762 | + """Get the intermediate variables by names |
| 1763 | + """ |
| 1764 | + if isinstance(names, list): |
| 1765 | + return [self._nleaf_vars[n] for n in names] |
| 1766 | + else: |
| 1767 | + return self._nleaf_vars[names] |
| 1768 | + |
1681 | 1769 | class SymbolBlock(HybridBlock): |
1682 | 1770 | """Construct block from symbol. This is useful for using pre-trained models |
1683 | 1771 | as feature extractors. For example, you may want to extract the output |
|
0 commit comments