Skip to content

Commit 5b65761

Browse files
committed
Add a test case for the new TypedRow encoder
implemented the proposal
1 parent f58ccd8 commit 5b65761

File tree

7 files changed

+218
-51
lines changed

7 files changed

+218
-51
lines changed

dataset/src/main/scala/frameless/RecordEncoder.scala

Lines changed: 89 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -143,20 +143,19 @@ object DropUnitValues {
143143
}
144144
}
145145

146-
class RecordEncoder[F, G <: HList, H <: HList](
146+
abstract class RecordEncoder[F, G <: HList, H <: HList](
147147
implicit
148-
i0: LabelledGeneric.Aux[F, G],
149-
i1: DropUnitValues.Aux[G, H],
150-
i2: IsHCons[H],
151-
fields: Lazy[RecordEncoderFields[H]],
152-
newInstanceExprs: Lazy[NewInstanceExprs[G]],
148+
stage1: RecordEncoderStage1[G, H],
153149
classTag: ClassTag[F])
154150
extends TypedEncoder[F] {
151+
152+
import stage1._
153+
155154
def nullable: Boolean = false
156155

157-
def jvmRepr: DataType = FramelessInternals.objectTypeFor[F]
156+
lazy val jvmRepr: DataType = FramelessInternals.objectTypeFor[F]
158157

159-
def catalystRepr: DataType = {
158+
lazy val catalystRepr: DataType = {
160159
val structFields = fields.value.value.map { field =>
161160
StructField(
162161
name = field.name,
@@ -169,39 +168,99 @@ class RecordEncoder[F, G <: HList, H <: HList](
169168
StructType(structFields)
170169
}
171170

172-
def toCatalyst(path: Expression): Expression = {
173-
val nameExprs = fields.value.value.map { field => Literal(field.name) }
171+
}
174172

175-
val valueExprs = fields.value.value.map { field =>
176-
val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil)
177-
field.encoder.toCatalyst(fieldPath)
178-
}
173+
object RecordEncoder {
174+
175+
case class ForGeneric[F, G <: HList, H <: HList](
176+
)(implicit
177+
stage1: RecordEncoderStage1[G, H],
178+
classTag: ClassTag[F])
179+
extends RecordEncoder[F, G, H] {
180+
181+
import stage1._
182+
183+
def toCatalyst(path: Expression): Expression = {
184+
185+
val valueExprs = fields.value.value.map { field =>
186+
val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil)
187+
field.encoder.toCatalyst(fieldPath)
188+
}
189+
190+
val createExpr = stage1.cellsToCatalyst(valueExprs)
179191

180-
// the way exprs are encoded in CreateNamedStruct
181-
val exprs = nameExprs.zip(valueExprs).flatMap {
182-
case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil
192+
val nullExpr = Literal.create(null, createExpr.dataType)
193+
194+
If(IsNull(path), nullExpr, createExpr)
183195
}
184196

185-
val createExpr = CreateNamedStruct(exprs)
186-
val nullExpr = Literal.create(null, createExpr.dataType)
197+
def fromCatalyst(path: Expression): Expression = {
198+
199+
val newArgs = stage1.fromCatalystToCells(path)
200+
201+
val newExpr =
202+
NewInstance(
203+
classTag.runtimeClass,
204+
newArgs,
205+
jvmRepr,
206+
propagateNull = true
207+
)
208+
209+
val nullExpr = Literal.create(null, jvmRepr)
187210

188-
If(IsNull(path), nullExpr, createExpr)
211+
If(IsNull(path), nullExpr, newExpr)
212+
}
189213
}
190214

191-
def fromCatalyst(path: Expression): Expression = {
192-
val exprs = fields.value.value.map { field =>
193-
field.encoder.fromCatalyst(
194-
GetStructField(path, field.ordinal, Some(field.name))
195-
)
215+
case class ForTypedRow[G <: HList, H <: HList](
216+
)(implicit
217+
stage1: RecordEncoderStage1[G, H],
218+
classTag: ClassTag[TypedRow[G]])
219+
extends RecordEncoder[TypedRow[G], G, H] {
220+
221+
import stage1._
222+
223+
private final val _apply = "apply"
224+
private final val _fromInternalRow = "fromInternalRow"
225+
226+
def toCatalyst(path: Expression): Expression = {
227+
228+
val valueExprs = fields.value.value.zipWithIndex.map {
229+
case (field, i) =>
230+
val fieldPath = Invoke(
231+
path,
232+
_apply,
233+
field.encoder.jvmRepr,
234+
Seq(Literal.create(i, IntegerType))
235+
)
236+
field.encoder.toCatalyst(fieldPath)
237+
}
238+
239+
val createExpr = stage1.cellsToCatalyst(valueExprs)
240+
241+
val nullExpr = Literal.create(null, createExpr.dataType)
242+
243+
If(IsNull(path), nullExpr, createExpr)
196244
}
197245

198-
val newArgs = newInstanceExprs.value.from(exprs)
199-
val newExpr =
200-
NewInstance(classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true)
246+
def fromCatalyst(path: Expression): Expression = {
201247

202-
val nullExpr = Literal.create(null, jvmRepr)
248+
val newArgs = stage1.fromCatalystToCells(path)
249+
val aggregated = CreateStruct(newArgs)
203250

204-
If(IsNull(path), nullExpr, newExpr)
251+
val partial = TypedRow.WithCatalystTypes(newArgs.map(_.dataType))
252+
253+
val newExpr = Invoke(
254+
Literal.fromObject(partial),
255+
_fromInternalRow,
256+
TypedRow.catalystType,
257+
Seq(aggregated)
258+
)
259+
260+
val nullExpr = Literal.create(null, jvmRepr)
261+
262+
If(IsNull(path), nullExpr, newExpr)
263+
}
205264
}
206265
}
207266

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package frameless
2+
3+
import org.apache.spark.sql.catalyst.expressions.{
4+
CreateNamedStruct,
5+
Expression,
6+
GetStructField,
7+
Literal
8+
}
9+
import shapeless.{ HList, Lazy }
10+
11+
case class RecordEncoderStage1[G <: HList, H <: HList](
12+
)(implicit
13+
// i1: DropUnitValues.Aux[G, H],
14+
// i2: IsHCons[H],
15+
val fields: Lazy[RecordEncoderFields[H]],
16+
val newInstanceExprs: Lazy[NewInstanceExprs[G]]) {
17+
18+
def cellsToCatalyst(valueExprs: Seq[Expression]): Expression = {
19+
val nameExprs = fields.value.value.map { field => Literal(field.name) }
20+
21+
// the way exprs are encoded in CreateNamedStruct
22+
val exprs = nameExprs.zip(valueExprs).flatMap {
23+
case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil
24+
}
25+
26+
val createExpr = CreateNamedStruct(exprs)
27+
createExpr
28+
}
29+
30+
def fromCatalystToCells(path: Expression): Seq[Expression] = {
31+
val exprs = fields.value.value.map { field =>
32+
field.encoder.fromCatalyst(
33+
GetStructField(path, field.ordinal, Some(field.name))
34+
)
35+
}
36+
37+
val newArgs = newInstanceExprs.value.from(exprs)
38+
newArgs
39+
}
40+
}
41+
42+
object RecordEncoderStage1 {
43+
44+
implicit def usingDerivation[G <: HList, H <: HList](
45+
implicit
46+
i3: Lazy[RecordEncoderFields[H]],
47+
i4: Lazy[NewInstanceExprs[G]]
48+
): RecordEncoderStage1[G, H] = RecordEncoderStage1[G, H]()
49+
}

dataset/src/main/scala/frameless/TypedEncoder.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -727,15 +727,23 @@ object TypedEncoder {
727727
}
728728

729729
/** Encodes things as records if there is no Injection defined */
730-
implicit def usingDerivation[F, G <: HList, H <: HList](
730+
implicit def deriveForGeneric[F, G <: HList, H <: HList](
731731
implicit
732732
i0: LabelledGeneric.Aux[F, G],
733733
i1: DropUnitValues.Aux[G, H],
734734
i2: IsHCons[H],
735735
i3: Lazy[RecordEncoderFields[H]],
736736
i4: Lazy[NewInstanceExprs[G]],
737737
i5: ClassTag[F]
738-
): TypedEncoder[F] = new RecordEncoder[F, G, H]
738+
): TypedEncoder[F] = RecordEncoder.ForGeneric[F, G, H]()
739+
740+
implicit def deriveForTypedRow[G <: HList, H <: HList](
741+
implicit
742+
i1: DropUnitValues.Aux[G, H],
743+
i2: IsHCons[H],
744+
i3: Lazy[RecordEncoderFields[H]],
745+
i4: Lazy[NewInstanceExprs[G]]
746+
): TypedEncoder[TypedRow[G]] = RecordEncoder.ForTypedRow[G, H]()
739747

740748
/** Encodes things using a Spark SQL's User Defined Type (UDT) if there is one defined in implicit */
741749
implicit def usingUserDefinedType[
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package frameless
2+
3+
import org.apache.spark.sql.Row
4+
import org.apache.spark.sql.catalyst.InternalRow
5+
import org.apache.spark.sql.types.{ DataType, ObjectType }
6+
import shapeless.HList
7+
8+
case class TypedRow[T <: HList](row: Row) {
9+
10+
def apply(i: Int): Any = row.apply(i)
11+
}
12+
13+
object TypedRow {
14+
15+
def apply(values: Any*): TypedRow[HList] = {
16+
17+
val row = Row.fromSeq(values)
18+
TypedRow(row)
19+
}
20+
21+
case class WithCatalystTypes(schema: Seq[DataType]) {
22+
23+
def fromInternalRow(row: InternalRow): TypedRow[HList] = {
24+
val data = row.toSeq(schema).toArray
25+
26+
apply(data: _*)
27+
}
28+
29+
}
30+
31+
object WithCatalystTypes {}
32+
33+
def fromHList[T <: HList](
34+
hlist: T
35+
): TypedRow[T] = {
36+
37+
val cells = hlist.runtimeList
38+
39+
val row = Row.fromSeq(cells)
40+
TypedRow(row)
41+
}
42+
43+
lazy val catalystType: ObjectType = ObjectType(classOf[TypedRow[_]])
44+
45+
}

dataset/src/test/scala/frameless/InjectionTests.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ class InjectionTests extends TypedDatasetSuite {
202202
}
203203

204204
test("Resolve ambiguity by importing usingDerivation") {
205-
import TypedEncoder.usingDerivation
205+
import TypedEncoder.deriveForGeneric
206206
assert(
207207
implicitly[TypedEncoder[Person]].isInstanceOf[RecordEncoder[Person, _, _]]
208208
)

dataset/src/test/scala/frameless/RecordEncoderTests.scala

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,12 @@
11
package frameless
22

3+
import frameless.RecordEncoderTests.{ A, B, E }
4+
import org.apache.spark.sql.types._
35
import org.apache.spark.sql.{ Row, functions => F }
4-
import org.apache.spark.sql.types.{
5-
ArrayType,
6-
BinaryType,
7-
DecimalType,
8-
IntegerType,
9-
LongType,
10-
MapType,
11-
ObjectType,
12-
StringType,
13-
StructField,
14-
StructType
15-
}
16-
17-
import shapeless.{ HList, LabelledGeneric }
18-
import shapeless.test.illTyped
19-
206
import org.scalatest.matchers.should.Matchers
7+
import shapeless.record.Record
8+
import shapeless.test.illTyped
9+
import shapeless.{ HList, LabelledGeneric }
2110

2211
final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
2312
test("Unable to encode products made from units only") {
@@ -101,6 +90,20 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
10190
ds.collect.head shouldBe obj
10291
}
10392

93+
test("shapeless Record") {
94+
95+
val r1: RecordEncoderTests.RR = Record(x = 1, y = "abc")
96+
val r2: TypedRow[RecordEncoderTests.RR] = TypedRow.fromHList(r1)
97+
98+
val rdd = sc.parallelize(Seq(r2))
99+
val ds =
100+
session.createDataset(rdd)(
101+
TypedExpressionEncoder[TypedRow[RecordEncoderTests.RR]]
102+
)
103+
104+
ds.collect.head shouldBe r2
105+
}
106+
104107
test("Scalar value class") {
105108
import RecordEncoderTests._
106109

@@ -632,6 +635,9 @@ object RecordEncoderTests {
632635
case class D(m: Map[String, Int])
633636
case class E(b: Set[B])
634637

638+
val RR = Record.`'x -> Int, 'y -> String`
639+
type RR = RR.T
640+
635641
final class Subject(val name: String) extends AnyVal with Serializable
636642

637643
final class Grade(val value: BigDecimal) extends AnyVal with Serializable

refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ object RefinedTypesTests {
127127

128128
import frameless.refined._ // implicit instances for refined
129129

130-
implicit val encoderA: TypedEncoder[A] = TypedEncoder.usingDerivation
130+
implicit val encoderA: TypedEncoder[A] = TypedEncoder.deriveForGeneric
131131

132-
implicit val encoderB: TypedEncoder[B] = TypedEncoder.usingDerivation
132+
implicit val encoderB: TypedEncoder[B] = TypedEncoder.deriveForGeneric
133133
}

0 commit comments

Comments
 (0)