Skip to content

Commit 573565b

Browse files
authored
Merge pull request #9 from vaaaaanquish/add_regression
add regression examples
2 parents 5cfaa71 + 6e6be3b commit 573565b

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ lightgbm-sys/target
1515
# example
1616
examples/binary_classification/target/
1717
examples/multiclass_classification/target/
18+
examples/regression/target/

examples/regression/Cargo.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[package]
2+
name = "lightgbm-example-regression"
3+
version = "0.1.0"
4+
authors = ["vaaaaanquish <[email protected]>"]
5+
publish = false
6+
7+
[dependencies]
8+
lightgbm = { path = "../../" }
9+
csv = "1.1.5"
10+
itertools = "0.9.0"
11+
serde_json = "1.0.59"

examples/regression/src/main.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
extern crate lightgbm;
2+
extern crate csv;
3+
extern crate serde_json;
4+
extern crate itertools;
5+
6+
7+
use itertools::zip;
8+
use lightgbm::{Dataset, Booster};
9+
use serde_json::json;
10+
11+
12+
fn load_file(file_path: &str) -> (Vec<Vec<f64>>, Vec<f32>) {
13+
let rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path(file_path);
14+
let mut labels: Vec<f32> = Vec::new();
15+
let mut features: Vec<Vec<f64>> = Vec::new();
16+
for result in rdr.unwrap().records() {
17+
let record = result.unwrap();
18+
let label = record[0].parse::<f32>().unwrap();
19+
let feature: Vec<f64> = record.iter().map(|x| x.parse::<f64>().unwrap()).collect::<Vec<f64>>()[1..].to_vec();
20+
labels.push(label);
21+
features.push(feature);
22+
}
23+
(features, labels)
24+
}
25+
26+
27+
fn main() -> std::io::Result<()> {
28+
let (train_features, train_labels) = load_file("../../lightgbm-sys/lightgbm/examples/regression/regression.train");
29+
let (test_features, test_labels) = load_file("../../lightgbm-sys/lightgbm/examples/regression/regression.test");
30+
let train_dataset = Dataset::from_mat(train_features, train_labels).unwrap();
31+
32+
let params = json!{
33+
{
34+
"num_iterations": 100,
35+
"objective": "regression",
36+
"metric": "l2"
37+
}
38+
};
39+
40+
let booster = Booster::train(train_dataset, &params).unwrap();
41+
let result = booster.predict(test_features).unwrap();
42+
43+
44+
let mut tp = 0;
45+
for (label, pred) in zip(&test_labels, &result[0]){
46+
if label == &(1 as f32) && pred > &(0.5 as f64) {
47+
tp = tp + 1;
48+
} else if label == &(0 as f32) && pred <= &(0.5 as f64) {
49+
tp = tp + 1;
50+
}
51+
println!("{}, {}", label, pred)
52+
}
53+
println!("{} / {}", &tp, result[0].len());
54+
Ok(())
55+
}

0 commit comments

Comments
 (0)