Skip to content

Commit 361a380

Browse files
authored
Merge pull request #8 from vaaaaanquish/add_multi_classification
Add multi classification examples
2 parents a7b3551 + 573565b commit 361a380

File tree

13 files changed

+521
-93
lines changed

13 files changed

+521
-93
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,6 @@ Cargo.lock
1313
lightgbm-sys/target
1414

1515
# example
16-
examples/target
16+
examples/binary_classification/target/
17+
examples/multiclass_classification/target/
18+
examples/regression/target/

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "lightgbm"
3-
version = "0.1.1"
3+
version = "0.1.2"
44
authors = ["vaaaaanquish <[email protected]>"]
55
license = "MIT"
66
repository = "https://github.com/vaaaaanquish/LightGBM"
@@ -11,3 +11,5 @@ exclude = [".gitignore", ".gitmodules", "examples", "lightgbm-sys"]
1111
[dependencies]
1212
lightgbm-sys = "0.1.0"
1313
libc = "0.2.81"
14+
derive_builder = "0.5.1"
15+
serde_json = "1.0.59"
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[package]
2+
name = "lightgbm-example-binary-classification"
3+
version = "0.1.0"
4+
authors = ["vaaaaanquish <[email protected]>"]
5+
publish = false
6+
7+
[dependencies]
8+
lightgbm = { path = "../../" }
9+
csv = "1.1.5"
10+
itertools = "0.9.0"
11+
serde_json = "1.0.59"
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
extern crate lightgbm;
2+
extern crate csv;
3+
extern crate serde_json;
4+
extern crate itertools;
5+
6+
7+
use itertools::zip;
8+
use lightgbm::{Dataset, Booster};
9+
use serde_json::json;
10+
11+
12+
fn load_file(file_path: &str) -> (Vec<Vec<f64>>, Vec<f32>) {
13+
let rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path(file_path);
14+
let mut labels: Vec<f32> = Vec::new();
15+
let mut features: Vec<Vec<f64>> = Vec::new();
16+
for result in rdr.unwrap().records() {
17+
let record = result.unwrap();
18+
let label = record[0].parse::<f32>().unwrap();
19+
let feature: Vec<f64> = record.iter().map(|x| x.parse::<f64>().unwrap()).collect::<Vec<f64>>()[1..].to_vec();
20+
labels.push(label);
21+
features.push(feature);
22+
}
23+
(features, labels)
24+
}
25+
26+
27+
fn main() -> std::io::Result<()> {
28+
let (train_features, train_labels) = load_file("../../lightgbm-sys/lightgbm/examples/binary_classification/binary.train");
29+
let (test_features, test_labels) = load_file("../../lightgbm-sys/lightgbm/examples/binary_classification/binary.test");
30+
let train_dataset = Dataset::from_mat(train_features, train_labels).unwrap();
31+
32+
let params = json!{
33+
{
34+
"num_iterations": 100,
35+
"objective": "binary",
36+
"metric": "auc"
37+
}
38+
};
39+
40+
let booster = Booster::train(train_dataset, &params).unwrap();
41+
let result = booster.predict(test_features).unwrap();
42+
43+
44+
let mut tp = 0;
45+
for (label, pred) in zip(&test_labels, &result[0]){
46+
if label == &(1 as f32) && pred > &(0.5 as f64) {
47+
tp = tp + 1;
48+
} else if label == &(0 as f32) && pred <= &(0.5 as f64) {
49+
tp = tp + 1;
50+
}
51+
println!("{}, {}", label, pred)
52+
}
53+
println!("{} / {}", &tp, result[0].len());
54+
Ok(())
55+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[package]
2+
name = "lightgbm-example-multiclass-classification"
3+
version = "0.1.0"
4+
authors = ["vaaaaanquish <[email protected]>"]
5+
publish = false
6+
7+
[dependencies]
8+
lightgbm = { path = "../../" }
9+
csv = "1.1.5"
10+
itertools = "0.9.0"
11+
serde_json = "1.0.59"
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
extern crate lightgbm;
2+
extern crate csv;
3+
extern crate serde_json;
4+
extern crate itertools;
5+
6+
7+
use itertools::zip;
8+
use lightgbm::{Dataset, Booster};
9+
use serde_json::json;
10+
11+
12+
fn load_file(file_path: &str) -> (Vec<Vec<f64>>, Vec<f32>) {
13+
let rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path(file_path);
14+
let mut labels: Vec<f32> = Vec::new();
15+
let mut features: Vec<Vec<f64>> = Vec::new();
16+
for result in rdr.unwrap().records() {
17+
let record = result.unwrap();
18+
let label = record[0].parse::<f32>().unwrap();
19+
let feature: Vec<f64> = record.iter().map(|x| x.parse::<f64>().unwrap()).collect::<Vec<f64>>()[1..].to_vec();
20+
labels.push(label);
21+
features.push(feature);
22+
}
23+
(features, labels)
24+
}
25+
26+
fn argmax<T: PartialOrd>(xs: &[T]) -> usize {
27+
if xs.len() == 1 {
28+
0
29+
} else {
30+
let mut maxval = &xs[0];
31+
let mut max_ixs: Vec<usize> = vec![0];
32+
for (i, x) in xs.iter().enumerate().skip(1) {
33+
if x > maxval {
34+
maxval = x;
35+
max_ixs = vec![i];
36+
} else if x == maxval {
37+
max_ixs.push(i);
38+
}
39+
}
40+
max_ixs[0]
41+
}
42+
}
43+
44+
fn main() -> std::io::Result<()> {
45+
let (train_features, train_labels) = load_file("../../lightgbm-sys/lightgbm/examples/multiclass_classification/multiclass.train");
46+
let (test_features, test_labels) = load_file("../../lightgbm-sys/lightgbm/examples/multiclass_classification/multiclass.test");
47+
let train_dataset = Dataset::from_mat(train_features, train_labels).unwrap();
48+
49+
let params = json!{
50+
{
51+
"num_iterations": 100,
52+
"objective": "multiclass",
53+
"metric": "multi_logloss",
54+
"num_class": 5,
55+
}
56+
};
57+
58+
let booster = Booster::train(train_dataset, &params).unwrap();
59+
let result = booster.predict(test_features).unwrap();
60+
61+
62+
let mut tp = 0;
63+
for (label, pred) in zip(&test_labels, &result){
64+
let argmax_pred = argmax(&pred);
65+
if *label == argmax_pred as f32 {
66+
tp = tp + 1;
67+
}
68+
println!("{}, {}, {:?}", label, argmax_pred, &pred);
69+
}
70+
println!("{} / {}", &tp, result.len());
71+
Ok(())
72+
}
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
[package]
2-
name = "lightgbm-example"
2+
name = "lightgbm-example-regression"
33
version = "0.1.0"
44
authors = ["vaaaaanquish <[email protected]>"]
55
publish = false
66

77
[dependencies]
8-
lightgbm = "0.1.1"
8+
lightgbm = { path = "../../" }
99
csv = "1.1.5"
1010
itertools = "0.9.0"
11+
serde_json = "1.0.59"

examples/regression/src/main.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
extern crate lightgbm;
2+
extern crate csv;
3+
extern crate serde_json;
4+
extern crate itertools;
5+
6+
7+
use itertools::zip;
8+
use lightgbm::{Dataset, Booster};
9+
use serde_json::json;
10+
11+
12+
fn load_file(file_path: &str) -> (Vec<Vec<f64>>, Vec<f32>) {
13+
let rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path(file_path);
14+
let mut labels: Vec<f32> = Vec::new();
15+
let mut features: Vec<Vec<f64>> = Vec::new();
16+
for result in rdr.unwrap().records() {
17+
let record = result.unwrap();
18+
let label = record[0].parse::<f32>().unwrap();
19+
let feature: Vec<f64> = record.iter().map(|x| x.parse::<f64>().unwrap()).collect::<Vec<f64>>()[1..].to_vec();
20+
labels.push(label);
21+
features.push(feature);
22+
}
23+
(features, labels)
24+
}
25+
26+
27+
fn main() -> std::io::Result<()> {
28+
let (train_features, train_labels) = load_file("../../lightgbm-sys/lightgbm/examples/regression/regression.train");
29+
let (test_features, test_labels) = load_file("../../lightgbm-sys/lightgbm/examples/regression/regression.test");
30+
let train_dataset = Dataset::from_mat(train_features, train_labels).unwrap();
31+
32+
let params = json!{
33+
{
34+
"num_iterations": 100,
35+
"objective": "regression",
36+
"metric": "l2"
37+
}
38+
};
39+
40+
let booster = Booster::train(train_dataset, &params).unwrap();
41+
let result = booster.predict(test_features).unwrap();
42+
43+
44+
let mut tp = 0;
45+
for (label, pred) in zip(&test_labels, &result[0]){
46+
if label == &(1 as f32) && pred > &(0.5 as f64) {
47+
tp = tp + 1;
48+
} else if label == &(0 as f32) && pred <= &(0.5 as f64) {
49+
tp = tp + 1;
50+
}
51+
println!("{}, {}", label, pred)
52+
}
53+
println!("{} / {}", &tp, result[0].len());
54+
Ok(())
55+
}

examples/src/main.rs

Lines changed: 0 additions & 56 deletions
This file was deleted.

0 commit comments

Comments
 (0)