Skip to content

Commit e2663f9

Browse files
authored
Merge pull request #7 from vaaaaanquish/fix_train_example
fix train example
2 parents b24c4d7 + eefac57 commit e2663f9

File tree

6 files changed

+70
-51
lines changed

6 files changed

+70
-51
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "lightgbm"
3-
version = "0.1.0"
3+
version = "0.1.1"
44
authors = ["vaaaaanquish <[email protected]>"]
55
license = "MIT"
66
repository = "https://github.com/vaaaaanquish/LightGBM"

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
LightGBM Rust binding
33

44

5+
Now: Done is better than perfect.
6+
7+
58
# develop
69

710
```

examples/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@ authors = ["vaaaaanquish <[email protected]>"]
55
publish = false
66

77
[dependencies]
8-
lightgbm = "0.1.0"
8+
lightgbm = "0.1.1"
99
csv = "1.1.5"
1010
itertools = "0.9.0"

examples/src/main.rs

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,27 @@ fn main() -> std::io::Result<()> {
1414
// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0];
1515
// let train_dataset = Dataset::from_mat(feature, label).unwrap();
1616

17-
let train_dataset = Dataset::from_file("../lightgbm-sys/lightgbm/examples/binary_classification/binary.train".to_string()).unwrap();
17+
// let train_dataset = Dataset::from_file("../lightgbm-sys/lightgbm/examples/binary_classification/binary.train".to_string()).unwrap();
18+
19+
let mut train_rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path("../lightgbm-sys/lightgbm/examples/binary_classification/binary.train")?;
20+
let mut train_labels: Vec<f32> = Vec::new();
21+
let mut train_feature: Vec<Vec<f64>> = Vec::new();
22+
for result in train_rdr.records() {
23+
let record = result?;
24+
let label = record[0].parse::<f32>().unwrap();
25+
let feature: Vec<f64> = record.iter().map(|x| x.parse::<f64>().unwrap()).collect::<Vec<f64>>()[1..].to_vec();
26+
train_labels.push(label);
27+
train_feature.push(feature);
28+
}
29+
let train_dataset = Dataset::from_mat(train_feature, train_labels).unwrap();
1830

1931
let mut rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path("../lightgbm-sys/lightgbm/examples/binary_classification/binary.test")?;
20-
let mut test_labels: Vec<i8> = Vec::new();
21-
let mut test_feature: Vec<Vec<f32>> = Vec::new();
32+
let mut test_labels: Vec<f32> = Vec::new();
33+
let mut test_feature: Vec<Vec<f64>> = Vec::new();
2234
for result in rdr.records() {
2335
let record = result?;
24-
let label = record[0].parse::<i8>().unwrap();
25-
let feature: Vec<f32> = record.iter().map(|x| x.parse::<f32>().unwrap()).collect::<Vec<f32>>()[1..].to_vec();
36+
let label = record[0].parse::<f32>().unwrap();
37+
let feature: Vec<f64> = record.iter().map(|x| x.parse::<f64>().unwrap()).collect::<Vec<f64>>()[1..].to_vec();
2638
test_labels.push(label);
2739
test_feature.push(feature);
2840
}
@@ -32,11 +44,12 @@ fn main() -> std::io::Result<()> {
3244

3345
let mut tp = 0;
3446
for (label, pred) in zip(&test_labels, &result){
35-
if label == &(1 as i8) && pred > &(0.5 as f64) {
47+
if label == &(1 as f32) && pred > &(0.5 as f64) {
3648
tp = tp + 1;
37-
} else if label == &(0 as i8) && pred <= &(0.5 as f64) {
49+
} else if label == &(0 as f32) && pred <= &(0.5 as f64) {
3850
tp = tp + 1;
3951
}
52+
println!("{}, {}", label, pred)
4053
}
4154
println!("{} / {}", &tp, result.len());
4255
Ok(())

src/booster.rs

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
use lightgbm_sys;
22

3-
use libc::{c_char, c_int, c_float, c_double, c_long, c_void};
3+
use libc::{c_char, c_double, c_void, c_long};
44
use std::ffi::CString;
5-
use std::convert::TryInto;
65
use std;
76

87
use super::{LGBMResult, Dataset};
@@ -18,49 +17,49 @@ impl Booster {
1817
}
1918

2019
pub fn train(dataset: Dataset) -> LGBMResult<Self> {
20+
let params = CString::new("objective=binary metric=auc").unwrap();
2121
let mut handle = std::ptr::null_mut();
22-
let mut params = CString::new("app=binary metric=auc num_leaves=31").unwrap();
2322
unsafe {
2423
lightgbm_sys::LGBM_BoosterCreate(
2524
dataset.handle,
2625
params.as_ptr() as *const c_char,
27-
&mut handle);
26+
&mut handle
27+
);
2828
}
2929

3030
// train
3131
let mut is_finished: i32 = 0;
3232
unsafe{
33-
for n in 1..50 {
34-
let ret = lightgbm_sys::LGBM_BoosterUpdateOneIter(handle, &mut is_finished);
33+
for _ in 1..100 {
34+
lightgbm_sys::LGBM_BoosterUpdateOneIter(handle, &mut is_finished);
3535
}
3636
}
3737
Ok(Booster::new(handle)?)
3838
}
3939

40-
pub fn predict(&self, data: Vec<Vec<f32>>) -> LGBMResult<Vec<f64>> {
41-
let data_length = data.len() as i32;
42-
let feature_length = data[0].len() as i32;
43-
let mut params = CString::new("").unwrap();
44-
let mut out_len: c_long = 0;
45-
// let mut out_result = Vec::with_capacity(data_length.try_into().unwrap());
46-
let data_size = data_length.try_into().unwrap();
47-
let mut out_result: Vec<f64> = vec![Default::default(); data_size];
40+
pub fn predict(&self, data: Vec<Vec<f64>>) -> LGBMResult<Vec<f64>> {
41+
let data_length = data.len();
42+
let feature_length = data[0].len();
43+
let params = CString::new("").unwrap();
44+
let mut out_length: c_long = 0;
45+
let out_result: Vec<f64> = vec![Default::default(); data.len()];
46+
let flat_data = data.into_iter().flatten().collect::<Vec<_>>();
4847

4948
unsafe {
5049
lightgbm_sys::LGBM_BoosterPredictForMat(
5150
self.handle,
52-
data.as_ptr() as * mut c_void,
53-
lightgbm_sys::C_API_DTYPE_FLOAT32.try_into().unwrap(),
54-
data_length,
55-
feature_length,
56-
0,
57-
0,
58-
0,
59-
0,
51+
flat_data.as_ptr() as *const c_void,
52+
lightgbm_sys::C_API_DTYPE_FLOAT64 as i32,
53+
data_length as i32,
54+
feature_length as i32,
55+
1 as i32,
56+
0 as i32,
57+
0 as i32,
58+
-1 as i32,
6059
params.as_ptr() as *const c_char,
61-
&mut out_len,
60+
&mut out_length,
6261
out_result.as_ptr() as *mut c_double
63-
);
62+
);
6463
}
6564
Ok(out_result)
6665
}

src/dataset.rs

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use libc::{c_void,c_char};
22

33
use std;
4-
use std::convert::TryInto;
54
use std::ffi::CString;
65
use lightgbm_sys;
76

@@ -17,45 +16,50 @@ impl Dataset {
1716
Ok(Dataset{handle})
1817
}
1918

20-
pub fn from_mat(data: Vec<Vec<f32>>, label: Vec<f32>) -> LGBMResult<Self> {
21-
let mut handle = std::ptr::null_mut();
22-
let data_length = data.len() as i32;
23-
let feature_length = data[0].len() as i32;
19+
pub fn from_mat(data: Vec<Vec<f64>>, label: Vec<f32>) -> LGBMResult<Self> {
20+
let data_length = data.len();
21+
let feature_length = data[0].len();
2422
let params = CString::new("").unwrap();
2523
let label_str = CString::new("label").unwrap();
24+
let reference = std::ptr::null_mut(); // not use
25+
let mut handle = std::ptr::null_mut();
26+
let flat_data = data.into_iter().flatten().collect::<Vec<_>>();
2627

2728
unsafe{
2829
lightgbm_sys::LGBM_DatasetCreateFromMat(
29-
data.as_ptr() as * mut c_void,
30-
lightgbm_sys::C_API_DTYPE_FLOAT32.try_into().unwrap(),
31-
data_length,
32-
feature_length,
33-
1,
30+
flat_data.as_ptr() as *const c_void,
31+
lightgbm_sys::C_API_DTYPE_FLOAT64 as i32,
32+
data_length as i32,
33+
feature_length as i32,
34+
1 as i32,
3435
params.as_ptr() as *const c_char,
35-
std::ptr::null_mut(),
36-
&mut handle);
36+
reference,
37+
&mut handle
38+
);
3739

3840
lightgbm_sys::LGBM_DatasetSetField(
3941
handle,
4042
label_str.as_ptr() as *const c_char,
41-
label.as_ptr() as * mut c_void,
42-
data_length,
43-
lightgbm_sys::C_API_DTYPE_FLOAT32.try_into().unwrap());
43+
label.as_ptr() as *const c_void,
44+
data_length as i32,
45+
lightgbm_sys::C_API_DTYPE_FLOAT32 as i32
46+
);
4447
}
4548
Ok(Dataset::new(handle)?)
4649
}
4750

4851
pub fn from_file(file_path: String) -> LGBMResult<Self> {
49-
let mut handle = std::ptr::null_mut();
5052
let file_path_str = CString::new(file_path).unwrap();
5153
let params = CString::new("").unwrap();
54+
let mut handle = std::ptr::null_mut();
5255

5356
unsafe {
5457
lightgbm_sys::LGBM_DatasetCreateFromFile(
55-
file_path_str.as_ptr() as * const c_char,
58+
file_path_str.as_ptr() as *const c_char,
5659
params.as_ptr() as *const c_char,
5760
std::ptr::null_mut(),
58-
&mut handle);
61+
&mut handle
62+
);
5963
}
6064
Ok(Dataset::new(handle)?)
6165
}

0 commit comments

Comments
 (0)