11use lightgbm_sys;
22
3- use libc:: { c_char, c_int , c_float , c_double, c_long , c_void } ;
3+ use libc:: { c_char, c_double, c_void , c_long } ;
44use std:: ffi:: CString ;
5- use std:: convert:: TryInto ;
65use std;
76
87use super :: { LGBMResult , Dataset } ;
@@ -18,49 +17,49 @@ impl Booster {
1817 }
1918
2019 pub fn train ( dataset : Dataset ) -> LGBMResult < Self > {
20+ let params = CString :: new ( "objective=binary metric=auc" ) . unwrap ( ) ;
2121 let mut handle = std:: ptr:: null_mut ( ) ;
22- let mut params = CString :: new ( "app=binary metric=auc num_leaves=31" ) . unwrap ( ) ;
2322 unsafe {
2423 lightgbm_sys:: LGBM_BoosterCreate (
2524 dataset. handle ,
2625 params. as_ptr ( ) as * const c_char ,
27- & mut handle) ;
26+ & mut handle
27+ ) ;
2828 }
2929
3030 // train
3131 let mut is_finished: i32 = 0 ;
3232 unsafe {
33- for n in 1 ..50 {
34- let ret = lightgbm_sys:: LGBM_BoosterUpdateOneIter ( handle, & mut is_finished) ;
33+ for _ in 1 ..100 {
34+ lightgbm_sys:: LGBM_BoosterUpdateOneIter ( handle, & mut is_finished) ;
3535 }
3636 }
3737 Ok ( Booster :: new ( handle) ?)
3838 }
3939
40- pub fn predict ( & self , data : Vec < Vec < f32 > > ) -> LGBMResult < Vec < f64 > > {
41- let data_length = data. len ( ) as i32 ;
42- let feature_length = data[ 0 ] . len ( ) as i32 ;
43- let mut params = CString :: new ( "" ) . unwrap ( ) ;
44- let mut out_len: c_long = 0 ;
45- // let mut out_result = Vec::with_capacity(data_length.try_into().unwrap());
46- let data_size = data_length. try_into ( ) . unwrap ( ) ;
47- let mut out_result: Vec < f64 > = vec ! [ Default :: default ( ) ; data_size] ;
40+ pub fn predict ( & self , data : Vec < Vec < f64 > > ) -> LGBMResult < Vec < f64 > > {
41+ let data_length = data. len ( ) ;
42+ let feature_length = data[ 0 ] . len ( ) ;
43+ let params = CString :: new ( "" ) . unwrap ( ) ;
44+ let mut out_length: c_long = 0 ;
45+ let out_result: Vec < f64 > = vec ! [ Default :: default ( ) ; data. len( ) ] ;
46+ let flat_data = data. into_iter ( ) . flatten ( ) . collect :: < Vec < _ > > ( ) ;
4847
4948 unsafe {
5049 lightgbm_sys:: LGBM_BoosterPredictForMat (
5150 self . handle ,
52- data . as_ptr ( ) as * mut c_void ,
53- lightgbm_sys:: C_API_DTYPE_FLOAT32 . try_into ( ) . unwrap ( ) ,
54- data_length,
55- feature_length,
56- 0 ,
57- 0 ,
58- 0 ,
59- 0 ,
51+ flat_data . as_ptr ( ) as * const c_void ,
52+ lightgbm_sys:: C_API_DTYPE_FLOAT64 as i32 ,
53+ data_length as i32 ,
54+ feature_length as i32 ,
55+ 1 as i32 ,
56+ 0 as i32 ,
57+ 0 as i32 ,
58+ - 1 as i32 ,
6059 params. as_ptr ( ) as * const c_char ,
61- & mut out_len ,
60+ & mut out_length ,
6261 out_result. as_ptr ( ) as * mut c_double
63- ) ;
62+ ) ;
6463 }
6564 Ok ( out_result)
6665 }
0 commit comments