@@ -2,6 +2,30 @@ open Core
22open Core.Poly
33open Middle
44
5+ type demotion = int * Mem_pattern .t * string [@@ deriving compare ]
6+
7+ let demotion_reasons = ref []
8+
9+ let get_warnings () =
10+ let mem_name pattern =
11+ match pattern with Mem_pattern. SoA -> " SoA" | AoS -> " AoS" in
12+ ! demotion_reasons
13+ |> List. dedup_and_sort ~compare: compare_demotion
14+ |> List. map ~f: (fun (linenum , pattern , msg ) ->
15+ Printf. sprintf " Optimization hazard warning (Line %i): %s warning: %s"
16+ linenum (mem_name pattern) msg)
17+
18+ let user_warning_op (mem_pattern : Mem_pattern.t ) (linenum : int ) (msg : string )
19+ (names : string ) =
20+ if not (String. is_empty names || String. is_empty msg) then
21+ demotion_reasons :=
22+ (linenum, mem_pattern, msg ^ " " ^ names) :: ! demotion_reasons
23+
24+ let concat_set_str (set : string Set.Poly.t ) =
25+ Set. fold
26+ ~f: (fun acc elem -> if acc = " " then acc ^ elem else acc ^ " , " ^ elem)
27+ ~init: " " set
28+
529(* *
630 Return a Var expression of the name for each type
731 containing an eigen matrix
@@ -98,7 +122,7 @@ let query_stan_math_mem_pattern_support (name : string)
98122 Frontend.SignatureMismatch. check_compatible_arguments_mod_conv x args
99123 |> Result. is_ok)
100124 namematches in
101- let is_soa = function _ , _ , _ , Mem_pattern. SoA -> true | _ -> false in
125+ let is_soa ( _ , _ , _ , p ) = p = Mem_pattern. SoA in
102126 List. exists ~f: is_soa filteredmatches
103127
104128(* Validate whether a function can support SoA matrices*)
@@ -116,13 +140,13 @@ let is_fun_soa_supported name exprs =
116140 will be returned if the matrix or vector is accessed by single
117141 cell indexing.
118142 *)
119- let rec query_initial_demotable_expr (in_loop : bool ) ~( acc : string Set.Poly.t )
120- Expr. {pattern; _} : string Set.Poly.t =
143+ let rec query_initial_demotable_expr (in_loop : bool ) ( stmt_linenum : int )
144+ ~( acc : string Set.Poly.t ) Expr. {pattern; _} : string Set.Poly.t =
121145 let query_expr (accum : string Set.Poly.t ) =
122- query_initial_demotable_expr in_loop ~acc: accum in
146+ query_initial_demotable_expr in_loop stmt_linenum ~acc: accum in
123147 match pattern with
124148 | FunApp (kind , (exprs : Expr.Typed.t list )) ->
125- query_initial_demotable_funs in_loop acc kind exprs
149+ query_initial_demotable_funs in_loop stmt_linenum acc kind exprs
126150 | Indexed ((Expr. {meta = {type_; _} ; _} as expr ), indexed ) ->
127151 let index_set =
128152 Set.Poly. union_list
@@ -132,18 +156,29 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
132156 (query_expr acc))
133157 indexed) in
134158 let index_demotes =
135- if is_uni_eigen_loop_indexing in_loop type_ indexed then
136- Set. union (query_var_eigen_names expr) index_set
159+ if is_uni_eigen_loop_indexing in_loop type_ indexed then (
160+ let single_index_set = query_var_eigen_names expr in
161+ let failure_str = concat_set_str (Set. inter acc single_index_set) in
162+ let msg = " Accessed by element in a for loop:" in
163+ user_warning_op SoA stmt_linenum msg failure_str;
164+ Set. union single_index_set index_set)
137165 else Set. union (query_expr acc expr) index_set in
138166 Set. union acc index_demotes
139167 | Var (_ : string ) | Lit ((_ : Expr.Pattern.litType ), (_ : string )) -> acc
140168 | Promotion (expr , _ , _ ) -> query_expr acc expr
141169 | TupleProjection (expr , _ ) -> query_expr acc expr
142170 | TernaryIf (predicate , texpr , fexpr ) ->
143171 let predicate_demotes = query_expr acc predicate in
144- Set. union
145- (Set. union predicate_demotes (query_var_eigen_names texpr))
146- (query_var_eigen_names fexpr)
172+ let full_set =
173+ Set. union
174+ (Set. union predicate_demotes (query_var_eigen_names texpr))
175+ (query_var_eigen_names fexpr) in
176+ if Set. is_empty full_set then full_set
177+ else
178+ let failure_str = concat_set_str (Set. inter acc full_set) in
179+ let msg = " Used in a ternary operator which is not allowed:" in
180+ user_warning_op SoA stmt_linenum msg failure_str;
181+ full_set
147182 | EAnd (lhs , rhs ) | EOr (lhs , rhs ) ->
148183 (* We need to get the demotes from both sides*)
149184 let full_lhs_rhs = Set. union (query_expr acc lhs) (query_expr acc rhs) in
@@ -166,9 +201,11 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t)
166201 to the UDF.
167202 exprs The expression list passed to the functions.
168203 *)
169- and query_initial_demotable_funs (in_loop : bool ) (acc : string Set.Poly.t )
170- (kind : 'a Fun_kind.t ) (exprs : Expr.Typed.t list ) : string Set.Poly.t =
171- let query_expr accum = query_initial_demotable_expr in_loop ~acc: accum in
204+ and query_initial_demotable_funs (in_loop : bool ) (stmt_linenum : int )
205+ (acc : string Set.Poly.t ) (kind : 'a Fun_kind.t ) (exprs : Expr.Typed.t list )
206+ : string Set.Poly.t =
207+ let query_expr accum =
208+ query_initial_demotable_expr in_loop stmt_linenum ~acc: accum in
172209 let top_level_eigen_names =
173210 Set.Poly. union_list (List. map ~f: query_var_eigen_names exprs) in
174211 let demoted_eigen_names = List. fold ~init: acc ~f: query_expr exprs in
@@ -181,11 +218,26 @@ and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t)
181218 | name -> (
182219 match is_fun_soa_supported name exprs with
183220 | true -> Set. union acc demoted_eigen_names
184- | false -> Set. union acc demoted_and_top_level_names))
221+ | false ->
222+ let fail_names =
223+ concat_set_str (Set. inter acc top_level_eigen_names) in
224+ user_warning_op SoA stmt_linenum
225+ (" Function " ^ name ^ " is not supported:" )
226+ fail_names;
227+ Set. union acc demoted_and_top_level_names))
185228 | CompilerInternal (Internal_fun. FnMakeArray | FnMakeRowVec | FnMakeTuple ) ->
229+ let fail_names =
230+ concat_set_str (Set. inter acc demoted_and_top_level_names) in
231+ user_warning_op SoA stmt_linenum
232+ " Used in {} make array or make row vector compiler functions:"
233+ fail_names;
186234 Set. union acc demoted_and_top_level_names
187235 | CompilerInternal (_ : 'a Internal_fun.t ) -> acc
188236 | UserDefined ((_ : string ), (_ : bool Fun_kind.suffix )) ->
237+ let fail_names =
238+ concat_set_str (Set. inter acc demoted_and_top_level_names) in
239+ user_warning_op SoA stmt_linenum " Used in user defined function:"
240+ fail_names;
189241 Set. union acc demoted_and_top_level_names
190242
191243(* *
@@ -283,9 +335,10 @@ let contains_at_least_one_ad_matrix_or_all_data
283335 [query_initial_demotable_expr] for an explanation of the logic.
284336 *)
285337let rec query_initial_demotable_stmt (in_loop : bool ) (acc : string Set.Poly.t )
286- (Stmt. {pattern; _} : Stmt.Located.t ) : string Set.Poly.t =
338+ (Stmt. {pattern; meta} : Stmt.Located.t ) : string Set.Poly.t =
339+ let linenum = meta.end_loc.line_num in
287340 let query_expr (accum : string Set.Poly.t ) =
288- query_initial_demotable_expr in_loop ~acc: accum in
341+ query_initial_demotable_expr in_loop linenum ~acc: accum in
289342 match pattern with
290343 | Stmt.Pattern. Assignment
291344 ( lval
@@ -299,21 +352,25 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
299352 List. fold ~init: acc
300353 ~f: (fun accum x ->
301354 Index. folder accum
302- (fun acc -> query_initial_demotable_expr in_loop ~acc )
355+ (fun acc -> query_initial_demotable_expr in_loop linenum ~acc )
303356 x)
304357 idx in
305358 match is_uni_eigen_loop_indexing in_loop ut idx with
306- | true -> Set. add idx_list name
359+ | true ->
360+ user_warning_op SoA linenum " Accessed by element in a for loop:"
361+ (if Set. mem acc name then " " else name);
362+ Set. add idx_list name
307363 | false -> idx_list in
308364 let rhs_demotable_names = query_expr acc rhs in
309365 let rhs_and_idx_demotions = Set. union idx_demotable rhs_demotable_names in
310366 (* RHS (1)*)
311367 let tuple_demotions =
312368 match lval with
313369 | LTupleProjection _ , _ ->
314- Set. add
315- (Set. union rhs_and_idx_demotions (query_var_eigen_names rhs))
316- name
370+ let tuple_set = query_var_eigen_names rhs in
371+ let fail_set = concat_set_str tuple_set in
372+ user_warning_op SoA linenum " Used in tuple:" fail_set;
373+ Set. add (Set. union rhs_and_idx_demotions tuple_set) name
317374 | _ -> rhs_and_idx_demotions in
318375 let assign_demotions =
319376 let is_eigen_stmt = UnsizedType. contains_eigen_type rhs.meta.type_ in
@@ -327,30 +384,52 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
327384 (extract_nonderived_admatrix_types rhs))
328385 | _ -> false in
329386 (* LHS (3) rhs unsupported function*)
330- let is_not_supported_func =
387+ let non_supported_func_name =
331388 match rhs.pattern with
332- | FunApp (UserDefined _ , _ ) -> true
333- | FunApp (CompilerInternal _ , _ ) -> false
334- | FunApp (StanLib ( name , _ , _ ), exprs ) ->
335- not
336- (query_stan_math_mem_pattern_support name
337- ( List. map ~f: Expr.Typed. fun_arg exprs))
338- | _ -> false in
389+ | FunApp (UserDefined ( name , _ ) , _ ) -> Some name
390+ | FunApp (StanLib (name, _, _), exprs)
391+ when not
392+ (query_stan_math_mem_pattern_support name
393+ ( List. map ~f: Expr.Typed. fun_arg exprs)) ->
394+ Some name
395+ | _ -> None in
339396 (* LHS (3) all rhs aos*)
340397 let is_all_rhs_aos =
341398 is_nonzero_subset
342399 ~subset: (query_var_eigen_names rhs)
343400 ~set: rhs_demotable_names in
344401 if
345402 is_all_rhs_aos || is_rhs_not_promoteable_to_soa
346- || is_not_supported_func
347- then
348- Set. add (Set. union tuple_demotions (query_var_eigen_names rhs)) name
403+ || Option. is_some non_supported_func_name
404+ then (
405+ let rhs_set = query_var_eigen_names rhs in
406+ let all_rhs_warn =
407+ if is_all_rhs_aos then " Right hand side of assignment is all AoS:"
408+ else " " in
409+ let rhs_not_promotable_to_soa_warn =
410+ if is_rhs_not_promoteable_to_soa then
411+ " The right hand side of the assignment only contains data and \
412+ scalar operations that are not promotable to SoA:"
413+ else " " in
414+ let not_supported_func_warn =
415+ match non_supported_func_name with
416+ | Some fname ->
417+ " Function '" ^ fname
418+ ^ " ' on right hand side of assignment is not supported by \
419+ SoA:"
420+ | None -> " " in
421+ let rhs_name_set = Set. add rhs_set name in
422+ let rhs_name_set_str = concat_set_str rhs_name_set in
423+ user_warning_op SoA linenum all_rhs_warn rhs_name_set_str;
424+ user_warning_op SoA linenum rhs_not_promotable_to_soa_warn
425+ rhs_name_set_str;
426+ user_warning_op SoA linenum not_supported_func_warn rhs_name_set_str;
427+ Set. add (Set. union tuple_demotions rhs_set) name)
349428 else tuple_demotions
350429 else tuple_demotions in
351430 Set. union acc assign_demotions
352431 | NRFunApp (kind , exprs ) ->
353- query_initial_demotable_funs in_loop acc kind exprs
432+ query_initial_demotable_funs in_loop linenum acc kind exprs
354433 | IfElse (predicate , true_stmt , op_false_stmt ) ->
355434 let predicate_acc = query_expr acc predicate in
356435 Set. union acc
@@ -386,7 +465,10 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
386465 | Decl {decl_type = Type. Sized st ; decl_id; initialize; _} ->
387466 let complex_name =
388467 match SizedType. is_complex_type st with
389- | true -> Set.Poly. singleton decl_id
468+ | true ->
469+ user_warning_op SoA linenum " Complex-valued types cannot be SoA:"
470+ decl_id;
471+ Set.Poly. singleton decl_id
390472 | false -> Set.Poly. empty in
391473 let init_names =
392474 match initialize with
@@ -408,24 +490,44 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t)
408490 @param pattern The Stmt pattern to query.
409491 *)
410492let query_demotable_stmt (aos_exits : string Set.Poly.t )
411- (pattern : (Expr.Typed.t, int) Stmt.Pattern.t ) : string Set.Poly.t =
412- match pattern with
493+ (stmt : Stmt.Located.Non_recursive.t ) : string Set.Poly.t =
494+ let linenum = stmt.meta.end_loc.line_num in
495+ match stmt.pattern with
413496 | Stmt.Pattern. Assignment (lval, (_ : UnsizedType.t ), (rhs : Expr.Typed.t ))
414497 -> (
415498 let assign_name = Stmt.Helpers. lhs_variable lval in
416499 let all_rhs_eigen_names = query_var_eigen_names rhs in
417- if Set. mem aos_exits assign_name then
418- Set. add all_rhs_eigen_names assign_name
500+ if Set. mem aos_exits assign_name then (
501+ user_warning_op SoA linenum
502+ " Right hand side contains only AoS expressions:" assign_name;
503+ Set. add all_rhs_eigen_names assign_name)
419504 else
420505 match is_nonzero_subset ~set: aos_exits ~subset: all_rhs_eigen_names with
421- | true -> Set. add all_rhs_eigen_names assign_name
506+ | true ->
507+ let warn =
508+ Fmt. (
509+ str " Right hand side contains AoS expressions (%s):"
510+ (concat_set_str (Set. inter aos_exits all_rhs_eigen_names)))
511+ in
512+ user_warning_op SoA linenum warn assign_name;
513+ Set. add all_rhs_eigen_names assign_name
422514 | false -> Set.Poly. empty)
423515 | Decl {decl_id; initialize = Assign e ; _} -> (
424516 let all_rhs_eigen_names = query_var_eigen_names e in
425- if Set. mem aos_exits decl_id then Set. add all_rhs_eigen_names decl_id
517+ if Set. mem aos_exits decl_id then (
518+ user_warning_op SoA linenum
519+ " Right hand side contains only AoS expressions:" decl_id;
520+ Set. add all_rhs_eigen_names decl_id)
426521 else
427522 match is_nonzero_subset ~set: aos_exits ~subset: all_rhs_eigen_names with
428- | true -> Set. add all_rhs_eigen_names decl_id
523+ | true ->
524+ let warn =
525+ Fmt. (
526+ str " Right hand side contains AoS expressions (%s):"
527+ (concat_set_str (Set. inter aos_exits all_rhs_eigen_names)))
528+ in
529+ user_warning_op SoA linenum warn decl_id;
530+ Set. add all_rhs_eigen_names decl_id
429531 | false -> Set.Poly. empty)
430532 (* All other statements do not need logic here*)
431533 | _ -> Set.Poly. empty
0 commit comments