@@ -589,16 +589,84 @@ pub fn to_arrow_type(fluss_type: &DataType) -> ArrowDataType {
589589 DataType :: Double ( _) => ArrowDataType :: Float64 ,
590590 DataType :: Char ( _) => ArrowDataType :: Utf8 ,
591591 DataType :: String ( _) => ArrowDataType :: Utf8 ,
592- DataType :: Decimal ( _) => todo ! ( ) ,
592+ DataType :: Decimal ( decimal_type) => ArrowDataType :: Decimal128 (
593+ decimal_type
594+ . precision ( )
595+ . try_into ( )
596+ . expect ( "precision exceeds u8::MAX" ) ,
597+ decimal_type
598+ . scale ( )
599+ . try_into ( )
600+ . expect ( "scale exceeds i8::MAX" ) ,
601+ ) ,
593602 DataType :: Date ( _) => ArrowDataType :: Date32 ,
594- DataType :: Time ( _) => todo ! ( ) ,
595- DataType :: Timestamp ( _) => todo ! ( ) ,
596- DataType :: TimestampLTz ( _) => todo ! ( ) ,
597- DataType :: Bytes ( _) => todo ! ( ) ,
598- DataType :: Binary ( _) => todo ! ( ) ,
599- DataType :: Array ( _data_type) => todo ! ( ) ,
600- DataType :: Map ( _data_type) => todo ! ( ) ,
601- DataType :: Row ( _data_fields) => todo ! ( ) ,
603+ DataType :: Time ( time_type) => match time_type. precision ( ) {
604+ 0 => ArrowDataType :: Time32 ( arrow_schema:: TimeUnit :: Second ) ,
605+ 1 ..=3 => ArrowDataType :: Time32 ( arrow_schema:: TimeUnit :: Millisecond ) ,
606+ 4 ..=6 => ArrowDataType :: Time64 ( arrow_schema:: TimeUnit :: Microsecond ) ,
607+ 7 ..=9 => ArrowDataType :: Time64 ( arrow_schema:: TimeUnit :: Nanosecond ) ,
608+ // This arm should never be reached due to validation in TimeType.
609+ invalid => panic ! ( "Invalid precision value for TimeType: {invalid}" ) ,
610+ } ,
611+ DataType :: Timestamp ( timestamp_type) => match timestamp_type. precision ( ) {
612+ 0 => ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Second , None ) ,
613+ 1 ..=3 => ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Millisecond , None ) ,
614+ 4 ..=6 => ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Microsecond , None ) ,
615+ 7 ..=9 => ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Nanosecond , None ) ,
616+ // This arm should never be reached due to validation in Timestamp.
617+ invalid => panic ! ( "Invalid precision value for TimestampType: {invalid}" ) ,
618+ } ,
619+ DataType :: TimestampLTz ( timestamp_ltz_type) => match timestamp_ltz_type. precision ( ) {
620+ 0 => ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Second , None ) ,
621+ 1 ..=3 => ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Millisecond , None ) ,
622+ 4 ..=6 => ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Microsecond , None ) ,
623+ 7 ..=9 => ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Nanosecond , None ) ,
624+ // This arm should never be reached due to validation in TimestampLTz.
625+ invalid => panic ! ( "Invalid precision value for TimestampLTzType: {invalid}" ) ,
626+ } ,
627+ DataType :: Bytes ( _) => ArrowDataType :: Binary ,
628+ DataType :: Binary ( binary_type) => ArrowDataType :: FixedSizeBinary (
629+ binary_type
630+ . length ( )
631+ . try_into ( )
632+ . expect ( "length exceeds i32::MAX" ) ,
633+ ) ,
634+ DataType :: Array ( array_type) => ArrowDataType :: List (
635+ Field :: new_list_field (
636+ to_arrow_type ( array_type. get_element_type ( ) ) ,
637+ fluss_type. is_nullable ( ) ,
638+ )
639+ . into ( ) ,
640+ ) ,
641+ DataType :: Map ( map_type) => {
642+ let key_type = to_arrow_type ( map_type. key_type ( ) ) ;
643+ let value_type = to_arrow_type ( map_type. value_type ( ) ) ;
644+ let entry_fields = vec ! [
645+ Field :: new( "key" , key_type, map_type. key_type( ) . is_nullable( ) ) ,
646+ Field :: new( "value" , value_type, map_type. value_type( ) . is_nullable( ) ) ,
647+ ] ;
648+ ArrowDataType :: Map (
649+ Arc :: new ( Field :: new (
650+ "entries" ,
651+ ArrowDataType :: Struct ( arrow_schema:: Fields :: from ( entry_fields) ) ,
652+ fluss_type. is_nullable ( ) ,
653+ ) ) ,
654+ false ,
655+ )
656+ }
657+ DataType :: Row ( row_type) => ArrowDataType :: Struct ( arrow_schema:: Fields :: from (
658+ row_type
659+ . fields ( )
660+ . iter ( )
661+ . map ( |f| {
662+ Field :: new (
663+ f. name ( ) ,
664+ to_arrow_type ( f. data_type ( ) ) ,
665+ f. data_type ( ) . is_nullable ( ) ,
666+ )
667+ } )
668+ . collect :: < Vec < Field > > ( ) ,
669+ ) ) ,
602670 }
603671}
604672
@@ -820,3 +888,129 @@ impl ArrowReader {
820888 }
821889}
822890pub struct MyVec < T > ( pub StreamReader < T > ) ;
891+
892+ #[ cfg( test) ]
893+ mod tests {
894+ use super :: * ;
895+ use crate :: metadata:: DataTypes ;
896+
897+ #[ test]
898+ fn test_to_array_type ( ) {
899+ assert_eq ! ( to_arrow_type( & DataTypes :: boolean( ) ) , ArrowDataType :: Boolean ) ;
900+ assert_eq ! ( to_arrow_type( & DataTypes :: tinyint( ) ) , ArrowDataType :: Int8 ) ;
901+ assert_eq ! ( to_arrow_type( & DataTypes :: smallint( ) ) , ArrowDataType :: Int16 ) ;
902+ assert_eq ! ( to_arrow_type( & DataTypes :: bigint( ) ) , ArrowDataType :: Int64 ) ;
903+ assert_eq ! ( to_arrow_type( & DataTypes :: int( ) ) , ArrowDataType :: Int32 ) ;
904+ assert_eq ! ( to_arrow_type( & DataTypes :: float( ) ) , ArrowDataType :: Float32 ) ;
905+ assert_eq ! ( to_arrow_type( & DataTypes :: double( ) ) , ArrowDataType :: Float64 ) ;
906+ assert_eq ! ( to_arrow_type( & DataTypes :: char ( 16 ) ) , ArrowDataType :: Utf8 ) ;
907+ assert_eq ! ( to_arrow_type( & DataTypes :: string( ) ) , ArrowDataType :: Utf8 ) ;
908+ assert_eq ! (
909+ to_arrow_type( & DataTypes :: decimal( 10 , 2 ) ) ,
910+ ArrowDataType :: Decimal128 ( 10 , 2 )
911+ ) ;
912+ assert_eq ! ( to_arrow_type( & DataTypes :: date( ) ) , ArrowDataType :: Date32 ) ;
913+ assert_eq ! (
914+ to_arrow_type( & DataTypes :: time( ) ) ,
915+ ArrowDataType :: Time32 ( arrow_schema:: TimeUnit :: Second )
916+ ) ;
917+ assert_eq ! (
918+ to_arrow_type( & DataTypes :: time_with_precision( 3 ) ) ,
919+ ArrowDataType :: Time32 ( arrow_schema:: TimeUnit :: Millisecond )
920+ ) ;
921+ assert_eq ! (
922+ to_arrow_type( & DataTypes :: time_with_precision( 6 ) ) ,
923+ ArrowDataType :: Time64 ( arrow_schema:: TimeUnit :: Microsecond )
924+ ) ;
925+ assert_eq ! (
926+ to_arrow_type( & DataTypes :: time_with_precision( 9 ) ) ,
927+ ArrowDataType :: Time64 ( arrow_schema:: TimeUnit :: Nanosecond )
928+ ) ;
929+ assert_eq ! (
930+ to_arrow_type( & DataTypes :: timestamp_with_precision( 0 ) ) ,
931+ ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Second , None )
932+ ) ;
933+ assert_eq ! (
934+ to_arrow_type( & DataTypes :: timestamp_with_precision( 3 ) ) ,
935+ ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Millisecond , None )
936+ ) ;
937+ assert_eq ! (
938+ to_arrow_type( & DataTypes :: timestamp_with_precision( 6 ) ) ,
939+ ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Microsecond , None )
940+ ) ;
941+ assert_eq ! (
942+ to_arrow_type( & DataTypes :: timestamp_with_precision( 9 ) ) ,
943+ ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Nanosecond , None )
944+ ) ;
945+ assert_eq ! (
946+ to_arrow_type( & DataTypes :: timestamp_ltz_with_precision( 0 ) ) ,
947+ ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Second , None )
948+ ) ;
949+ assert_eq ! (
950+ to_arrow_type( & DataTypes :: timestamp_ltz_with_precision( 3 ) ) ,
951+ ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Millisecond , None )
952+ ) ;
953+ assert_eq ! (
954+ to_arrow_type( & DataTypes :: timestamp_ltz_with_precision( 6 ) ) ,
955+ ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Microsecond , None )
956+ ) ;
957+ assert_eq ! (
958+ to_arrow_type( & DataTypes :: timestamp_ltz_with_precision( 9 ) ) ,
959+ ArrowDataType :: Timestamp ( arrow_schema:: TimeUnit :: Nanosecond , None )
960+ ) ;
961+ assert_eq ! ( to_arrow_type( & DataTypes :: bytes( ) ) , ArrowDataType :: Binary ) ;
962+ assert_eq ! (
963+ to_arrow_type( & DataTypes :: binary( 16 ) ) ,
964+ ArrowDataType :: FixedSizeBinary ( 16 )
965+ ) ;
966+
967+ assert_eq ! (
968+ to_arrow_type( & DataTypes :: array( DataTypes :: int( ) ) ) ,
969+ ArrowDataType :: List ( Field :: new_list_field( ArrowDataType :: Int32 , true ) . into( ) )
970+ ) ;
971+
972+ assert_eq ! (
973+ to_arrow_type( & DataTypes :: map( DataTypes :: string( ) , DataTypes :: int( ) ) ) ,
974+ ArrowDataType :: Map (
975+ Arc :: new( Field :: new(
976+ "entries" ,
977+ ArrowDataType :: Struct ( arrow_schema:: Fields :: from( vec![
978+ Field :: new( "key" , ArrowDataType :: Utf8 , true ) ,
979+ Field :: new( "value" , ArrowDataType :: Int32 , true ) ,
980+ ] ) ) ,
981+ true ,
982+ ) ) ,
983+ false ,
984+ )
985+ ) ;
986+
987+ assert_eq ! (
988+ to_arrow_type( & DataTypes :: row( vec![
989+ DataTypes :: field( "f1" . to_string( ) , DataTypes :: int( ) ) ,
990+ DataTypes :: field( "f2" . to_string( ) , DataTypes :: string( ) ) ,
991+ ] ) ) ,
992+ ArrowDataType :: Struct ( arrow_schema:: Fields :: from( vec![
993+ Field :: new( "f1" , ArrowDataType :: Int32 , true ) ,
994+ Field :: new( "f2" , ArrowDataType :: Utf8 , true ) ,
995+ ] ) )
996+ ) ;
997+ }
998+
999+ #[ test]
1000+ #[ should_panic( expected = "Invalid precision value for TimeType: 10" ) ]
1001+ fn test_time_invalid_precision ( ) {
1002+ to_arrow_type ( & DataTypes :: time_with_precision ( 10 ) ) ;
1003+ }
1004+
1005+ #[ test]
1006+ #[ should_panic( expected = "Invalid precision value for TimestampType: 10" ) ]
1007+ fn test_timestamp_invalid_precision ( ) {
1008+ to_arrow_type ( & DataTypes :: timestamp_with_precision ( 10 ) ) ;
1009+ }
1010+
1011+ #[ test]
1012+ #[ should_panic( expected = "Invalid precision value for TimestampLTzType: 10" ) ]
1013+ fn test_timestamp_ltz_invalid_precision ( ) {
1014+ to_arrow_type ( & DataTypes :: timestamp_ltz_with_precision ( 10 ) ) ;
1015+ }
1016+ }
0 commit comments