Skip to content

Commit 1359c60

Browse files
Merge pull request #3 from JuliaHealth/add/postcohort
[FEATURE] Post Cohort Feasibility
2 parents ee1080c + e99c878 commit 1359c60

File tree

9 files changed

+940
-89
lines changed

9 files changed

+940
-89
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
99
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1010
FunSQL = "cf6cc811-59f4-4a10-b258-a8547a8f6407"
1111
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
12-
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
1312
OMOPCommonDataModel = "ba65db9e-6590-4054-ab8a-101ed9124986"
1413
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1514

@@ -18,7 +17,6 @@ DBInterface = "2.6.1"
1817
DataFrames = "1.7.0"
1918
FunSQL = "0.10, 0.11, 0.12, 0.13"
2019
InlineStrings = "1.4.4"
21-
JuliaFormatter = "2.1.6"
2220
OMOPCommonDataModel = "0.1"
2321
PrettyTables = "2.4.0"
2422
julia = "1.10"
@@ -31,4 +29,4 @@ SQLite = "0aa819cd-b072-5ff4-a722-6bc24af294d9"
3129
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3230

3331
[targets]
34-
test = ["Test", "HealthSampleData", "SQLite", "OMOPCDMCohortCreator", "DataDeps"]
32+
test = ["Test", "HealthSampleData", "SQLite", "OMOPCDMCohortCreator", "DataDeps"]

src/OMOPCDMFeasibility.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@ using OMOPCommonDataModel
1919

2020
include("utils.jl")
2121
include("precohort.jl")
22+
include("postcohort.jl")
2223

23-
export analyze_concept_distribution, generate_summary, generate_domain_breakdown
24+
export analyze_concept_distribution,
25+
generate_summary,
26+
generate_domain_breakdown,
27+
create_individual_profiles,
28+
create_cartesian_profiles
2429

2530
end

src/postcohort.jl

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""
2+
create_individual_profiles(;
3+
cohort_definition_id::Union{Int, Nothing} = nothing,
4+
cohort_df::Union{DataFrame, Nothing} = nothing,
5+
conn,
6+
covariate_funcs::AbstractVector{<:Function},
7+
schema::String = "dbt_synthea_dev",
8+
dialect::Symbol = :postgresql
9+
)
10+
11+
Creates individual demographic profile tables for a cohort by analyzing each covariate separately.
12+
13+
This function generates separate DataFrames for each demographic covariate (e.g., gender, race, age group),
14+
providing detailed statistics including cohort and database-level percentages for post-cohort feasibility analysis.
15+
Results are sorted alphabetically by covariate values for consistent, readable output.
16+
17+
# Arguments
18+
- `conn` - Database connection using DBInterface
19+
- `covariate_funcs` - Vector of covariate functions from OMOPCDMCohortCreator (e.g., `GetPatientGender`, `GetPatientRace`)
20+
21+
# Keyword Arguments
22+
- `cohort_definition_id` - ID of the cohort definition in the cohort table (or nothing). Either this or `cohort_df` must be provided
23+
- `cohort_df` - DataFrame containing cohort with `person_id` column (or nothing). Either this or `cohort_definition_id` must be provided
24+
- `schema` - Database schema name. Default: `"dbt_synthea_dev"`
25+
- `dialect` - Database dialect. Default: `:postgresql` (for DuckDB compatibility)
26+
27+
# Returns
28+
- `NamedTuple` - Named tuple with keys corresponding to covariate names, each containing a DataFrame with covariate categories and statistics
29+
30+
# Examples
31+
```julia
32+
using OMOPCDMCohortCreator: GetPatientGender, GetPatientRace, GetPatientAgeGroup
33+
34+
individual_profiles = create_individual_profiles(
35+
cohort_df = my_cohort_df,
36+
conn = conn,
37+
covariate_funcs = [GetPatientGender, GetPatientRace, GetPatientAgeGroup]
38+
)
39+
```
40+
"""
41+
function create_individual_profiles(;
42+
cohort_definition_id::Union{Int,Nothing}=nothing,
43+
cohort_df::Union{DataFrame,Nothing}=nothing,
44+
conn,
45+
covariate_funcs::AbstractVector{<:Function},
46+
schema::String="dbt_synthea_dev",
47+
dialect::Symbol=:postgresql,
48+
)
49+
if cohort_definition_id === nothing && cohort_df === nothing
50+
throw(ArgumentError("Must provide either cohort_definition_id or cohort_df"))
51+
end
52+
53+
if isempty(covariate_funcs)
54+
throw(ArgumentError("covariate_funcs cannot be empty"))
55+
end
56+
57+
person_ids = _get_cohort_person_ids(
58+
cohort_definition_id, cohort_df, conn; schema=schema, dialect=dialect
59+
)
60+
cohort_size = length(person_ids)
61+
62+
database_size = _get_database_total_patients(conn; schema=schema, dialect=dialect)
63+
64+
_funcs = [Base.Fix2(fun, conn) for fun in covariate_funcs]
65+
demographics_df = _counter_reducer(person_ids, _funcs)
66+
67+
result_tables = Dict{Symbol,DataFrame}()
68+
69+
for col in names(demographics_df)
70+
if col != "person_id"
71+
covariate_stats = _create_individual_profile_table(
72+
demographics_df,
73+
col,
74+
cohort_size,
75+
database_size,
76+
conn;
77+
schema=schema,
78+
dialect=dialect,
79+
)
80+
covariate_name = Symbol(replace(string(col), "_concept_id" => ""))
81+
result_tables[covariate_name] = covariate_stats
82+
end
83+
end
84+
85+
return NamedTuple(result_tables)
86+
end
87+
88+
"""
89+
create_cartesian_profiles(;
90+
cohort_definition_id::Union{Int, Nothing} = nothing,
91+
cohort_df::Union{DataFrame, Nothing} = nothing,
92+
conn,
93+
covariate_funcs::AbstractVector{<:Function},
94+
schema::String = "dbt_synthea_dev",
95+
dialect::Symbol = :postgresql
96+
)
97+
98+
Creates Cartesian product demographic profiles for a cohort by analyzing all combinations of covariates.
99+
100+
This function generates a single DataFrame containing all possible combinations of demographic
101+
covariates (e.g., gender × race × age_group), providing comprehensive cross-tabulated statistics
102+
for detailed post-cohort feasibility analysis. Column order matches the input `covariate_funcs` order,
103+
and results are sorted by covariate values for interpretable output.
104+
105+
# Arguments
106+
- `conn` - Database connection using DBInterface
107+
- `covariate_funcs` - Vector of covariate functions from OMOPCDMCohortCreator (must contain at least 2 functions)
108+
109+
# Keyword Arguments
110+
- `cohort_definition_id` - ID of the cohort definition in the cohort table (or nothing). Either this or `cohort_df` must be provided
111+
- `cohort_df` - DataFrame containing cohort with `person_id` column (or nothing). Either this or `cohort_definition_id` must be provided
112+
- `schema` - Database schema name. Default: `"dbt_synthea_dev"`
113+
- `dialect` - Database dialect. Default: `:postgresql` (for DuckDB compatibility)
114+
115+
# Returns
116+
- `DataFrame` - Cross-tabulated profile table with all covariate combinations and statistics
117+
118+
# Examples
119+
```julia
120+
using OMOPCDMCohortCreator: GetPatientAgeGroup, GetPatientGender, GetPatientRace
121+
122+
cartesian_profiles = create_cartesian_profiles(
123+
cohort_df = my_cohort_df,
124+
conn = conn,
125+
covariate_funcs = [GetPatientAgeGroup, GetPatientGender, GetPatientRace]
126+
)
127+
```
128+
"""
129+
function create_cartesian_profiles(;
130+
cohort_definition_id::Union{Int,Nothing}=nothing,
131+
cohort_df::Union{DataFrame,Nothing}=nothing,
132+
conn,
133+
covariate_funcs::AbstractVector{<:Function},
134+
schema::String="dbt_synthea_dev",
135+
dialect::Symbol=:postgresql,
136+
)
137+
if cohort_definition_id === nothing && cohort_df === nothing
138+
throw(ArgumentError("Must provide either cohort_definition_id or cohort_df"))
139+
end
140+
141+
if length(covariate_funcs) < 2
142+
throw(
143+
ArgumentError("Need at least 2 covariate functions for Cartesian combinations")
144+
)
145+
end
146+
147+
person_ids = _get_cohort_person_ids(
148+
cohort_definition_id, cohort_df, conn; schema=schema, dialect=dialect
149+
)
150+
cohort_size = length(person_ids)
151+
152+
database_size = _get_database_total_patients(conn; schema=schema, dialect=dialect)
153+
154+
_funcs = [Base.Fix2(fun, conn) for fun in covariate_funcs]
155+
demographics_df = _counter_reducer(person_ids, _funcs)
156+
157+
demographic_cols = names(demographics_df)[names(demographics_df) .!= "person_id"]
158+
ordered_cols = reverse(demographic_cols)
159+
160+
result_df = _create_cartesian_profile_table(
161+
demographics_df,
162+
ordered_cols,
163+
cohort_size,
164+
database_size,
165+
conn;
166+
schema=schema,
167+
dialect=dialect,
168+
)
169+
170+
return result_df
171+
end

src/precohort.jl

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ function analyze_concept_distribution(
4343
)
4444
isempty(concept_set) && throw(ArgumentError("concept_set cannot be empty"))
4545

46-
concepts_by_domain = _get_concepts_by_domain(concept_set, conn; schema=schema, dialect=dialect)
46+
concepts_by_domain = _get_concepts_by_domain(
47+
concept_set, conn; schema=schema, dialect=dialect
48+
)
4749

4850
if isempty(concepts_by_domain)
4951
return DataFrame(;
@@ -57,18 +59,26 @@ function analyze_concept_distribution(
5759
try
5860
table_symbol = _domain_id_to_table(domain_id)
5961

60-
setup = _setup_domain_query(conn; domain=table_symbol, schema=schema, dialect=dialect)
62+
setup = _setup_domain_query(
63+
conn; domain=table_symbol, schema=schema, dialect=dialect
64+
)
6165

62-
base = Where(Fun.in(Get(setup.concept_col), domain_concepts...))(Join(
63-
:main_concept => setup.concept_table,
64-
Get(setup.concept_col) .== Get.main_concept.concept_id,
65-
)(From(setup.tbl)))
66+
base = Where(Fun.in(Get(setup.concept_col), domain_concepts...))(
67+
Join(
68+
:main_concept => setup.concept_table,
69+
Get(setup.concept_col) .== Get.main_concept.concept_id,
70+
)(
71+
From(setup.tbl)
72+
),
73+
)
6674

6775
q = Select(
6876
Get(:person_id),
6977
:concept_id => Get(setup.concept_col),
7078
:concept_name => Get.main_concept.concept_name,
71-
)(base)
79+
)(
80+
base
81+
)
7282
base_df = DataFrame(DBInterface.execute(setup.fconn, q))
7383

7484
if !isempty(base_df)
@@ -144,11 +154,13 @@ function generate_summary(
144154
covariate_funcs::AbstractVector{<:Function}=Function[],
145155
schema::String="main",
146156
dialect::Symbol=:postgresql,
147-
raw_values::Bool=false
157+
raw_values::Bool=false,
148158
)
149159
isempty(concept_set) && throw(ArgumentError("concept_set cannot be empty"))
150160

151-
concepts_by_domain = _get_concepts_by_domain(concept_set, conn; schema=schema, dialect=dialect)
161+
concepts_by_domain = _get_concepts_by_domain(
162+
concept_set, conn; schema=schema, dialect=dialect
163+
)
152164

153165
if isempty(concepts_by_domain)
154166
return DataFrame(;
@@ -170,16 +182,22 @@ function generate_summary(
170182
for (domain_id, domain_concepts) in concepts_by_domain
171183
try
172184
table_symbol = _domain_id_to_table(domain_id)
173-
setup = _setup_domain_query(conn; domain=table_symbol, schema=schema, dialect=dialect)
185+
setup = _setup_domain_query(
186+
conn; domain=table_symbol, schema=schema, dialect=dialect
187+
)
174188

175-
concept_records_q = Select(:total_concept_records => Agg.count())(Group()(Where(
176-
Fun.in(Get(setup.concept_col), domain_concepts...)
177-
)(From(setup.tbl))))
189+
concept_records_q = Select(:total_concept_records => Agg.count())(
190+
Group()(
191+
Where(Fun.in(Get(setup.concept_col), domain_concepts...))(
192+
From(setup.tbl)
193+
),
194+
),
195+
)
178196
domain_records = DataFrame(DBInterface.execute(setup.fconn, concept_records_q)).total_concept_records[1]
179197

180-
unique_patients_q = Select(Get(:person_id))(Where(
181-
Fun.in(Get(setup.concept_col), domain_concepts...)
182-
)(From(setup.tbl)))
198+
unique_patients_q = Select(Get(:person_id))(
199+
Where(Fun.in(Get(setup.concept_col), domain_concepts...))(From(setup.tbl))
200+
)
183201
domain_patients_df = DataFrame(
184202
DBInterface.execute(setup.fconn, unique_patients_q)
185203
)
@@ -307,18 +325,17 @@ function generate_domain_breakdown(
307325
covariate_funcs::AbstractVector{<:Function}=Function[],
308326
schema::String="main",
309327
dialect::Symbol=:postgresql,
310-
raw_values::Bool=false
328+
raw_values::Bool=false,
311329
)
312330
isempty(concept_set) && throw(ArgumentError("concept_set cannot be empty"))
313331

314-
concepts_by_domain = _get_concepts_by_domain(concept_set, conn; schema=schema, dialect=dialect)
332+
concepts_by_domain = _get_concepts_by_domain(
333+
concept_set, conn; schema=schema, dialect=dialect
334+
)
315335

316336
if isempty(concepts_by_domain)
317337
return DataFrame(;
318-
metric=String[],
319-
value=String[],
320-
interpretation=String[],
321-
domain=String[],
338+
metric=String[], value=String[], interpretation=String[], domain=String[]
322339
)
323340
end
324341

@@ -332,16 +349,22 @@ function generate_domain_breakdown(
332349
for (domain_id, domain_concepts) in concepts_by_domain
333350
try
334351
table_symbol = _domain_id_to_table(domain_id)
335-
setup = _setup_domain_query(conn; domain=table_symbol, schema=schema, dialect=dialect)
352+
setup = _setup_domain_query(
353+
conn; domain=table_symbol, schema=schema, dialect=dialect
354+
)
336355

337-
concept_records_q = Select(:total_concept_records => Agg.count())(Group()(Where(
338-
Fun.in(Get(setup.concept_col), domain_concepts...)
339-
)(From(setup.tbl))))
356+
concept_records_q = Select(:total_concept_records => Agg.count())(
357+
Group()(
358+
Where(Fun.in(Get(setup.concept_col), domain_concepts...))(
359+
From(setup.tbl)
360+
),
361+
),
362+
)
340363
domain_records = DataFrame(DBInterface.execute(setup.fconn, concept_records_q)).total_concept_records[1]
341364

342-
unique_patients_q = Select(Get(:person_id))(Where(
343-
Fun.in(Get(setup.concept_col), domain_concepts...)
344-
)(From(setup.tbl)))
365+
unique_patients_q = Select(Get(:person_id))(
366+
Where(Fun.in(Get(setup.concept_col), domain_concepts...))(From(setup.tbl))
367+
)
345368
domain_patients_df = DataFrame(
346369
DBInterface.execute(setup.fconn, unique_patients_q)
347370
)
@@ -367,7 +390,7 @@ function generate_domain_breakdown(
367390
domain_breakdown = DataFrame()
368391
for row in eachrow(domain_details)
369392
domain_coverage = round((row.patients / total_patients) * 100; digits=3)
370-
393+
371394
if raw_values
372395
domain_metrics = DataFrame(;
373396
metric=[
@@ -376,12 +399,7 @@ function generate_domain_breakdown(
376399
"$(row.domain) - Records",
377400
"$(row.domain) - Coverage (%)",
378401
],
379-
value=[
380-
row.concepts,
381-
row.patients,
382-
row.records,
383-
domain_coverage,
384-
],
402+
value=[row.concepts, row.patients, row.records, domain_coverage],
385403
interpretation=[
386404
"Number of concepts analyzed in $(row.domain) domain",
387405
"Patients with $(row.domain) concepts",

0 commit comments

Comments
 (0)