@@ -13,13 +13,27 @@ use super::{LGBMResult, Dataset, LGBMError};
1313/// Core model in LightGBM, containing functions for training, evaluating and predicting.
1414pub struct Booster {
1515 pub ( super ) handle : lightgbm_sys:: BoosterHandle ,
16- num_class : i64
1716}
1817
1918
2019impl Booster {
21- fn new ( handle : lightgbm_sys:: BoosterHandle , num_class : i64 ) -> LGBMResult < Self > {
22- Ok ( Booster { handle, num_class} )
20+ fn new ( handle : lightgbm_sys:: BoosterHandle ) -> LGBMResult < Self > {
21+ Ok ( Booster { handle} )
22+ }
23+
24+ /// Init from model file.
25+ pub fn from_file ( filename : String ) -> LGBMResult < Self > {
26+ let filename_str = CString :: new ( filename) . unwrap ( ) ;
27+ let mut out_num_iterations = 0 ;
28+ let mut handle = std:: ptr:: null_mut ( ) ;
29+ lgbm_call ! (
30+ lightgbm_sys:: LGBM_BoosterCreateFromModelfile (
31+ filename_str. as_ptr( ) as * const c_char,
32+ & mut out_num_iterations,
33+ & mut handle
34+ )
35+ ) . unwrap ( ) ;
36+ Ok ( Booster :: new ( handle) ?)
2337 }
2438
2539 /// Create a new Booster model with given Dataset and parameters.
@@ -56,14 +70,6 @@ impl Booster {
5670 num_iterations = parameter[ "num_iterations" ] . as_i64 ( ) . unwrap ( ) ;
5771 }
5872
59- // get num_class
60- let num_class: i64 ;
61- if parameter[ "num_class" ] . is_null ( ) {
62- num_class = 1 ;
63- } else {
64- num_class = parameter[ "num_class" ] . as_i64 ( ) . unwrap ( ) ;
65- }
66-
6773 // exchange params {"x": "y", "z": 1} => "x=y z=1"
6874 let params_string = parameter. as_object ( ) . unwrap ( ) . iter ( ) . map ( |( k, v) | format ! ( "{}={}" , k, v) ) . collect :: < Vec < _ > > ( ) . join ( " " ) ;
6975 let params_cstring = CString :: new ( params_string) . unwrap ( ) ;
@@ -81,7 +87,7 @@ impl Booster {
8187 for _ in 1 ..num_iterations {
8288 lgbm_call ! ( lightgbm_sys:: LGBM_BoosterUpdateOneIter ( handle, & mut is_finished) ) ?;
8389 }
84- Ok ( Booster :: new ( handle, num_class ) ?)
90+ Ok ( Booster :: new ( handle) ?)
8591 }
8692
8793 /// Predict results for given data.
@@ -102,9 +108,19 @@ impl Booster {
102108 let feature_length = data[ 0 ] . len ( ) ;
103109 let params = CString :: new ( "" ) . unwrap ( ) ;
104110 let mut out_length: c_long = 0 ;
105- let out_result: Vec < f64 > = vec ! [ Default :: default ( ) ; data. len( ) * self . num_class as usize ] ;
106111 let flat_data = data. into_iter ( ) . flatten ( ) . collect :: < Vec < _ > > ( ) ;
107112
113+ // get num_class
114+ let mut num_class = 0 ;
115+ lgbm_call ! (
116+ lightgbm_sys:: LGBM_BoosterGetNumClasses (
117+ self . handle,
118+ & mut num_class
119+ )
120+ ) ?;
121+
122+ let out_result: Vec < f64 > = vec ! [ Default :: default ( ) ; data_length * num_class as usize ] ;
123+
108124 lgbm_call ! (
109125 lightgbm_sys:: LGBM_BoosterPredictForMat (
110126 self . handle,
@@ -124,13 +140,28 @@ impl Booster {
124140
125141 // reshape for multiclass [1,2,3,4,5,6] -> [[1,2,3], [4,5,6]] # 3 class
126142 let reshaped_output;
127- if self . num_class > 1 {
128- reshaped_output = out_result. chunks ( self . num_class as usize ) . map ( |x| x. to_vec ( ) ) . collect ( ) ;
143+ if num_class > 1 {
144+ reshaped_output = out_result. chunks ( num_class as usize ) . map ( |x| x. to_vec ( ) ) . collect ( ) ;
129145 } else {
130146 reshaped_output = vec ! [ out_result] ;
131147 }
132148 Ok ( reshaped_output)
133149 }
150+
151+
152+ /// Save model to file.
153+ pub fn save_file ( & self , filename : String ) {
154+ let filename_str = CString :: new ( filename) . unwrap ( ) ;
155+ lgbm_call ! (
156+ lightgbm_sys:: LGBM_BoosterSaveModel (
157+ self . handle,
158+ 0 as i32 ,
159+ -1 as i32 ,
160+ 0 as i32 ,
161+ filename_str. as_ptr( ) as * const c_char
162+ )
163+ ) . unwrap ( ) ;
164+ }
134165}
135166
136167
@@ -145,6 +176,9 @@ impl Drop for Booster {
145176mod tests {
146177 use super :: * ;
147178 use serde_json:: json;
179+ use std:: path:: Path ;
180+ use std:: fs;
181+
148182 fn read_train_file ( ) -> LGBMResult < Dataset > {
149183 Dataset :: from_file ( "lightgbm-sys/lightgbm/examples/binary_classification/binary.train" . to_string ( ) )
150184 }
@@ -173,4 +207,26 @@ mod tests {
173207 }
174208 assert_eq ! ( normalized_result, vec![ 0 , 0 , 1 ] ) ;
175209 }
210+
211+ #[ test]
212+ fn save_file ( ) {
213+ let dataset = read_train_file ( ) . unwrap ( ) ;
214+ let params = json ! {
215+ {
216+ "num_iterations" : 1 ,
217+ "objective" : "binary" ,
218+ "metric" : "auc" ,
219+ "data_random_seed" : 0
220+ }
221+ } ;
222+ let bst = Booster :: train ( dataset, & params) . unwrap ( ) ;
223+ bst. save_file ( "./test/test_save_file.output" . to_string ( ) ) ;
224+ assert ! ( Path :: new( "./test/test_save_file.output" ) . exists( ) ) ;
225+ fs:: remove_file ( "./test/test_save_file.output" ) ;
226+ }
227+
228+ #[ test]
229+ fn from_file ( ) {
230+ let bst = Booster :: from_file ( "./test/test_from_file.input" . to_string ( ) ) ;
231+ }
176232}
0 commit comments