Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ class EvictionPolicy(NamedTuple):
eviction_free_mem_check_interval_batch: Optional[int] = (
None # Number of batches between checks for free memory threshold when using free_mem trigger mode.
)
enable_eviction_for_feature_score_eviction_policy: Optional[list[bool]] = (
None # enable eviction if eviction policy is feature score, false means no eviction
)

def validate(self) -> None:
assert self.eviction_trigger_mode in [0, 1, 2, 3, 4, 5], (
Expand Down Expand Up @@ -217,13 +220,17 @@ def validate(self) -> None:
"threshold_calculation_bucket_num must be set if eviction_strategy is 5,"
f"actual {self.threshold_calculation_bucket_num}"
)
assert self.enable_eviction_for_feature_score_eviction_policy is not None, (
"enable_eviction_for_feature_score_eviction_policy must be set if eviction_strategy is 5,"
f"actual {self.enable_eviction_for_feature_score_eviction_policy}"
)
assert (
len(self.training_id_keep_count)
len(self.enable_eviction_for_feature_score_eviction_policy)
== len(self.training_id_keep_count)
== len(self.feature_score_counter_decay_rates)
== len(self.training_id_eviction_trigger_count)
), (
"feature_score_thresholds, training_id_eviction_trigger_count and training_id_keep_count must have the same length, "
f"actual {self.training_id_keep_count} vs {self.feature_score_counter_decay_rates} vs {self.training_id_eviction_trigger_count}"
"feature_score_thresholds, enable_eviction_for_feature_score_eviction_policy, and training_id_keep_count must have the same length, "
f"actual {self.training_id_keep_count} vs {self.feature_score_counter_decay_rates} vs {self.enable_eviction_for_feature_score_eviction_policy}"
)


Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,15 @@ def __init__(
# If trigger mode is free_mem(5), populate config
self.set_free_mem_eviction_trigger_config(eviction_policy)

enable_eviction_for_feature_score_eviction_policy = ( # pytorch api in c++ doesn't support vertor<bool>, convert to int here, 0: no eviction 1: eviction
[
int(x)
for x in eviction_policy.enable_eviction_for_feature_score_eviction_policy
]
if eviction_policy.enable_eviction_for_feature_score_eviction_policy
is not None
else None
)
# Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count
Expand All @@ -719,6 +728,7 @@ def __init__(
eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table
eviction_policy.training_id_keep_count, # training_id_keep_count for each table
enable_eviction_for_feature_score_eviction_policy, # no eviction setting for feature score eviction policy
eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
table_dims.tolist() if table_dims is not None else None,
eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ void DramKVEmbeddingInferenceWrapper::init(
std::nullopt /* feature_score_counter_decay_rates */,
std::nullopt /* training_id_eviction_trigger_count */,
std::nullopt /* training_id_keep_count */,
std::nullopt /* enable_eviction_for_feature_score_eviction_policy */,
std::nullopt /* l2_weight_thresholds */,
std::nullopt /* embedding_dims */,
std::nullopt /* threshold_calculation_bucket_stride */,
Expand Down
54 changes: 52 additions & 2 deletions fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
std::optional<std::vector<double>> feature_score_counter_decay_rates,
std::optional<std::vector<int64_t>> training_id_eviction_trigger_count,
std::optional<std::vector<int64_t>> training_id_keep_count,
std::optional<std::vector<int8_t>>
enable_eviction_for_feature_score_eviction_policy, // 0: no eviction,
// 1: evict
std::optional<std::vector<double>> l2_weight_thresholds,
std::optional<std::vector<int64_t>> embedding_dims,
std::optional<double> threshold_calculation_bucket_stride = 0.2,
Expand All @@ -129,6 +132,8 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
training_id_eviction_trigger_count_(
std::move(training_id_eviction_trigger_count)),
training_id_keep_count_(std::move(training_id_keep_count)),
enable_eviction_for_feature_score_eviction_policy_(
std::move(enable_eviction_for_feature_score_eviction_policy)),
l2_weight_thresholds_(l2_weight_thresholds),
embedding_dims_(embedding_dims),
threshold_calculation_bucket_stride_(
Expand Down Expand Up @@ -169,10 +174,17 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
CHECK(
training_id_eviction_trigger_count_.has_value() &&
!training_id_eviction_trigger_count_.value().empty());
CHECK(enable_eviction_for_feature_score_eviction_policy_.has_value());
const auto& enable_eviction_vec =
enable_eviction_for_feature_score_eviction_policy_.value();
const auto& vec = training_id_eviction_trigger_count_.value();
eviction_trigger_stats_log = ", training_id_eviction_trigger_count: [";
total_id_eviction_trigger_count_ = 0;
for (size_t i = 0; i < vec.size(); ++i) {
if (enable_eviction_vec[i] == 0) {
throw std::runtime_error(
"ID_COUNT trigger mode doesn't not support enable_eviction=False, please use FREE_MEM trigger mode instead");
}
total_id_eviction_trigger_count_ =
total_id_eviction_trigger_count_.value() + vec[i];
if (vec[i] <= 0) {
Expand Down Expand Up @@ -212,6 +224,7 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
CHECK(threshold_calculation_bucket_stride_.has_value());
CHECK(threshold_calculation_bucket_num_.has_value());
CHECK(ttls_in_mins_.has_value());
CHECK(enable_eviction_for_feature_score_eviction_policy_.has_value());
LOG(INFO) << "eviction config, trigger mode:"
<< to_string(trigger_mode_) << eviction_trigger_stats_log
<< ", strategy: " << to_string(trigger_strategy_)
Expand All @@ -223,7 +236,9 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
<< ", threshold_calculation_bucket_num: "
<< threshold_calculation_bucket_num_.value()
<< ", feature_score_counter_decay_rates: "
<< feature_score_counter_decay_rates_.value();
<< feature_score_counter_decay_rates_.value()
<< ", enable_eviction_for_feature_score_eviction_policy: "
<< enable_eviction_for_feature_score_eviction_policy_.value();
return;
}

Expand Down Expand Up @@ -281,6 +296,8 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
std::optional<std::vector<double>> feature_score_counter_decay_rates_;
std::optional<std::vector<int64_t>> training_id_eviction_trigger_count_;
std::optional<std::vector<int64_t>> training_id_keep_count_;
std::optional<std::vector<int8_t>>
enable_eviction_for_feature_score_eviction_policy_;
std::optional<int64_t> total_id_eviction_trigger_count_;
std::optional<std::vector<double>> l2_weight_thresholds_;
std::optional<std::vector<int64_t>> embedding_dims_;
Expand Down Expand Up @@ -984,6 +1001,8 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
const std::vector<int64_t>& training_id_eviction_trigger_count,
const std::vector<int64_t>& training_id_keep_count,
const std::vector<int64_t>& ttls_in_mins,
const std::vector<int8_t>&
enable_eviction_for_feature_score_eviction_policy,
const double threshold_calculation_bucket_stride,
const int64_t threshold_calculation_bucket_num,
int64_t interval_for_insufficient_eviction_s,
Expand All @@ -1003,6 +1022,8 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
training_id_eviction_trigger_count_(training_id_eviction_trigger_count),
training_id_keep_count_(training_id_keep_count),
ttls_in_mins_(ttls_in_mins),
enable_eviction_for_feature_score_eviction_policy_(
enable_eviction_for_feature_score_eviction_policy),
threshold_calculation_bucket_stride_(
threshold_calculation_bucket_stride),
num_buckets_(threshold_calculation_bucket_num),
Expand Down Expand Up @@ -1071,6 +1092,13 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
protected:
bool evict_block(weight_type* block, int sub_table_id, int shard_id)
override {
int8_t enable_eviction =
enable_eviction_for_feature_score_eviction_policy_[sub_table_id];
if (enable_eviction == 0) {
// If enable_eviction is set to 0, we don't evict any block.
return false;
}

double ttls_threshold = ttls_in_mins_[sub_table_id];
if (ttls_threshold > 0) {
auto current_time = FixedBlockPool::current_timestamp();
Expand Down Expand Up @@ -1145,6 +1173,15 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {

void compute_thresholds_from_buckets() {
for (size_t table_id = 0; table_id < num_tables_; ++table_id) {
int8_t enable_eviction =
enable_eviction_for_feature_score_eviction_policy_[table_id];
if (enable_eviction == 0) {
// If enable_eviction is set to 0, we don't evict any block.
thresholds_[table_id] = 0.0;
evict_modes_[table_id] = EvictMode::NONE;
continue;
}

int64_t total = 0;

if (ttls_in_mins_[table_id] > 0) {
Expand Down Expand Up @@ -1209,7 +1246,8 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
<< " threshold bucket: " << threshold_bucket
<< " actual evict count: " << acc_count
<< " target evict count: " << evict_count
<< " total count: " << total;
<< " total count: " << total
<< " evict mode: " << to_string(evict_modes_[table_id]);

for (int table_id = 0; table_id < num_tables_; ++table_id) {
this->metrics_.eviction_threshold_with_dry_run[table_id] =
Expand All @@ -1226,6 +1264,16 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
THRESHOLD // blocks with scores below the computed threshold will be
// evicted
};
inline std::string to_string(EvictMode mode) {
switch (mode) {
case EvictMode::NONE:
return "NONE";
case EvictMode::ONLY_ZERO:
return "ONLY_ZERO";
case EvictMode::THRESHOLD:
return "THRESHOLD";
}
}
std::vector<EvictMode> evict_modes_;

const int num_tables_ = static_cast<int>(this->sub_table_hash_cumsum_.size());
Expand All @@ -1240,6 +1288,7 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
// eviction.

const std::vector<int64_t>& ttls_in_mins_; // Time-to-live for eviction.
const std::vector<int8_t>& enable_eviction_for_feature_score_eviction_policy_;
std::vector<std::vector<std::vector<size_t>>>
local_buckets_per_shard_per_table_;
std::vector<std::vector<size_t>> local_blocks_num_per_shard_per_table_;
Expand Down Expand Up @@ -1489,6 +1538,7 @@ std::unique_ptr<FeatureEvict<weight_type>> create_feature_evict(
config->training_id_eviction_trigger_count_.value(),
config->training_id_keep_count_.value(),
config->ttls_in_mins_.value(),
config->enable_eviction_for_feature_score_eviction_policy_.value(),
config->threshold_calculation_bucket_stride_.value(),
config->threshold_calculation_bucket_num_.value(),
config->interval_for_insufficient_eviction_s_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ static auto feature_evict_config =
std::optional<std::vector<double>>,
std::optional<std::vector<int64_t>>,
std::optional<std::vector<int64_t>>,
std::optional<std::vector<int8_t>>,
std::optional<std::vector<double>>,
std::optional<std::vector<int64_t>>,
std::optional<double>,
Expand All @@ -756,6 +757,9 @@ static auto feature_evict_config =
torch::arg("feature_score_counter_decay_rates") = std::nullopt,
torch::arg("training_id_eviction_trigger_count") = std::nullopt,
torch::arg("training_id_keep_count") = std::nullopt,
torch::arg(
"enable_eviction_for_feature_score_eviction_policy") =
std::nullopt,
torch::arg("l2_weight_thresholds") = std::nullopt,
torch::arg("embedding_dims") = std::nullopt,
torch::arg("threshold_calculation_bucket_stride") = 0.2,
Expand Down
Loading
Loading