Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
add_one_to_nth_value_input,
add_order_by_to_empty_ranking_window_functions,
empty_in_values_right_side,
Expand Down Expand Up @@ -337,13 +335,11 @@ class SQLGlotCompiler(abc.ABC):
ops.Degrees: "degrees",
ops.DenseRank: "dense_rank",
ops.Exp: "exp",
FirstValue: "first_value",
ops.GroupConcat: "group_concat",
ops.IfElse: "if",
ops.IsInf: "isinf",
ops.IsNan: "isnan",
ops.JSONGetItem: "json_extract",
LastValue: "last_value",
ops.Levenshtein: "levenshtein",
ops.Ln: "ln",
ops.Log10: "log",
Expand Down Expand Up @@ -1244,6 +1240,12 @@ def visit_RowID(self, op, *, table):
op.name, table=table.alias_or_name, quoted=self.quoted, copy=False
)

def visit_FirstLastValue(self, op, *, arg, include_null):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer more verbose but less tricky, and just list them separately.

Also, currently handling include_null is opt-in: the base compiler silently ignores that arg, so individual backends like duckdb need to implement it themselves. This leads to the risk that a backend that DOES handle IGNORE NULL will forget to do it (eg pyspark, as in this PR?) and we will get incorrect behavior.

I would prefer to handle the include_null at the SqlGlotCompiler level and then deriving compilers to opt-out. Then users will get an error, instead of the incorrect result, which I think is a better outcome.

So can you change this bit to this:

    def visit_FirstValue(self, op, *, arg, include_null):
        if include_null:
            return sge.RespectNulls(this=self.f.first_value(arg))
        else:
            return sge.IgnoreNulls(this=self.f.first_value(arg))

    def visit_LastValue(self, op, *, arg, include_null):
        if include_null:
            return sge.RespectNulls(this=self.f.last_value(arg))
        else:
            return sge.IgnoreNulls(this=self.f.last_value(arg))

and then I'll put some other comments in the other files for the related changes

fun_name = "first_value" if type(op).__name__ == "FirstValue" else "last_value"
return self.f[fun_name](arg)

visit_FirstValue = visit_LastValue = visit_FirstLastValue

# TODO(kszucs): this should be renamed to something UDF related
def __sql_name__(self, op: ops.ScalarUDF | ops.AggUDF) -> str:
# for builtin functions use the exact function name, otherwise use the
Expand Down
14 changes: 14 additions & 0 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,5 +712,19 @@ def visit_TableUnnest(
def visit_StringToTime(self, op, *, arg, format_str):
return self.cast(self.f.str_to_time(arg, format_str), to=dt.time)

def visit_LastValue(self, op, *, arg, include_null):
return (
self.f.last_value(arg)
if include_null
else sge.IgnoreNulls(this=self.f.last_value(arg))
)

def visit_FirstValue(self, op, *, arg, include_null):
return (
self.f.first_value(arg)
if include_null
else sge.IgnoreNulls(this=self.f.first_value(arg))
)


compiler = DuckDBCompiler()
4 changes: 2 additions & 2 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,10 @@ def visit_CountDistinctStar(self, op, *, arg, where):
]
return self.f.count(sge.Distinct(expressions=cols))

def visit_FirstValue(self, op, *, arg):
def visit_FirstValue(self, op, *, arg, include_null):
return sge.IgnoreNulls(this=self.f.first(arg))

def visit_LastValue(self, op, *, arg):
def visit_LastValue(self, op, *, arg, include_null):
return sge.IgnoreNulls(this=self.f.last(arg))

def visit_First(self, op, *, arg, where, order_by, include_null):
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class FirstValue(ops.Analytic):
"""Retrieve the first element."""

arg: ops.Column[dt.Any]
include_null: bool = False

@attribute
def dtype(self):
Expand All @@ -84,6 +85,7 @@ class LastValue(ops.Analytic):
"""Retrieve the last element."""

arg: ops.Column[dt.Any]
include_null: bool = False

@attribute
def dtype(self):
Expand Down Expand Up @@ -204,7 +206,7 @@ def first_to_firstvalue(_, **kwargs):
"in a window function"
)
klass = FirstValue if isinstance(_.func, ops.First) else LastValue
return _.copy(func=klass(_.func.arg))
return _.copy(func=klass(_.func.arg, include_null=_.func.include_null))


@replace(p.Alias)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ WITH "t5" AS (
"t2"."field_of_study",
"t2"."years",
"t2"."degrees",
FIRST_VALUE("t2"."degrees") OVER (
FIRST_VALUE("t2"."degrees" IGNORE NULLS) OVER (
PARTITION BY "t2"."field_of_study"
ORDER BY "t2"."years" ASC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS "earliest_degrees",
LAST_VALUE("t2"."degrees") OVER (
LAST_VALUE("t2"."degrees" IGNORE NULLS) OVER (
PARTITION BY "t2"."field_of_study"
ORDER BY "t2"."years" ASC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
Expand Down
59 changes: 59 additions & 0 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,65 @@ def test_first_last(backend):
backend.assert_frame_equal(result, expected)


@pytest.mark.notimpl(
[
"risingwave",
"clickhouse",
"bigquery",
"oracle",
"snowflake",
"databricks",
"pyspark",
],
raises=AssertionError,
)
@pytest.mark.notyet(
["polars"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notyet(
["flink"],
raises=NotImplementedError,
)
@pytest.mark.notyet(
[
"mysql",
"sqlite",
"postgres",
"datafusion",
"druid",
"athena",
"impala",
"mssql",
"trino",
"exasol",
],
raises=Exception,
)
def test_first_last_include_nulls(backend):
t = ibis.memtable({"a": (2, 2, 1, 1), "b": (None, 3, 5, None), "c": list(range(4))})
w = ibis.window(group_by=t.a, order_by=t.c)
expr = t.select(
"a",
b_first_null=t.b.first(include_null=True).over(w),
b_last_null=t.b.last(include_null=True).over(w),
b_first=t.b.first(include_null=False).over(w),
b_last=t.b.last(include_null=False).over(w),
)
con = backend.connection
# execute the expr, and ensure the columns are sorted by column "a"
result = con.execute(expr).sort_values("a").set_index("a").reset_index(drop=True)
expected = pd.DataFrame(
{
"b_first_null": [5, 5, None, None],
"b_last_null": [None, None, 3, 3],
"b_first": [5, 5, 3, 3],
"b_last": [5, 5, 3, 3],
}
)
backend.assert_frame_equal(result, expected, check_dtype=False)


@pytest.mark.notyet(
["bigquery"], raises=GoogleBadRequest, reason="not supported by BigQuery"
)
Expand Down