Skip to content

Commit 3037192

Browse files
authored
chore: Implement datatype conversion for all types in arrow.rs (#81)
1 parent c644023 commit 3037192

File tree

2 files changed

+223
-29
lines changed

2 files changed

+223
-29
lines changed

crates/fluss/src/metadata/datatype.rs

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -96,25 +96,25 @@ impl DataType {
9696
impl Display for DataType {
9797
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
9898
match self {
99-
DataType::Boolean(v) => write!(f, "{}", v),
100-
DataType::TinyInt(v) => write!(f, "{}", v),
101-
DataType::SmallInt(v) => write!(f, "{}", v),
102-
DataType::Int(v) => write!(f, "{}", v),
103-
DataType::BigInt(v) => write!(f, "{}", v),
104-
DataType::Float(v) => write!(f, "{}", v),
105-
DataType::Double(v) => write!(f, "{}", v),
106-
DataType::Char(v) => write!(f, "{}", v),
107-
DataType::String(v) => write!(f, "{}", v),
108-
DataType::Decimal(v) => write!(f, "{}", v),
109-
DataType::Date(v) => write!(f, "{}", v),
110-
DataType::Time(v) => write!(f, "{}", v),
111-
DataType::Timestamp(v) => write!(f, "{}", v),
112-
DataType::TimestampLTz(v) => write!(f, "{}", v),
113-
DataType::Bytes(v) => write!(f, "{}", v),
114-
DataType::Binary(v) => write!(f, "{}", v),
115-
DataType::Array(v) => write!(f, "{}", v),
116-
DataType::Map(v) => write!(f, "{}", v),
117-
DataType::Row(v) => write!(f, "{}", v),
99+
DataType::Boolean(v) => write!(f, "{v}"),
100+
DataType::TinyInt(v) => write!(f, "{v}"),
101+
DataType::SmallInt(v) => write!(f, "{v}"),
102+
DataType::Int(v) => write!(f, "{v}"),
103+
DataType::BigInt(v) => write!(f, "{v}"),
104+
DataType::Float(v) => write!(f, "{v}"),
105+
DataType::Double(v) => write!(f, "{v}"),
106+
DataType::Char(v) => write!(f, "{v}"),
107+
DataType::String(v) => write!(f, "{v}"),
108+
DataType::Decimal(v) => write!(f, "{v}"),
109+
DataType::Date(v) => write!(f, "{v}"),
110+
DataType::Time(v) => write!(f, "{v}"),
111+
DataType::Timestamp(v) => write!(f, "{v}"),
112+
DataType::TimestampLTz(v) => write!(f, "{v}"),
113+
DataType::Bytes(v) => write!(f, "{v}"),
114+
DataType::Binary(v) => write!(f, "{v}"),
115+
DataType::Array(v) => write!(f, "{v}"),
116+
DataType::Map(v) => write!(f, "{v}"),
117+
DataType::Row(v) => write!(f, "{v}"),
118118
}
119119
}
120120
}
@@ -861,7 +861,7 @@ impl Display for RowType {
861861
if i > 0 {
862862
write!(f, ", ")?;
863863
}
864-
write!(f, "{}", field)?;
864+
write!(f, "{field}")?;
865865
}
866866
write!(f, ">")?;
867867
if !self.nullable {

crates/fluss/src/record/arrow.rs

Lines changed: 203 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -589,16 +589,84 @@ pub fn to_arrow_type(fluss_type: &DataType) -> ArrowDataType {
589589
DataType::Double(_) => ArrowDataType::Float64,
590590
DataType::Char(_) => ArrowDataType::Utf8,
591591
DataType::String(_) => ArrowDataType::Utf8,
592-
DataType::Decimal(_) => todo!(),
592+
DataType::Decimal(decimal_type) => ArrowDataType::Decimal128(
593+
decimal_type
594+
.precision()
595+
.try_into()
596+
.expect("precision exceeds u8::MAX"),
597+
decimal_type
598+
.scale()
599+
.try_into()
600+
.expect("scale exceeds i8::MAX"),
601+
),
593602
DataType::Date(_) => ArrowDataType::Date32,
594-
DataType::Time(_) => todo!(),
595-
DataType::Timestamp(_) => todo!(),
596-
DataType::TimestampLTz(_) => todo!(),
597-
DataType::Bytes(_) => todo!(),
598-
DataType::Binary(_) => todo!(),
599-
DataType::Array(_data_type) => todo!(),
600-
DataType::Map(_data_type) => todo!(),
601-
DataType::Row(_data_fields) => todo!(),
603+
DataType::Time(time_type) => match time_type.precision() {
604+
0 => ArrowDataType::Time32(arrow_schema::TimeUnit::Second),
605+
1..=3 => ArrowDataType::Time32(arrow_schema::TimeUnit::Millisecond),
606+
4..=6 => ArrowDataType::Time64(arrow_schema::TimeUnit::Microsecond),
607+
7..=9 => ArrowDataType::Time64(arrow_schema::TimeUnit::Nanosecond),
608+
// This arm should never be reached due to validation in TimeType.
609+
invalid => panic!("Invalid precision value for TimeType: {invalid}"),
610+
},
611+
DataType::Timestamp(timestamp_type) => match timestamp_type.precision() {
612+
0 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Second, None),
613+
1..=3 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
614+
4..=6 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None),
615+
7..=9 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None),
616+
// This arm should never be reached due to validation in Timestamp.
617+
invalid => panic!("Invalid precision value for TimestampType: {invalid}"),
618+
},
619+
DataType::TimestampLTz(timestamp_ltz_type) => match timestamp_ltz_type.precision() {
620+
0 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Second, None),
621+
1..=3 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None),
622+
4..=6 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None),
623+
7..=9 => ArrowDataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None),
624+
// This arm should never be reached due to validation in TimestampLTz.
625+
invalid => panic!("Invalid precision value for TimestampLTzType: {invalid}"),
626+
},
627+
DataType::Bytes(_) => ArrowDataType::Binary,
628+
DataType::Binary(binary_type) => ArrowDataType::FixedSizeBinary(
629+
binary_type
630+
.length()
631+
.try_into()
632+
.expect("length exceeds i32::MAX"),
633+
),
634+
DataType::Array(array_type) => ArrowDataType::List(
635+
Field::new_list_field(
636+
to_arrow_type(array_type.get_element_type()),
637+
fluss_type.is_nullable(),
638+
)
639+
.into(),
640+
),
641+
DataType::Map(map_type) => {
642+
let key_type = to_arrow_type(map_type.key_type());
643+
let value_type = to_arrow_type(map_type.value_type());
644+
let entry_fields = vec![
645+
Field::new("key", key_type, map_type.key_type().is_nullable()),
646+
Field::new("value", value_type, map_type.value_type().is_nullable()),
647+
];
648+
ArrowDataType::Map(
649+
Arc::new(Field::new(
650+
"entries",
651+
ArrowDataType::Struct(arrow_schema::Fields::from(entry_fields)),
652+
fluss_type.is_nullable(),
653+
)),
654+
false,
655+
)
656+
}
657+
DataType::Row(row_type) => ArrowDataType::Struct(arrow_schema::Fields::from(
658+
row_type
659+
.fields()
660+
.iter()
661+
.map(|f| {
662+
Field::new(
663+
f.name(),
664+
to_arrow_type(f.data_type()),
665+
f.data_type().is_nullable(),
666+
)
667+
})
668+
.collect::<Vec<Field>>(),
669+
)),
602670
}
603671
}
604672

@@ -820,3 +888,129 @@ impl ArrowReader {
820888
}
821889
}
822890
pub struct MyVec<T>(pub StreamReader<T>);
891+
892+
#[cfg(test)]
893+
mod tests {
894+
use super::*;
895+
use crate::metadata::DataTypes;
896+
897+
#[test]
898+
fn test_to_array_type() {
899+
assert_eq!(to_arrow_type(&DataTypes::boolean()), ArrowDataType::Boolean);
900+
assert_eq!(to_arrow_type(&DataTypes::tinyint()), ArrowDataType::Int8);
901+
assert_eq!(to_arrow_type(&DataTypes::smallint()), ArrowDataType::Int16);
902+
assert_eq!(to_arrow_type(&DataTypes::bigint()), ArrowDataType::Int64);
903+
assert_eq!(to_arrow_type(&DataTypes::int()), ArrowDataType::Int32);
904+
assert_eq!(to_arrow_type(&DataTypes::float()), ArrowDataType::Float32);
905+
assert_eq!(to_arrow_type(&DataTypes::double()), ArrowDataType::Float64);
906+
assert_eq!(to_arrow_type(&DataTypes::char(16)), ArrowDataType::Utf8);
907+
assert_eq!(to_arrow_type(&DataTypes::string()), ArrowDataType::Utf8);
908+
assert_eq!(
909+
to_arrow_type(&DataTypes::decimal(10, 2)),
910+
ArrowDataType::Decimal128(10, 2)
911+
);
912+
assert_eq!(to_arrow_type(&DataTypes::date()), ArrowDataType::Date32);
913+
assert_eq!(
914+
to_arrow_type(&DataTypes::time()),
915+
ArrowDataType::Time32(arrow_schema::TimeUnit::Second)
916+
);
917+
assert_eq!(
918+
to_arrow_type(&DataTypes::time_with_precision(3)),
919+
ArrowDataType::Time32(arrow_schema::TimeUnit::Millisecond)
920+
);
921+
assert_eq!(
922+
to_arrow_type(&DataTypes::time_with_precision(6)),
923+
ArrowDataType::Time64(arrow_schema::TimeUnit::Microsecond)
924+
);
925+
assert_eq!(
926+
to_arrow_type(&DataTypes::time_with_precision(9)),
927+
ArrowDataType::Time64(arrow_schema::TimeUnit::Nanosecond)
928+
);
929+
assert_eq!(
930+
to_arrow_type(&DataTypes::timestamp_with_precision(0)),
931+
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Second, None)
932+
);
933+
assert_eq!(
934+
to_arrow_type(&DataTypes::timestamp_with_precision(3)),
935+
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None)
936+
);
937+
assert_eq!(
938+
to_arrow_type(&DataTypes::timestamp_with_precision(6)),
939+
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None)
940+
);
941+
assert_eq!(
942+
to_arrow_type(&DataTypes::timestamp_with_precision(9)),
943+
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None)
944+
);
945+
assert_eq!(
946+
to_arrow_type(&DataTypes::timestamp_ltz_with_precision(0)),
947+
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Second, None)
948+
);
949+
assert_eq!(
950+
to_arrow_type(&DataTypes::timestamp_ltz_with_precision(3)),
951+
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None)
952+
);
953+
assert_eq!(
954+
to_arrow_type(&DataTypes::timestamp_ltz_with_precision(6)),
955+
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None)
956+
);
957+
assert_eq!(
958+
to_arrow_type(&DataTypes::timestamp_ltz_with_precision(9)),
959+
ArrowDataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None)
960+
);
961+
assert_eq!(to_arrow_type(&DataTypes::bytes()), ArrowDataType::Binary);
962+
assert_eq!(
963+
to_arrow_type(&DataTypes::binary(16)),
964+
ArrowDataType::FixedSizeBinary(16)
965+
);
966+
967+
assert_eq!(
968+
to_arrow_type(&DataTypes::array(DataTypes::int())),
969+
ArrowDataType::List(Field::new_list_field(ArrowDataType::Int32, true).into())
970+
);
971+
972+
assert_eq!(
973+
to_arrow_type(&DataTypes::map(DataTypes::string(), DataTypes::int())),
974+
ArrowDataType::Map(
975+
Arc::new(Field::new(
976+
"entries",
977+
ArrowDataType::Struct(arrow_schema::Fields::from(vec![
978+
Field::new("key", ArrowDataType::Utf8, true),
979+
Field::new("value", ArrowDataType::Int32, true),
980+
])),
981+
true,
982+
)),
983+
false,
984+
)
985+
);
986+
987+
assert_eq!(
988+
to_arrow_type(&DataTypes::row(vec![
989+
DataTypes::field("f1".to_string(), DataTypes::int()),
990+
DataTypes::field("f2".to_string(), DataTypes::string()),
991+
])),
992+
ArrowDataType::Struct(arrow_schema::Fields::from(vec![
993+
Field::new("f1", ArrowDataType::Int32, true),
994+
Field::new("f2", ArrowDataType::Utf8, true),
995+
]))
996+
);
997+
}
998+
999+
#[test]
1000+
#[should_panic(expected = "Invalid precision value for TimeType: 10")]
1001+
fn test_time_invalid_precision() {
1002+
to_arrow_type(&DataTypes::time_with_precision(10));
1003+
}
1004+
1005+
#[test]
1006+
#[should_panic(expected = "Invalid precision value for TimestampType: 10")]
1007+
fn test_timestamp_invalid_precision() {
1008+
to_arrow_type(&DataTypes::timestamp_with_precision(10));
1009+
}
1010+
1011+
#[test]
1012+
#[should_panic(expected = "Invalid precision value for TimestampLTzType: 10")]
1013+
fn test_timestamp_ltz_invalid_precision() {
1014+
to_arrow_type(&DataTypes::timestamp_ltz_with_precision(10));
1015+
}
1016+
}

0 commit comments

Comments
 (0)