@@ -6,7 +6,7 @@ use serde_json::Value;
66
77use lightgbm_sys;
88
9- use super :: { Dataset , LGBMError , LGBMResult } ;
9+ use super :: { Dataset , Error , Result } ;
1010
1111/// Core model in LightGBM, containing functions for training, evaluating and predicting.
1212pub struct Booster {
@@ -19,7 +19,7 @@ impl Booster {
1919 }
2020
2121 /// Init from model file.
22- pub fn from_file ( filename : String ) -> LGBMResult < Self > {
22+ pub fn from_file ( filename : String ) -> Result < Self > {
2323 let filename_str = CString :: new ( filename) . unwrap ( ) ;
2424 let mut out_num_iterations = 0 ;
2525 let mut handle = std:: ptr:: null_mut ( ) ;
@@ -56,7 +56,7 @@ impl Booster {
5656 /// };
5757 /// let bst = Booster::train(dataset, ¶ms).unwrap();
5858 /// ```
59- pub fn train ( dataset : Dataset , parameter : & Value ) -> LGBMResult < Self > {
59+ pub fn train ( dataset : Dataset , parameter : & Value ) -> Result < Self > {
6060 // get num_iterations
6161 let num_iterations: i64 = if parameter[ "num_iterations" ] . is_null ( ) {
6262 100
@@ -104,7 +104,7 @@ impl Booster {
104104 /// ```
105105 /// let output = vec![vec![1.0, 0.109, 0.433]];
106106 /// ```
107- pub fn predict ( & self , data : Vec < Vec < f64 > > ) -> LGBMResult < Vec < Vec < f64 > > > {
107+ pub fn predict ( & self , data : Vec < Vec < f64 > > ) -> Result < Vec < Vec < f64 > > > {
108108 let data_length = data. len ( ) ;
109109 let feature_length = data[ 0 ] . len ( ) ;
110110 let params = CString :: new ( "" ) . unwrap ( ) ;
@@ -148,16 +148,16 @@ impl Booster {
148148 }
149149
150150 /// Save model to file.
151- pub fn save_file ( & self , filename : String ) {
151+ pub fn save_file ( & self , filename : String ) -> Result < ( ) > {
152152 let filename_str = CString :: new ( filename) . unwrap ( ) ;
153153 lgbm_call ! ( lightgbm_sys:: LGBM_BoosterSaveModel (
154154 self . handle,
155155 0_i32 ,
156156 -1_i32 ,
157157 0_i32 ,
158158 filename_str. as_ptr( ) as * const c_char
159- ) )
160- . unwrap ( ) ;
159+ ) ) ? ;
160+ Ok ( ( ) )
161161 }
162162}
163163
@@ -174,7 +174,7 @@ mod tests {
174174 use std:: fs;
175175 use std:: path:: Path ;
176176
177- fn read_train_file ( ) -> LGBMResult < Dataset > {
177+ fn read_train_file ( ) -> Result < Dataset > {
178178 Dataset :: from_file (
179179 "lightgbm-sys/lightgbm/examples/binary_classification/binary.train" . to_string ( ) ,
180180 )
@@ -213,7 +213,10 @@ mod tests {
213213 }
214214 } ;
215215 let bst = Booster :: train ( dataset, & params) . unwrap ( ) ;
216- bst. save_file ( "./test/test_save_file.output" . to_string ( ) ) ;
216+ assert_eq ! (
217+ bst. save_file( "./test/test_save_file.output" . to_string( ) ) ,
218+ Ok ( ( ) )
219+ ) ;
217220 assert ! ( Path :: new( "./test/test_save_file.output" ) . exists( ) ) ;
218221 let _ = fs:: remove_file ( "./test/test_save_file.output" ) ;
219222 }
0 commit comments