Skip to content

Commit 4f8977e

Browse files
authored
Merge pull request #1566 from stan-dev/soa-explanations
Add explanations to `--debug-mem-patterns`
2 parents a58dba6 + afe5255 commit 4f8977e

File tree

4 files changed

+234
-44
lines changed

4 files changed

+234
-44
lines changed

src/analysis_and_optimization/Memory_patterns.ml

Lines changed: 143 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,30 @@ open Core
22
open Core.Poly
33
open Middle
44

5+
type demotion = int * Mem_pattern.t * string [@@deriving compare]
6+
7+
let demotion_reasons = ref []
8+
9+
let get_warnings () =
10+
let mem_name pattern =
11+
match pattern with Mem_pattern.SoA -> "SoA" | AoS -> "AoS" in
12+
!demotion_reasons
13+
|> List.dedup_and_sort ~compare:compare_demotion
14+
|> List.map ~f:(fun (linenum, pattern, msg) ->
15+
Printf.sprintf "Optimization hazard warning (Line %i): %s warning: %s"
16+
linenum (mem_name pattern) msg)
17+
18+
let user_warning_op (mem_pattern : Mem_pattern.t) (linenum : int) (msg : string)
19+
(names : string) =
20+
if not (String.is_empty names || String.is_empty msg) then
21+
demotion_reasons :=
22+
(linenum, mem_pattern, msg ^ " " ^ names) :: !demotion_reasons
23+
24+
let concat_set_str (set : string Set.Poly.t) =
25+
Set.fold
26+
~f:(fun acc elem -> if acc = "" then acc ^ elem else acc ^ ", " ^ elem)
27+
~init:"" set
28+
529
(**
630
Return a Var expression of the name for each type
731
containing an eigen matrix
@@ -98,7 +122,7 @@ let query_stan_math_mem_pattern_support (name : string)
98122
Frontend.SignatureMismatch.check_compatible_arguments_mod_conv x args
99123
|> Result.is_ok)
100124
namematches in
101-
let is_soa = function _, _, _, Mem_pattern.SoA -> true | _ -> false in
125+
let is_soa (_, _, _, p) = p = Mem_pattern.SoA in
102126
List.exists ~f:is_soa filteredmatches
103127

104128
(*Validate whether a function can support SoA matrices*)
@@ -116,13 +140,13 @@ let is_fun_soa_supported name exprs =
116140
will be returned if the matrix or vector is accessed by single
117141
cell indexing.
118142
*)
119-
let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
120-
Expr.{pattern; _} : string Set.Poly.t =
143+
let rec query_initial_demotable_expr (in_loop : bool) (stmt_linenum : int)
144+
~(acc : string Set.Poly.t) Expr.{pattern; _} : string Set.Poly.t =
121145
let query_expr (accum : string Set.Poly.t) =
122-
query_initial_demotable_expr in_loop ~acc:accum in
146+
query_initial_demotable_expr in_loop stmt_linenum ~acc:accum in
123147
match pattern with
124148
| FunApp (kind, (exprs : Expr.Typed.t list)) ->
125-
query_initial_demotable_funs in_loop acc kind exprs
149+
query_initial_demotable_funs in_loop stmt_linenum acc kind exprs
126150
| Indexed ((Expr.{meta= {type_; _}; _} as expr), indexed) ->
127151
let index_set =
128152
Set.Poly.union_list
@@ -132,18 +156,29 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
132156
(query_expr acc))
133157
indexed) in
134158
let index_demotes =
135-
if is_uni_eigen_loop_indexing in_loop type_ indexed then
136-
Set.union (query_var_eigen_names expr) index_set
159+
if is_uni_eigen_loop_indexing in_loop type_ indexed then (
160+
let single_index_set = query_var_eigen_names expr in
161+
let failure_str = concat_set_str (Set.inter acc single_index_set) in
162+
let msg = "Accessed by element in a for loop:" in
163+
user_warning_op SoA stmt_linenum msg failure_str;
164+
Set.union single_index_set index_set)
137165
else Set.union (query_expr acc expr) index_set in
138166
Set.union acc index_demotes
139167
| Var (_ : string) | Lit ((_ : Expr.Pattern.litType), (_ : string)) -> acc
140168
| Promotion (expr, _, _) -> query_expr acc expr
141169
| TupleProjection (expr, _) -> query_expr acc expr
142170
| TernaryIf (predicate, texpr, fexpr) ->
143171
let predicate_demotes = query_expr acc predicate in
144-
Set.union
145-
(Set.union predicate_demotes (query_var_eigen_names texpr))
146-
(query_var_eigen_names fexpr)
172+
let full_set =
173+
Set.union
174+
(Set.union predicate_demotes (query_var_eigen_names texpr))
175+
(query_var_eigen_names fexpr) in
176+
if Set.is_empty full_set then full_set
177+
else
178+
let failure_str = concat_set_str (Set.inter acc full_set) in
179+
let msg = "Used in a ternary operator which is not allowed:" in
180+
user_warning_op SoA stmt_linenum msg failure_str;
181+
full_set
147182
| EAnd (lhs, rhs) | EOr (lhs, rhs) ->
148183
(*We need to get the demotes from both sides*)
149184
let full_lhs_rhs = Set.union (query_expr acc lhs) (query_expr acc rhs) in
@@ -166,9 +201,11 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
166201
to the UDF.
167202
exprs The expression list passed to the functions.
168203
*)
169-
and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
170-
(kind : 'a Fun_kind.t) (exprs : Expr.Typed.t list) : string Set.Poly.t =
171-
let query_expr accum = query_initial_demotable_expr in_loop ~acc:accum in
204+
and query_initial_demotable_funs (in_loop : bool) (stmt_linenum : int)
205+
(acc : string Set.Poly.t) (kind : 'a Fun_kind.t) (exprs : Expr.Typed.t list)
206+
: string Set.Poly.t =
207+
let query_expr accum =
208+
query_initial_demotable_expr in_loop stmt_linenum ~acc:accum in
172209
let top_level_eigen_names =
173210
Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in
174211
let demoted_eigen_names = List.fold ~init:acc ~f:query_expr exprs in
@@ -181,11 +218,26 @@ and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
181218
| name -> (
182219
match is_fun_soa_supported name exprs with
183220
| true -> Set.union acc demoted_eigen_names
184-
| false -> Set.union acc demoted_and_top_level_names))
221+
| false ->
222+
let fail_names =
223+
concat_set_str (Set.inter acc top_level_eigen_names) in
224+
user_warning_op SoA stmt_linenum
225+
("Function " ^ name ^ " is not supported:")
226+
fail_names;
227+
Set.union acc demoted_and_top_level_names))
185228
| CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec | FnMakeTuple) ->
229+
let fail_names =
230+
concat_set_str (Set.inter acc demoted_and_top_level_names) in
231+
user_warning_op SoA stmt_linenum
232+
"Used in {} make array or make row vector compiler functions:"
233+
fail_names;
186234
Set.union acc demoted_and_top_level_names
187235
| CompilerInternal (_ : 'a Internal_fun.t) -> acc
188236
| UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) ->
237+
let fail_names =
238+
concat_set_str (Set.inter acc demoted_and_top_level_names) in
239+
user_warning_op SoA stmt_linenum "Used in user defined function:"
240+
fail_names;
189241
Set.union acc demoted_and_top_level_names
190242

191243
(**
@@ -283,9 +335,10 @@ let contains_at_least_one_ad_matrix_or_all_data
283335
[query_initial_demotable_expr] for an explanation of the logic.
284336
*)
285337
let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
286-
(Stmt.{pattern; _} : Stmt.Located.t) : string Set.Poly.t =
338+
(Stmt.{pattern; meta} : Stmt.Located.t) : string Set.Poly.t =
339+
let linenum = meta.end_loc.line_num in
287340
let query_expr (accum : string Set.Poly.t) =
288-
query_initial_demotable_expr in_loop ~acc:accum in
341+
query_initial_demotable_expr in_loop linenum ~acc:accum in
289342
match pattern with
290343
| Stmt.Pattern.Assignment
291344
( lval
@@ -299,21 +352,25 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
299352
List.fold ~init:acc
300353
~f:(fun accum x ->
301354
Index.folder accum
302-
(fun acc -> query_initial_demotable_expr in_loop ~acc)
355+
(fun acc -> query_initial_demotable_expr in_loop linenum ~acc)
303356
x)
304357
idx in
305358
match is_uni_eigen_loop_indexing in_loop ut idx with
306-
| true -> Set.add idx_list name
359+
| true ->
360+
user_warning_op SoA linenum "Accessed by element in a for loop:"
361+
(if Set.mem acc name then "" else name);
362+
Set.add idx_list name
307363
| false -> idx_list in
308364
let rhs_demotable_names = query_expr acc rhs in
309365
let rhs_and_idx_demotions = Set.union idx_demotable rhs_demotable_names in
310366
(* RHS (1)*)
311367
let tuple_demotions =
312368
match lval with
313369
| LTupleProjection _, _ ->
314-
Set.add
315-
(Set.union rhs_and_idx_demotions (query_var_eigen_names rhs))
316-
name
370+
let tuple_set = query_var_eigen_names rhs in
371+
let fail_set = concat_set_str tuple_set in
372+
user_warning_op SoA linenum "Used in tuple:" fail_set;
373+
Set.add (Set.union rhs_and_idx_demotions tuple_set) name
317374
| _ -> rhs_and_idx_demotions in
318375
let assign_demotions =
319376
let is_eigen_stmt = UnsizedType.contains_eigen_type rhs.meta.type_ in
@@ -327,30 +384,52 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
327384
(extract_nonderived_admatrix_types rhs))
328385
| _ -> false in
329386
(* LHS (3) rhs unsupported function*)
330-
let is_not_supported_func =
387+
let non_supported_func_name =
331388
match rhs.pattern with
332-
| FunApp (UserDefined _, _) -> true
333-
| FunApp (CompilerInternal _, _) -> false
334-
| FunApp (StanLib (name, _, _), exprs) ->
335-
not
336-
(query_stan_math_mem_pattern_support name
337-
(List.map ~f:Expr.Typed.fun_arg exprs))
338-
| _ -> false in
389+
| FunApp (UserDefined (name, _), _) -> Some name
390+
| FunApp (StanLib (name, _, _), exprs)
391+
when not
392+
(query_stan_math_mem_pattern_support name
393+
(List.map ~f:Expr.Typed.fun_arg exprs)) ->
394+
Some name
395+
| _ -> None in
339396
(* LHS (3) all rhs aos*)
340397
let is_all_rhs_aos =
341398
is_nonzero_subset
342399
~subset:(query_var_eigen_names rhs)
343400
~set:rhs_demotable_names in
344401
if
345402
is_all_rhs_aos || is_rhs_not_promoteable_to_soa
346-
|| is_not_supported_func
347-
then
348-
Set.add (Set.union tuple_demotions (query_var_eigen_names rhs)) name
403+
|| Option.is_some non_supported_func_name
404+
then (
405+
let rhs_set = query_var_eigen_names rhs in
406+
let all_rhs_warn =
407+
if is_all_rhs_aos then "Right hand side of assignment is all AoS:"
408+
else "" in
409+
let rhs_not_promotable_to_soa_warn =
410+
if is_rhs_not_promoteable_to_soa then
411+
"The right hand side of the assignment only contains data and \
412+
scalar operations that are not promotable to SoA:"
413+
else "" in
414+
let not_supported_func_warn =
415+
match non_supported_func_name with
416+
| Some fname ->
417+
"Function '" ^ fname
418+
^ "' on right hand side of assignment is not supported by \
419+
SoA:"
420+
| None -> "" in
421+
let rhs_name_set = Set.add rhs_set name in
422+
let rhs_name_set_str = concat_set_str rhs_name_set in
423+
user_warning_op SoA linenum all_rhs_warn rhs_name_set_str;
424+
user_warning_op SoA linenum rhs_not_promotable_to_soa_warn
425+
rhs_name_set_str;
426+
user_warning_op SoA linenum not_supported_func_warn rhs_name_set_str;
427+
Set.add (Set.union tuple_demotions rhs_set) name)
349428
else tuple_demotions
350429
else tuple_demotions in
351430
Set.union acc assign_demotions
352431
| NRFunApp (kind, exprs) ->
353-
query_initial_demotable_funs in_loop acc kind exprs
432+
query_initial_demotable_funs in_loop linenum acc kind exprs
354433
| IfElse (predicate, true_stmt, op_false_stmt) ->
355434
let predicate_acc = query_expr acc predicate in
356435
Set.union acc
@@ -386,7 +465,10 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
386465
| Decl {decl_type= Type.Sized st; decl_id; initialize; _} ->
387466
let complex_name =
388467
match SizedType.is_complex_type st with
389-
| true -> Set.Poly.singleton decl_id
468+
| true ->
469+
user_warning_op SoA linenum "Complex-valued types cannot be SoA:"
470+
decl_id;
471+
Set.Poly.singleton decl_id
390472
| false -> Set.Poly.empty in
391473
let init_names =
392474
match initialize with
@@ -408,24 +490,44 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
408490
@param pattern The Stmt pattern to query.
409491
*)
410492
let query_demotable_stmt (aos_exits : string Set.Poly.t)
411-
(pattern : (Expr.Typed.t, int) Stmt.Pattern.t) : string Set.Poly.t =
412-
match pattern with
493+
(stmt : Stmt.Located.Non_recursive.t) : string Set.Poly.t =
494+
let linenum = stmt.meta.end_loc.line_num in
495+
match stmt.pattern with
413496
| Stmt.Pattern.Assignment (lval, (_ : UnsizedType.t), (rhs : Expr.Typed.t))
414497
-> (
415498
let assign_name = Stmt.Helpers.lhs_variable lval in
416499
let all_rhs_eigen_names = query_var_eigen_names rhs in
417-
if Set.mem aos_exits assign_name then
418-
Set.add all_rhs_eigen_names assign_name
500+
if Set.mem aos_exits assign_name then (
501+
user_warning_op SoA linenum
502+
"Right hand side contains only AoS expressions:" assign_name;
503+
Set.add all_rhs_eigen_names assign_name)
419504
else
420505
match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with
421-
| true -> Set.add all_rhs_eigen_names assign_name
506+
| true ->
507+
let warn =
508+
Fmt.(
509+
str "Right hand side contains AoS expressions (%s):"
510+
(concat_set_str (Set.inter aos_exits all_rhs_eigen_names)))
511+
in
512+
user_warning_op SoA linenum warn assign_name;
513+
Set.add all_rhs_eigen_names assign_name
422514
| false -> Set.Poly.empty)
423515
| Decl {decl_id; initialize= Assign e; _} -> (
424516
let all_rhs_eigen_names = query_var_eigen_names e in
425-
if Set.mem aos_exits decl_id then Set.add all_rhs_eigen_names decl_id
517+
if Set.mem aos_exits decl_id then (
518+
user_warning_op SoA linenum
519+
"Right hand side contains only AoS expressions:" decl_id;
520+
Set.add all_rhs_eigen_names decl_id)
426521
else
427522
match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with
428-
| true -> Set.add all_rhs_eigen_names decl_id
523+
| true ->
524+
let warn =
525+
Fmt.(
526+
str "Right hand side contains AoS expressions (%s):"
527+
(concat_set_str (Set.inter aos_exits all_rhs_eigen_names)))
528+
in
529+
user_warning_op SoA linenum warn decl_id;
530+
Set.add all_rhs_eigen_names decl_id
429531
| false -> Set.Poly.empty)
430532
(* All other statements do not need logic here*)
431533
| _ -> Set.Poly.empty

src/analysis_and_optimization/Optimize.ml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,8 +1228,7 @@ let optimize_soa (mir : Program.Typed.t) =
12281228
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t)
12291229
(l : int) (aos_variables : string Set.Poly.t) =
12301230
let mir_node mir_idx = Map.find_exn flowgraph_to_mir mir_idx in
1231-
match (mir_node l).pattern with
1232-
| stmt -> Memory_patterns.query_demotable_stmt aos_variables stmt in
1231+
Memory_patterns.query_demotable_stmt aos_variables (mir_node l) in
12331232
let initial_variables =
12341233
List.fold ~init:Set.Poly.empty
12351234
~f:(Memory_patterns.query_initial_demotable_stmt false)

src/driver/Entry.ml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,11 @@ let stan2cpp model_name model (flags : Flags.t) (output : other_output -> unit)
127127
tx_mir in
128128
if flags.debug_settings.print_mem_patterns then
129129
output
130-
(Memory_patterns (Fmt.str "%a" Memory_patterns.pp_mem_patterns opt_mir));
130+
(Memory_patterns
131+
(Fmt.str "%a%a@\n" Memory_patterns.pp_mem_patterns opt_mir
132+
(* TODO should be better associated with the names from above? *)
133+
Fmt.(list string)
134+
(Memory_patterns.get_warnings ())));
131135
debug_output_mir output opt_mir flags.debug_settings.print_optimized_mir;
132136
let cpp =
133137
Lower_program.lower_program

0 commit comments

Comments
 (0)