Skip to content

Commit e69a5f6

Browse files
authored
Fix perf bugs (#167)
* Tackle perf degradation * Revert a wrong fix from #166 * Add `ARK_ENFORCE_KERNEL_CODE_PATH` feature for debugging
1 parent 7f61f7b commit e69a5f6

File tree

8 files changed

+57
-83
lines changed

8 files changed

+57
-83
lines changed

ark/env.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#define DEFAULT_ARK_DISABLE_GRAPH_OPT false
2020
#define DEFAULT_ARK_IGNORE_BINARY_CACHE false
2121
#define DEFAULT_ARK_SHM_NAME_PREFIX "ark."
22+
#define DEFAULT_ARK_ENFORCE_KERNEL_CODE_PATH ""
2223
#define DEFAULT_ARK_USE_MSLL false
2324
#define DEFAULT_ARK_MSLL_INCLUDE_DIR "/usr/local/msll/include"
2425
#define DEFAULT_ARK_MSLL_PORT 50051
@@ -75,6 +76,9 @@ Env::Env() {
7576
//
7677
this->shm_name_prefix =
7778
env<std::string>("ARK_SHM_NAME_PREFIX", DEFAULT_ARK_SHM_NAME_PREFIX);
79+
//
80+
this->enforce_kernel_code_path = env<std::string>(
81+
"ARK_ENFORCE_KERNEL_CODE_PATH", DEFAULT_ARK_ENFORCE_KERNEL_CODE_PATH);
7882
// If `ARK_USE_MSLL=1`, we use MSLL.
7983
this->use_msll = env<bool>("ARK_USE_MSLL", DEFAULT_ARK_USE_MSLL);
8084
// Get the MSLL include directory path.

ark/env.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ struct Env {
3535
bool ignore_binary_cache;
3636
// Prefix of shared memory file names.
3737
std::string shm_name_prefix;
38+
// Enforce to compile a specific kernel code file.
39+
std::string enforce_kernel_code_path;
3840
// Use MSLL.
3941
bool use_msll;
4042
// MSLL include directory path.

ark/gpu/gpu_kernel.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "cpu_timer.h"
1313
#include "env.h"
14+
#include "file_io.h"
1415
#include "gpu/gpu_compile.h"
1516
#include "gpu/gpu_logging.h"
1617
#include "include/ark.h"
@@ -158,7 +159,11 @@ GpuLoopKernel::GpuLoopKernel(const string &name_,
158159

159160
*(GpuPtr *)this->params[0] = this->flag->ref(0);
160161

161-
if (codes_body.size() > 0) {
162+
auto &code_path = get_env().enforce_kernel_code_path;
163+
if (!code_path.empty()) {
164+
LOG(INFO, "Enforce kernel code path: ", code_path);
165+
this->codes.emplace_back(read_file(code_path));
166+
} else if (codes_body.size() > 0) {
162167
const string *ark_loop_body_code = nullptr;
163168
for (auto &code : codes_body) {
164169
if (code.find("ark_loop_body") == string::npos) {

ark/ops/ops_add.cc

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -68,41 +68,7 @@ Tensor *Model::add(Tensor *input, Tensor *other, Tensor *output,
6868
}
6969

7070
const OpConfigMap ArithmeticConfigMap = {
71-
{{OP_ARCH_ANY, "fp32"},
72-
{
73-
// NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost
74-
{8, 0, {{128, 256}, {128, 256}}, {{128, 256}}, false, false},
75-
{8, 0, {{256, 128}, {256, 128}}, {{256, 128}}, false, false},
76-
{8, 0, {{128, 128}, {128, 128}}, {{128, 128}}, false, false},
77-
{4, 0, {{64, 64}, {64, 64}}, {{64, 64}}, false, false},
78-
{2, 0, {{32, 64}, {32, 64}}, {{32, 64}}, false, false},
79-
{1, 0, {{16, 64}, {16, 64}}, {{16, 64}}, false, false},
80-
{1, 0, {{8, 64}, {8, 64}}, {{8, 64}}, false, false},
81-
{1, 0, {{2, 128}, {2, 128}}, {{2, 128}}, false, false},
82-
{1, 0, {{4, 64}, {4, 64}}, {{4, 64}}, false, false},
83-
{1, 0, {{2, 64}, {2, 64}}, {{2, 64}}, false, false},
84-
{1, 0, {{1, 128}, {1, 128}}, {{1, 128}}, false, false},
85-
{1, 0, {{1, 64}, {1, 64}}, {{1, 64}}, false, false},
86-
{1, 0, {{1, 32}, {1, 32}}, {{1, 32}}, false, false},
87-
}},
88-
{{OP_ARCH_ANY, "fp16"},
89-
{
90-
// NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost
91-
{8, 0, {{128, 256}, {128, 256}}, {{128, 256}}, false, false},
92-
{8, 0, {{256, 128}, {256, 128}}, {{256, 128}}, false, false},
93-
{8, 0, {{128, 128}, {128, 128}}, {{128, 128}}, false, false},
94-
{4, 0, {{64, 64}, {64, 64}}, {{64, 64}}, false, false},
95-
{2, 0, {{32, 64}, {32, 64}}, {{32, 64}}, false, false},
96-
{1, 0, {{16, 64}, {16, 64}}, {{16, 64}}, false, false},
97-
{1, 0, {{8, 64}, {8, 64}}, {{8, 64}}, false, false},
98-
{1, 0, {{2, 128}, {2, 128}}, {{2, 128}}, false, false},
99-
{1, 0, {{4, 64}, {4, 64}}, {{4, 64}}, false, false},
100-
{1, 0, {{2, 64}, {2, 64}}, {{2, 64}}, false, false},
101-
{1, 0, {{1, 256}, {1, 256}}, {{1, 256}}, false, false},
102-
{1, 0, {{1, 128}, {1, 128}}, {{1, 128}}, false, false},
103-
{1, 0, {{1, 64}, {1, 64}}, {{1, 64}}, false, false},
104-
}},
105-
{{OP_ARCH_ANY, "bf16"},
71+
{{OP_ARCH_ANY, "any"},
10672
{
10773
// NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost
10874
{8, 0, {{128, 256}, {128, 256}}, {{128, 256}}, false, false},

ark/ops/ops_cast.cc

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,19 +159,9 @@ const OpConfigMap CastConfigMap = {
159159
{{OP_ARCH_ANY, "none"},
160160
{
161161
// NumWarps, SmemBytes, InDepsTiles, OutDepsTiles, SyncPre, SyncPost
162-
{8, 0, {{128, 256}}, {{128, 256}}, false, false},
163-
{8, 0, {{256, 128}}, {{256, 128}}, false, false},
164-
{8, 0, {{128, 128}}, {{128, 128}}, false, false},
165-
{4, 0, {{64, 64}}, {{64, 64}}, false, false},
166-
{2, 0, {{32, 64}}, {{32, 64}}, false, false},
167-
{1, 0, {{16, 64}}, {{16, 64}}, false, false},
168-
{1, 0, {{8, 64}}, {{8, 64}}, false, false},
169-
{1, 0, {{2, 128}}, {{2, 128}}, false, false},
170-
{1, 0, {{4, 64}}, {{4, 64}}, false, false},
171-
{1, 0, {{2, 64}}, {{2, 64}}, false, false},
162+
{1, 0, {{1, 256}}, {{1, 256}}, false, false},
172163
{1, 0, {{1, 128}}, {{1, 128}}, false, false},
173164
{1, 0, {{1, 64}}, {{1, 64}}, false, false},
174-
{1, 0, {{1, 32}}, {{1, 32}}, false, false},
175165
}},
176166
};
177167

ark/sched/sched.cc

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -102,20 +102,16 @@ const OpConfig *BaseScheduler::sched_op_config(const Op *op) {
102102
}
103103
}
104104
// Heuristic auto-selection of granularity level
105-
unsigned int min_wps =
106-
gpu_info.min_threads_per_block / gpu_info.threads_per_warp;
107105
Dims shape4 = output->shape.dims4();
108106
Dims ldims4 = output->ldims.dims4();
107+
DimType shape_x = shape4[2];
108+
DimType shape_y = shape4[3];
109109
std::vector<std::tuple<const OpConfig *, Dims, int>> config_candidates;
110-
std::vector<std::tuple<const OpConfig *, Dims, int>>
111-
high_priority_candidates;
112110
for (auto &cfg : feasible_configs) {
113111
assert(cfg->output_tiles.size() > 0);
114112
const OpTile &ot = cfg->output_tiles[0];
115113
DimType ot_x = (ot.x == -1) ? ldims4[2] : ot.x;
116114
DimType ot_y = (ot.y == -1) ? ldims4[3] : ot.y;
117-
DimType shape_x = shape4[2];
118-
DimType shape_y = shape4[3];
119115
if (output->shape.ndims() == 1 && ot_x != 1) {
120116
// Output is 1D, but tile is 2D. Cannot use this tile shape.
121117
continue;
@@ -126,13 +122,6 @@ const OpConfig *BaseScheduler::sched_op_config(const Op *op) {
126122

127123
// This config is OK to use
128124
config_candidates.emplace_back(cfg, Dims(ot_x, ot_y), num_tiles);
129-
130-
// magic condition
131-
if ((shape_y * 2 > ot_y) && (shape_x * 2 > ot_x) &&
132-
((num_tiles * cfg->num_warps) >= (min_wps * gpu_info.num_sm / 2))) {
133-
high_priority_candidates.emplace_back(cfg, Dims(ot_x, ot_y),
134-
num_tiles);
135-
}
136125
}
137126
if (config_candidates.empty()) {
138127
stringstream configs_str;
@@ -152,14 +141,33 @@ const OpConfig *BaseScheduler::sched_op_config(const Op *op) {
152141
ERR(SchedulerError, "no valid tile configuration found. Output shape ",
153142
output->shape, ", available tiles: ", configs_str.str());
154143
}
144+
std::vector<std::tuple<const OpConfig *, Dims, int>>
145+
high_priority_candidates;
146+
int min_wps = gpu_info.min_threads_per_block / gpu_info.threads_per_warp;
147+
int target_concurrent_num_warps = min_wps * gpu_info.num_sm;
148+
for (auto &c : config_candidates) {
149+
auto &cfg = std::get<0>(c);
150+
auto &tile = std::get<1>(c);
151+
auto &num_tiles = std::get<2>(c);
152+
153+
if ((shape_x < tile[0]) || (shape_y < tile[1])) {
154+
// too large tile.
155+
continue;
156+
}
157+
auto num_total_warps = num_tiles * cfg->num_warps;
158+
if (num_total_warps >= target_concurrent_num_warps / 2) {
159+
high_priority_candidates.push_back(c);
160+
}
161+
}
155162
auto &candidates = high_priority_candidates.empty()
156163
? config_candidates
157164
: high_priority_candidates;
158-
// prefer smaller tiles here to minimize paddings
165+
159166
std::sort(candidates.begin(), candidates.end(),
160167
[](const std::tuple<const OpConfig *, Dims, int> &a,
161168
const std::tuple<const OpConfig *, Dims, int> &b) {
162-
return std::get<1>(a).size() < std::get<1>(b).size();
169+
return std::get<2>(a) * std::get<0>(a)->num_warps <
170+
std::get<2>(b) * std::get<0>(b)->num_warps;
163171
});
164172
return std::get<0>(candidates[0]);
165173
}

ark/sched/sched/sched_default.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -621,10 +621,8 @@ std::vector<std::string> DefaultScheduler::gen_code() {
621621
auto comp_streams = this->comp_stream[i]->get_streams();
622622
for (size_t j = 0; j < comp_streams.size(); ++j) {
623623
auto &stream = comp_streams[j];
624-
int prev_sm_id_end = -1;
625624
for (auto &branch : stream.branches) {
626-
this->codegen->branch(code, branch, prev_sm_id_end);
627-
prev_sm_id_end = branch.sm_id_end;
625+
this->codegen->branch(code, branch);
628626
}
629627
if (!stream.branches.empty() && j != comp_streams.size() - 1) {
630628
code << " ";
@@ -634,10 +632,8 @@ std::vector<std::string> DefaultScheduler::gen_code() {
634632
auto comm_streams = this->comm_stream[i]->get_streams();
635633
for (size_t j = 0; j < comm_streams.size(); ++j) {
636634
auto &stream = comm_streams[j];
637-
int prev_sm_id_end = -1;
638635
for (auto &branch : stream.branches) {
639-
this->codegen->branch(code, branch, prev_sm_id_end);
640-
prev_sm_id_end = branch.sm_id_end;
636+
this->codegen->branch(code, branch);
641637
}
642638
if (!stream.branches.empty() && j != comm_streams.size() - 1) {
643639
code << " ";

examples/llama/model_test.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def test_module(
128128
module_name_prefix: str = "",
129129
test_thru: bool = False,
130130
test_thru_iterations: int = 100,
131+
test_thru_ark_only: bool = False,
131132
):
132133
if test_thru:
133134
print(f"Throughput test (iterations: {test_thru_iterations})")
@@ -164,21 +165,25 @@ def test_module(
164165
iterations=test_thru_iterations if test_thru else 1,
165166
)
166167

167-
# PyTorch module
168-
module_pt: torch.nn.Module = module_class_pt(*module_args_pt)
168+
if not test_thru_ark_only:
169+
# PyTorch module
170+
module_pt: torch.nn.Module = module_class_pt(*module_args_pt)
169171

170-
# Run the PyTorch module
171-
res_pt = run_pt(
172-
module_pt,
173-
state_dict_pt,
174-
inputs_pt,
175-
iterations=test_thru_iterations if test_thru else 1,
176-
)
177-
178-
if test_thru:
179-
print(
180-
f" PyTorch: {res_pt.runtime:.4f} seconds, ARK: {res_ark.runtime:.4f} seconds"
172+
# Run the PyTorch module
173+
res_pt = run_pt(
174+
module_pt,
175+
state_dict_pt,
176+
inputs_pt,
177+
iterations=test_thru_iterations if test_thru else 1,
181178
)
179+
180+
if test_thru:
181+
print(
182+
f" PyTorch: {res_pt.runtime:.4f} seconds, ARK: {res_ark.runtime:.4f} seconds"
183+
)
184+
return
185+
elif test_thru:
186+
print(f" ARK: {res_ark.runtime:.4f} seconds")
182187
return
183188

184189
# Compare the outputs
@@ -454,8 +459,6 @@ def test_transformer(
454459
module_class_pt=model_pt.Transformer,
455460
module_args_pt=[args],
456461
inputs_pt=[tokens, start_pos],
457-
test_thru=True,
458-
test_thru_iterations=10,
459462
)
460463

461464

0 commit comments

Comments
 (0)