Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 5be12f7

Browse files
committed
self consistent cpp addition
1 parent bbb258a commit 5be12f7

File tree

6 files changed

+39
-0
lines changed

6 files changed

+39
-0
lines changed

include/mxnet/c_api.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,13 @@ MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var,
12761276
NDArrayHandle* var_handles,
12771277
uint32_t* reqs_array,
12781278
NDArrayHandle* grad_handles);
1279+
/*!
1280+
* \brief mark nonleaf NDArrays as variables during deferredcomputation
1281+
* \param num_nleafs number of nonleaf NDArrays
1282+
* \param cnt_var count of existing marked nonleaf variables
1283+
* \return 0 when success, -1 when failure happens
1284+
*/
1285+
MXNET_DLL int MXNDArrayMarkDCVariables(NDArrayHandle* nleaf_handles, int num_nleafs, int cnt_var);
12791286
/*!
12801287
* \brief unmark nonleaf NDArrays to free the memory
12811288
* \param num_var number of variable NDArrays

include/mxnet/imperative.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ class Imperative {
290290
void MarkVariables(const std::vector<NDArray*>& variables,
291291
const std::vector<uint32_t>& grad_reqs,
292292
const std::vector<NDArray*>& gradients);
293+
/*! \brief mark nonleaf variables during DC for computing gradients. */
294+
void MarkDCVariables(const std::vector<NDArray*>& nleafs, int cnt_vars);
293295
/*! \brief unmark nonleaf variables to free the memory. */
294296
void DropGrads(const std::vector<NDArray*>& variables);
295297
/*! \brief compute the gradient of outputs w.r.t variables. */

include/mxnet/ndarray.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ class NDArray {
351351
bool fresh_out_grad() const;
352352
/*! \return updated grad state in autograd_entry_ */
353353
void set_fresh_out_grad(bool state) const;
354+
/*! \brief copy the autograd_entry_ from src NDArray */
355+
void copy_autograd_entry_(const NDArray* src);
354356
/*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized
355357
* Throws an exception if the indices array shape is inconsistent
356358
* Returns false if the indices array is empty(nnz = 0) for csr/row_sparse

src/c_api/c_api_ndarray.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,3 +495,15 @@ int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle* output_handles,
495495
*out = s;
496496
API_END_HANDLE_ERROR(delete s;);
497497
}
498+
499+
int MXNDArrayMarkDCVariables(NDArrayHandle* nleaf_handles, int num_nleafs, int cnt_var) {
500+
API_BEGIN();
501+
std::vector<NDArray*> nleafs;
502+
nleafs.reserve(num_nleafs);
503+
for (int i = 0; i < num_nleafs; ++i) {
504+
NDArray* array = reinterpret_cast<NDArray*>(nleaf_handles[i]);
505+
nleafs.emplace_back(array);
506+
}
507+
Imperative::Get()->MarkDCVariables(nleafs, cnt_var);
508+
API_END();
509+
}

src/imperative/imperative.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,18 @@ void Imperative::MarkVariables(const std::vector<NDArray*>& variables,
171171
}
172172
}
173173

174+
void Imperative::MarkDCVariables(const std::vector<NDArray*>& nleafs, int cnt_vars) {
175+
for (NDArray* nleaf : nleafs) {
176+
if (Imperative::DCInfo::IsNone(*nleaf)) {
177+
LOG(WARNING) << "The marked node doesn't have deferred compute history.";
178+
} else {
179+
nnvm::ObjectPtr node = nleaf->deferredcompute_entry_.node;
180+
node->attrs.dict["mark_id"] = std::to_string(cnt_vars);
181+
}
182+
cnt_vars++;
183+
}
184+
}
185+
174186
// Unmark the variables to free the memory.
175187
void Imperative::DropGrads(const std::vector<NDArray*>& variables) {
176188
for (auto variable : variables) {

src/ndarray/ndarray.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,10 @@ void NDArray::set_fresh_out_grad(bool state) const {
513513
info.fresh_out_grad = state;
514514
}
515515

516+
void NDArray::copy_autograd_entry_(const NDArray* src) {
517+
autograd_entry_ = nnvm::NodeEntry{src->autograd_entry_.node, 0, 0};
518+
}
519+
516520
#if MXNET_USE_ONEDNN == 1
517521

518522
bool NDArray::Chunk::IsDNNL() const {

0 commit comments

Comments
 (0)