Skip to content

Commit bf0dea7

Browse files
authored
Merge pull request #13 from vaaaaanquish/save_file
Save file/Load file
2 parents ae7df90 + dbb0e0f commit bf0dea7

File tree

4 files changed

+200
-16
lines changed

4 files changed

+200
-16
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ jobs:
3232
- name: Build for ubuntu
3333
if: matrix.os == 'ubuntu-latest'
3434
run: |
35+
sudo apt-get update
3536
sudo apt-get install -y cmake libclang-dev libc++-dev gcc-multilib
3637
cargo build
3738
- name: Run tests

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "lightgbm"
3-
version = "0.1.2"
3+
version = "0.1.3"
44
authors = ["vaaaaanquish <[email protected]>"]
55
license = "MIT"
66
repository = "https://github.com/vaaaaanquish/LightGBM"

src/booster.rs

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,27 @@ use super::{LGBMResult, Dataset, LGBMError};
1313
/// Core model in LightGBM, containing functions for training, evaluating and predicting.
1414
pub struct Booster {
1515
pub(super) handle: lightgbm_sys::BoosterHandle,
16-
num_class: i64
1716
}
1817

1918

2019
impl Booster {
21-
fn new(handle: lightgbm_sys::BoosterHandle, num_class: i64) -> LGBMResult<Self> {
22-
Ok(Booster{handle, num_class})
20+
fn new(handle: lightgbm_sys::BoosterHandle) -> LGBMResult<Self> {
21+
Ok(Booster{handle})
22+
}
23+
24+
/// Init from model file.
25+
pub fn from_file(filename: String) -> LGBMResult<Self>{
26+
let filename_str = CString::new(filename).unwrap();
27+
let mut out_num_iterations = 0;
28+
let mut handle = std::ptr::null_mut();
29+
lgbm_call!(
30+
lightgbm_sys::LGBM_BoosterCreateFromModelfile(
31+
filename_str.as_ptr() as *const c_char,
32+
&mut out_num_iterations,
33+
&mut handle
34+
)
35+
).unwrap();
36+
Ok(Booster::new(handle)?)
2337
}
2438

2539
/// Create a new Booster model with given Dataset and parameters.
@@ -56,14 +70,6 @@ impl Booster {
5670
num_iterations = parameter["num_iterations"].as_i64().unwrap();
5771
}
5872

59-
// get num_class
60-
let num_class: i64;
61-
if parameter["num_class"].is_null(){
62-
num_class = 1;
63-
} else {
64-
num_class = parameter["num_class"].as_i64().unwrap();
65-
}
66-
6773
// exchange params {"x": "y", "z": 1} => "x=y z=1"
6874
let params_string = parameter.as_object().unwrap().iter().map(|(k, v)| format!("{}={}", k, v)).collect::<Vec<_>>().join(" ");
6975
let params_cstring = CString::new(params_string).unwrap();
@@ -81,7 +87,7 @@ impl Booster {
8187
for _ in 1..num_iterations {
8288
lgbm_call!(lightgbm_sys::LGBM_BoosterUpdateOneIter(handle, &mut is_finished))?;
8389
}
84-
Ok(Booster::new(handle, num_class)?)
90+
Ok(Booster::new(handle)?)
8591
}
8692

8793
/// Predict results for given data.
@@ -102,9 +108,19 @@ impl Booster {
102108
let feature_length = data[0].len();
103109
let params = CString::new("").unwrap();
104110
let mut out_length: c_long = 0;
105-
let out_result: Vec<f64> = vec![Default::default(); data.len() * self.num_class as usize];
106111
let flat_data = data.into_iter().flatten().collect::<Vec<_>>();
107112

113+
// get num_class
114+
let mut num_class = 0;
115+
lgbm_call!(
116+
lightgbm_sys::LGBM_BoosterGetNumClasses(
117+
self.handle,
118+
&mut num_class
119+
)
120+
)?;
121+
122+
let out_result: Vec<f64> = vec![Default::default(); data_length * num_class as usize];
123+
108124
lgbm_call!(
109125
lightgbm_sys::LGBM_BoosterPredictForMat(
110126
self.handle,
@@ -124,13 +140,28 @@ impl Booster {
124140

125141
// reshape for multiclass [1,2,3,4,5,6] -> [[1,2,3], [4,5,6]] # 3 class
126142
let reshaped_output;
127-
if self.num_class > 1{
128-
reshaped_output = out_result.chunks(self.num_class as usize).map(|x| x.to_vec()).collect();
143+
if num_class > 1{
144+
reshaped_output = out_result.chunks(num_class as usize).map(|x| x.to_vec()).collect();
129145
} else {
130146
reshaped_output = vec![out_result];
131147
}
132148
Ok(reshaped_output)
133149
}
150+
151+
152+
/// Save model to file.
153+
pub fn save_file(&self, filename: String){
154+
let filename_str = CString::new(filename).unwrap();
155+
lgbm_call!(
156+
lightgbm_sys::LGBM_BoosterSaveModel(
157+
self.handle,
158+
0 as i32,
159+
-1 as i32,
160+
0 as i32,
161+
filename_str.as_ptr() as *const c_char
162+
)
163+
).unwrap();
164+
}
134165
}
135166

136167

@@ -145,6 +176,9 @@ impl Drop for Booster {
145176
mod tests {
146177
use super::*;
147178
use serde_json::json;
179+
use std::path::Path;
180+
use std::fs;
181+
148182
fn read_train_file() -> LGBMResult<Dataset> {
149183
Dataset::from_file("lightgbm-sys/lightgbm/examples/binary_classification/binary.train".to_string())
150184
}
@@ -173,4 +207,26 @@ mod tests {
173207
}
174208
assert_eq!(normalized_result, vec![0, 0, 1]);
175209
}
210+
211+
#[test]
212+
fn save_file() {
213+
let dataset = read_train_file().unwrap();
214+
let params = json!{
215+
{
216+
"num_iterations": 1,
217+
"objective": "binary",
218+
"metric": "auc",
219+
"data_random_seed": 0
220+
}
221+
};
222+
let bst = Booster::train(dataset, &params).unwrap();
223+
bst.save_file("./test/test_save_file.output".to_string());
224+
assert!(Path::new("./test/test_save_file.output").exists());
225+
fs::remove_file("./test/test_save_file.output");
226+
}
227+
228+
#[test]
229+
fn from_file(){
230+
let bst = Booster::from_file("./test/test_from_file.input".to_string());
231+
}
176232
}

test/test_from_file.input

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
tree
2+
version=v3
3+
num_class=1
4+
num_tree_per_iteration=1
5+
label_index=0
6+
max_feature_idx=27
7+
objective=binary sigmoid:1
8+
feature_names=Column_0 Column_1 Column_2 Column_3 Column_4 Column_5 Column_6 Column_7 Column_8 Column_9 Column_10 Column_11 Column_12 Column_13 Column_14 Column_15 Column_16 Column_17 Column_18 Column_19 Column_20 Column_21 Column_22 Column_23 Column_24 Column_25 Column_26 Column_27
9+
feature_infos=[0.27500000000000002:6.6950000000000003] [-2.4169999999999998:2.4300000000000002] [-1.7429999999999999:1.7429999999999999] [0.019:5.7000000000000002] [-1.7429999999999999:1.7429999999999999] [0.159:4.1900000000000004] [-2.9409999999999998:2.9699999999999998] [-1.7410000000000001:1.7410000000000001] [0:2.173] [0.19:5.1929999999999996] [-2.9039999999999999:2.9089999999999998] [-1.742:1.7429999999999999] [0:2.2149999999999999] [0.26400000000000001:6.5229999999999997] [-2.7279999999999998:2.7269999999999999] [-1.742:1.742] [0:2.548] [0.36499999999999999:6.0679999999999996] [-2.4950000000000001:2.496] [-1.74:1.7429999999999999] [0:3.1019999999999999] [0.17199999999999999:13.098000000000001] [0.41899999999999998:7.3920000000000003] [0.46100000000000002:3.6819999999999999] [0.38400000000000001:6.5830000000000002] [0.092999999999999999:7.8600000000000003] [0.38900000000000001:4.5430000000000001] [0.48899999999999999:4.3159999999999998]
10+
tree_sizes=
11+
12+
end of trees
13+
14+
feature_importances:
15+
16+
parameters:
17+
[boosting: gbdt]
18+
[objective: binary]
19+
[metric: auc]
20+
[tree_learner: serial]
21+
[device_type: cpu]
22+
[linear_tree: 0]
23+
[data: ]
24+
[valid: ]
25+
[num_iterations: 1]
26+
[learning_rate: 0.1]
27+
[num_leaves: 31]
28+
[num_threads: 0]
29+
[deterministic: 0]
30+
[force_col_wise: 0]
31+
[force_row_wise: 0]
32+
[histogram_pool_size: -1]
33+
[max_depth: -1]
34+
[min_data_in_leaf: 20]
35+
[min_sum_hessian_in_leaf: 0.001]
36+
[bagging_fraction: 1]
37+
[pos_bagging_fraction: 1]
38+
[neg_bagging_fraction: 1]
39+
[bagging_freq: 0]
40+
[bagging_seed: 3]
41+
[feature_fraction: 1]
42+
[feature_fraction_bynode: 1]
43+
[feature_fraction_seed: 2]
44+
[extra_trees: 0]
45+
[extra_seed: 6]
46+
[early_stopping_round: 0]
47+
[first_metric_only: 0]
48+
[max_delta_step: 0]
49+
[lambda_l1: 0]
50+
[lambda_l2: 0]
51+
[linear_lambda: 0]
52+
[min_gain_to_split: 0]
53+
[drop_rate: 0.1]
54+
[max_drop: 50]
55+
[skip_drop: 0.5]
56+
[xgboost_dart_mode: 0]
57+
[uniform_drop: 0]
58+
[drop_seed: 4]
59+
[top_rate: 0.2]
60+
[other_rate: 0.1]
61+
[min_data_per_group: 100]
62+
[max_cat_threshold: 32]
63+
[cat_l2: 10]
64+
[cat_smooth: 10]
65+
[max_cat_to_onehot: 4]
66+
[top_k: 20]
67+
[monotone_constraints: ]
68+
[monotone_constraints_method: basic]
69+
[monotone_penalty: 0]
70+
[feature_contri: ]
71+
[forcedsplits_filename: ]
72+
[refit_decay_rate: 0.9]
73+
[cegb_tradeoff: 1]
74+
[cegb_penalty_split: 0]
75+
[cegb_penalty_feature_lazy: ]
76+
[cegb_penalty_feature_coupled: ]
77+
[path_smooth: 0]
78+
[interaction_constraints: ]
79+
[verbosity: 1]
80+
[saved_feature_importance_type: 0]
81+
[max_bin: 255]
82+
[max_bin_by_feature: ]
83+
[min_data_in_bin: 3]
84+
[bin_construct_sample_cnt: 200000]
85+
[data_random_seed: 0]
86+
[is_enable_sparse: 1]
87+
[enable_bundle: 1]
88+
[use_missing: 1]
89+
[zero_as_missing: 0]
90+
[feature_pre_filter: 1]
91+
[pre_partition: 0]
92+
[two_round: 0]
93+
[header: 0]
94+
[label_column: ]
95+
[weight_column: ]
96+
[group_column: ]
97+
[ignore_column: ]
98+
[categorical_feature: ]
99+
[forcedbins_filename: ]
100+
[objective_seed: 5]
101+
[num_class: 1]
102+
[is_unbalance: 0]
103+
[scale_pos_weight: 1]
104+
[sigmoid: 1]
105+
[boost_from_average: 1]
106+
[reg_sqrt: 0]
107+
[alpha: 0.9]
108+
[fair_c: 1]
109+
[poisson_max_delta_step: 0.7]
110+
[tweedie_variance_power: 1.5]
111+
[lambdarank_truncation_level: 30]
112+
[lambdarank_norm: 1]
113+
[label_gain: ]
114+
[eval_at: ]
115+
[multi_error_top_k: 1]
116+
[auc_mu_weights: ]
117+
[num_machines: 1]
118+
[local_listen_port: 12400]
119+
[time_out: 120]
120+
[machine_list_filename: ]
121+
[machines: ]
122+
[gpu_platform_id: -1]
123+
[gpu_device_id: -1]
124+
[gpu_use_dp: 0]
125+
[num_gpu: 1]
126+
127+
end of parameters

0 commit comments

Comments
 (0)