diff --git a/cpp-benchmark/image_classification.cpp b/cpp-benchmark/image_classification.cpp new file mode 100644 index 0000000..610c713 --- /dev/null +++ b/cpp-benchmark/image_classification.cpp @@ -0,0 +1,459 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * This example demonstrates image classification workflow with pre-trained models using MXNet C++ API. + * The example performs following tasks. + * 1. Load the pre-trained model. + * 2. Load the parameters of pre-trained model. + * 3. Load the image to be classified in to NDArray. + * 4. Normalize the image using the mean of images that were used for training. + * 5. Run the forward pass and predict the input image. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "mxnet-cpp/MxNetCpp.h" +#include + + +using namespace mxnet::cpp; + +static mx_float DEFAULT_MEAN_R = 123.675; +static mx_float DEFAULT_MEAN_G = 116.28; +static mx_float DEFAULT_MEAN_B = 103.53; +/* + * class Predictor + * + * This class encapsulates the functionality to load the model, process input image and run the forward pass. + */ + +class Predictor { + public: + Predictor() {} + Predictor(const std::string& model_json_file, + const std::string& model_params_file, + const Shape& input_shape, + bool gpu_context_type = false, + const std::string& synset_file = "", + const std::string& mean_image_file = ""); + void WarmUp(const std::string& image_file, int count); + void MeasurePredictions(const std::string& image_file, int count); + NDArray LoadInputImage(const std::string& image_file); + void Calculate(); + ~Predictor(); + + private: + void LoadModel(const std::string& model_json_file); + void LoadParameters(const std::string& model_parameters_file); + void LoadSynset(const std::string& synset_file); + void LoadMeanImageData(); + void LoadDefaultMeanImageData(); + void NormalizeInput(const std::string& mean_image_file); + inline bool FileExists(const std::string& name) { + struct stat buffer; + return (stat(name.c_str(), &buffer) == 0); + } + NDArray mean_img; + std::map args_map; + std::map aux_map; + std::vector output_labels; + Symbol net; + Executor *executor; + Shape input_shape; + NDArray mean_image_data; + NDArray std_dev_image_data; + Context global_ctx = Context::cpu(); + std::string mean_image_file; + std::vector measurements; +}; + + +/* + * The constructor takes following parameters as input: + * 1. model_json_file: The model in json formatted file. + * 2. model_params_file: File containing model parameters + * 3. synset_file: File containing the list of image labels + * 4. input_shape: Shape of input data to the model. Since this class will be running one inference at a time, + * the input shape is required to be in format Shape(1, number_of_channels, height, width) + * The input image will be resized to (height x width) size before running the inference. + * The constructor will: + * 1. Load the model and parameter files. + * 2. Load the synset file. + * 3. Invoke the SimpleBind to bind the input argument to the model and create an executor. + * + * The SimpleBind is expected to be invoked only once. + */ +Predictor::Predictor(const std::string& model_json_file, + const std::string& model_params_file, + const Shape& input_shape, + bool gpu_context_type, + const std::string& synset_file, + const std::string& mean_image_file): + input_shape(input_shape), + mean_image_file(mean_image_file) { + if (gpu_context_type) { + global_ctx = Context::gpu(); + } + // Load the model + LoadModel(model_json_file); + + // Load the model parameters. + LoadParameters(model_params_file); + + /* + * The data will be used to output the exact label that matches highest output of the model. + */ + LoadSynset(synset_file); + + /* + * Load the mean image data if specified. + */ + if (!mean_image_file.empty()) { + LoadMeanImageData(); + } else { + LG << "Mean image file for normalizing the input is not provide." + << " We will use the default mean values for R,G and B channels."; + LoadDefaultMeanImageData(); + } + + // Create an executor after binding the model to input parameters. + args_map["data"] = NDArray(input_shape, global_ctx, false); + executor = net.SimpleBind(global_ctx, args_map, std::map(), + std::map(), aux_map); +} + +/* + * The following function loads the model from json file. + */ +void Predictor::LoadModel(const std::string& model_json_file) { + if (!FileExists(model_json_file)) { + LG << "Model file " << model_json_file << " does not exist"; + throw std::runtime_error("Model file does not exist"); + } + LG << "Loading the model from " << model_json_file << std::endl; + net = Symbol::Load(model_json_file); +} + + +/* + * The following function loads the model parameters. + */ +void Predictor::LoadParameters(const std::string& model_parameters_file) { + if (!FileExists(model_parameters_file)) { + LG << "Parameter file " << model_parameters_file << " does not exist"; + throw std::runtime_error("Model parameters does not exist"); + } + LG << "Loading the model parameters from " << model_parameters_file << std::endl; + std::map parameters; + NDArray::Load(model_parameters_file, 0, ¶meters); + for (const auto &k : parameters) { + if (k.first.substr(0, 4) == "aux:") { + auto name = k.first.substr(4, k.first.size() - 4); + aux_map[name] = k.second.Copy(global_ctx); + } + if (k.first.substr(0, 4) == "arg:") { + auto name = k.first.substr(4, k.first.size() - 4); + args_map[name] = k.second.Copy(global_ctx); + } + } + /*WaitAll is need when we copy data between GPU and the main memory*/ + NDArray::WaitAll(); +} + + +/* + * The following function loads the synset file. + * This information will be used later to report the label of input image. + */ +void Predictor::LoadSynset(const std::string& synset_file) { + if (!FileExists(synset_file)) { + LG << "Synset file " << synset_file << " does not exist"; + throw std::runtime_error("Synset file does not exist"); + } + LG << "Loading the synset file."; + std::ifstream fi(synset_file.c_str()); + if (!fi.is_open()) { + std::cerr << "Error opening synset file " << synset_file << std::endl; + throw std::runtime_error("Error in opening the synset file."); + } + std::string synset, lemma; + while (fi >> synset) { + getline(fi, lemma); + output_labels.push_back(lemma); + } + fi.close(); +} + + +/* + * The following function loads the mean data from mean image file. + * This data will be used for normalizing the image before running the forward + * pass. + * The output data has the same shape as that of the input image data. + */ +void Predictor::LoadMeanImageData() { + LG << "Load the mean image data that will be used to normalize " + << "the image before running forward pass."; + mean_image_data = NDArray(input_shape, global_ctx, false); + mean_image_data.SyncCopyFromCPU( + NDArray::LoadToMap(mean_image_file)["mean_img"].GetData(), + input_shape.Size()); +} + + +/* + * The following function loads the default mean values for + * R, G and B channels into NDArray that has the same shape as that of + * input image. + */ +void Predictor::LoadDefaultMeanImageData() { + LG << "Loading the default mean image data"; + std::vector array; + /*resize pictures to (224, 224) according to the pretrained model*/ + int height = input_shape[2]; + int width = input_shape[3]; + int channels = input_shape[1]; + std::vector default_means; + default_means.push_back(DEFAULT_MEAN_R); + default_means.push_back(DEFAULT_MEAN_G); + default_means.push_back(DEFAULT_MEAN_B); + for (int c = 0; c < channels; ++c) { + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + array.push_back(default_means[c]); + } + } + } + mean_image_data = NDArray(input_shape, global_ctx, false); + mean_image_data.SyncCopyFromCPU(array.data(), input_shape.Size()); +} + + +/* + * The following function loads the input image into NDArray. + */ +NDArray Predictor::LoadInputImage(const std::string& image_file) { + if (!FileExists(image_file)) { + LG << "Image file " << image_file << " does not exist"; + throw std::runtime_error("Image file does not exist"); + } + LG << "Loading the image " << image_file << std::endl; + std::vector array; + cv::Mat mat = cv::imread(image_file); + /*resize pictures to (224, 224) according to the pretrained model*/ + int height = input_shape[2]; + int width = input_shape[3]; + int channels = input_shape[1]; + cv::resize(mat, mat, cv::Size(height, width)); + for (int c = 0; c < channels; ++c) { + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + array.push_back(static_cast(mat.data[(i * height + j) * 3 + c])); + } + } + } + NDArray image_data = NDArray(input_shape, global_ctx, false); + image_data.SyncCopyFromCPU(array.data(), input_shape.Size()); + return image_data; +} + + +void Predictor::WarmUp(const std::string& image_file, int count) +{ + NDArray image_data = LoadInputImage(image_file); + image_data.Slice(0, 1) -= mean_image_data; + + for (int i = 0; i < count; i++) { + image_data.CopyTo(&(executor->arg_dict()["data"])); + executor->Forward(false); + } + return; +} + + +void Predictor::Calculate() +{ + std::sort(measurements.begin(), measurements.end()); + int count = measurements.size(); + int Index_50 = static_cast(ceil((count - 1) * 50 /100)); + LG << "InferenceTime_P50(uSecs): " << measurements[Index_50].count(); + int Index_90 = static_cast(ceil((count - 1) * 90 /100)); + LG << "InferenceTime_P90(uSecs): " << measurements[Index_90].count(); + int Index_99 = static_cast(ceil((count - 1) * 99 /100)); + LG << "InferenceTime_P99(uSecs): " << measurements[Index_99].count(); +} + + +void Predictor::MeasurePredictions(const std::string& image_file, int count) +{ + NDArray image_data = LoadInputImage(image_file); + image_data.Slice(0, 1) -= mean_image_data; + for (int i = 0; i < count; i++) { + image_data.CopyTo(&(executor->arg_dict()["data"])); + + std::chrono::high_resolution_clock::time_point begin = + std::chrono::high_resolution_clock::now(); + executor->Forward(false); + auto array = executor->outputs[0]; + array.WaitToRead(); + std::chrono::high_resolution_clock::time_point end = std::chrono::high_resolution_clock::now(); + std::chrono::microseconds period = + std::chrono::duration_cast(end - begin); + + measurements.push_back(period); + + LG << "Begin = " << begin.time_since_epoch().count() << " End = " + << end.time_since_epoch().count() << " Period = " << period.count(); + } +} + + +Predictor::~Predictor() { + if (executor) { + delete executor; + } + MXNotifyShutdown(); +} + + +/* + * Convert the input string of number of hidden units into the vector of integers. + */ +std::vector getShapeDimensions(const std::string& hidden_units_string) { + std::vector dimensions; + char *p_next; + int num_unit = strtol(hidden_units_string.c_str(), &p_next, 10); + dimensions.push_back(num_unit); + while (*p_next) { + num_unit = strtol(p_next, &p_next, 10); + dimensions.push_back(num_unit); + } + return dimensions; +} + +void printUsage() { + std::cout << "Usage:" << std::endl; + std::cout << "inception_inference --symbol " << std::endl + << "--params " << std::endl + << "--image " << std::endl + << "--synset " << std::endl + << "[--input_shape ] " << std::endl + << "[--mean ] " + << std::endl + << "[--gpu ]" + << std::endl; +} + +int main(int argc, char** argv) { + std::string model_file_json; + std::string model_file_params; + std::string synset_file = ""; + std::string mean_image = ""; + std::string input_image = ""; + bool gpu_context_type = false; + int warmupIterations = 10; + int predictionIterations = 1000; + + std::string input_shape = "3 224 224"; + int index = 1; + while (index < argc) { + if (strcmp("--symbol", argv[index]) == 0) { + index++; + model_file_json = (index < argc ? argv[index]:""); + } else if (strcmp("--params", argv[index]) == 0) { + index++; + model_file_params = (index < argc ? argv[index]:""); + } else if (strcmp("--synset", argv[index]) == 0) { + index++; + synset_file = (index < argc ? argv[index]:""); + } else if (strcmp("--mean", argv[index]) == 0) { + index++; + mean_image = (index < argc ? argv[index]:""); + } else if (strcmp("--image", argv[index]) == 0) { + index++; + input_image = (index < argc ? argv[index]:""); + } else if (strcmp("--input_shape", argv[index]) == 0) { + index++; + input_shape = (index < argc ? argv[index]:input_shape); + } else if (strcmp("--warmup", argv[index]) == 0) { + index++; + warmupIterations = (index < argc ? strtol(argv[index], NULL, 10):10); + } else if (strcmp("--predict", argv[index]) == 0) { + index++; + predictionIterations = (index < argc ? strtol(argv[index], NULL, 10):100); + } else if (strcmp("--gpu", argv[index]) == 0) { + gpu_context_type = true; + } else if (strcmp("--help", argv[index]) == 0) { + printUsage(); + return 0; + } + index++; + } + + if (model_file_json.empty() || model_file_params.empty() || synset_file.empty()) { + LG << "ERROR: Model details such as symbol, param and/or synset files are not specified"; + printUsage(); + return 1; + } + + if (input_image.empty()) { + LG << "ERROR: Path to the input image is not specified."; + printUsage(); + return 1; + } + + std::vector input_dimensions = getShapeDimensions(input_shape); + + /* + * Since we are running inference for 1 image, add 1 to the input_dimensions so that + * the shape of input data for the model will be + * {no. of images, channels, height, width} + */ + input_dimensions.insert(input_dimensions.begin(), 1); + + Shape input_data_shape(input_dimensions); + + try { + // Initialize the predictor object + Predictor predict(model_file_json, model_file_params, input_data_shape, gpu_context_type, + synset_file, mean_image); + + predict.WarmUp(input_image, warmupIterations); + predict.MeasurePredictions(input_image, predictionIterations); + predict.Calculate(); + + } catch (std::runtime_error &error) { + LG << "Execution failed with ERROR: " << error.what(); + } catch (...) { + /* + * If underlying MXNet code has thrown an exception the error message is + * accessible through MXGetLastError() function. + */ + LG << "Execution failed with following MXNet error"; + LG << MXGetLastError(); + } + return 0; +} diff --git a/cpp-benchmark/mxnet-cpu-build.sh b/cpp-benchmark/mxnet-cpu-build.sh new file mode 100755 index 0000000..c256d84 --- /dev/null +++ b/cpp-benchmark/mxnet-cpu-build.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +echo "Downloading the build essentials for mxnet build." +sudo apt-get update +sudo apt-get install -y build-essential git +sudo apt-get install -y libopenblas-dev liblapack-dev +sudo apt-get install -y libopencv-dev +sudo apt-get install -y python-dev python-setuptools python-pip libgfortran3 + +export PERF_HOME=`pwd` + + +cd ${HOME} +git clone --recursive https://github.com/apache/incubator-mxnet.git incubator-mxnet + +export MXNET_HOME=${HOME}/incubator-mxnet +export CPP_INFERENCE_EXAMPLE=${MXNET_HOME}/cpp-package/example/inference + +echo "Copying the C++ performance program to ${MXNET_HOME}" +cp ${PERF_HOME}/image_classification.cpp ${CPP_INFERENCE_EXAMPLE}/. +cp ${PERF_HOME}/unit_test_image_classification_cpu.sh ${CPP_INFERENCE_EXAMPLE}/. + +echo "Installing the build dependencies" +cd ${MXNET_HOME}/ci/docker/install +sudo ./ubuntu_base.sh +sudo ./ubuntu_core.sh +sudo ./ubuntu_mklmk.sh +sudo ./ubuntu_mkl.sh + +echo "Building the mxnet at ${MXNET_HOME}" +cd ${MXNET_HOME} +make USE_CPP_PACKAGE=1 USE_OPENCV=1 USE_CUDA=0 USE_CUDNN=0 USE_LAPACK=0 2>&1 | tee buildLog.txt +cd ${HOME}/benchmarkai + diff --git a/cpp-benchmark/mxnet-gpu-build.sh b/cpp-benchmark/mxnet-gpu-build.sh new file mode 100755 index 0000000..119ef02 --- /dev/null +++ b/cpp-benchmark/mxnet-gpu-build.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +echo "Downloading the build essentials for mxnet build." +sudo apt-get update +sudo apt-get install -y build-essential git +sudo apt-get install -y libopenblas-dev liblapack-dev +sudo apt-get install -y libopencv-dev +sudo apt-get install -y python-dev python-setuptools python-pip libgfortran3 + +export PERF_HOME=`pwd` + + +cd ${HOME} +git clone --recursive https://github.com/apache/incubator-mxnet.git incubator-mxnet + +export MXNET_HOME=${HOME}/incubator-mxnet +export CPP_INFERENCE_EXAMPLE=${MXNET_HOME}/cpp-package/example/inference + +echo "Copying the C++ performance program to ${MXNET_HOME}" +cp ${PERF_HOME}/image_classification.cpp ${CPP_INFERENCE_EXAMPLE}/. +cp ${PERF_HOME}/unit_test_image_classification_cpu.sh ${CPP_INFERENCE_EXAMPLE}/. + +echo "Installing the build dependencies" +cd ${MXNET_HOME}/ci/docker/install +sudo ./ubuntu_base.sh +sudo ./ubuntu_core.sh +sudo ./ubuntu_mklml.sh +sudo ./ubuntu_mkl.sh + +echo "Building the mxnet at ${MXNET_HOME}" +cd ${MXNET_HOME} +make USE_CPP_PACKAGE=1 USE_MKLDNN=1 USE_OPENCV=1 USE_CUDA=1 USE_CUDNN=1 USE_CUDA_PATH=/usr/local/cuda USE_LAPACK=0 2>&1 | tee buildLog.txt +cd ${HOME}/benchmarkai + diff --git a/cpp-benchmark/unit_test_image_classification.sh b/cpp-benchmark/unit_test_image_classification.sh new file mode 100755 index 0000000..67df9cd --- /dev/null +++ b/cpp-benchmark/unit_test_image_classification.sh @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Downloading the data and model +mkdir -p model +cd model +wget -nc https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-symbol.json +wget -nc https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-0000.params +wget -nc https://s3.amazonaws.com/model-server/models/resnet50_ssd/synset.txt +wget -nc -O dog.jpg https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/dog.jpg?raw=true +wget -nc -O mean_224.nd https://github.com/dmlc/web-data/raw/master/mxnet/example/feature_extract/mean_224.nd +cd .. + + +# Running the example with dog image. +if [ "$(uname)" == "Darwin" ]; then + DYLD_LIBRARY_PATH=${DYLD_LIBRARY_PATH}:../../../lib ./image_classification --symbol "./model/resnet50_ssd_model-symbol.json" --params "./model/resnet50_ssd_model-0000.params" --synset "./model/synset.txt" --mean "./model/mean_224.nd" --image "./model/dog.jpg" --warmup 1 --predict 100 2&> image_classification.log +else + LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:../../../lib ./image_classification --symbol "./model/resnet50_ssd_model-symbol.json" --params "./model/resnet50_ssd_model-0000.params" --synset "./model/synset.txt" --mean "./model/mean_224.nd" --image "./model/dog.jpg" --warmup 1 --predict 100 2&> image_classification.log +fi diff --git a/cpp-benchmark/unit_test_image_classification_cpu.sh b/cpp-benchmark/unit_test_image_classification_cpu.sh new file mode 100755 index 0000000..9d1455a --- /dev/null +++ b/cpp-benchmark/unit_test_image_classification_cpu.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# Downloading the data and model +source ${HOME}/.dlamirc +export MXNET_HOME=${HOME}/incubator-mxnet +export CPP_INFERENCE_EXAMPLE=image_classification +export CPP_INFERENCE_EXAMPLE_FOLDER=${MXNET_HOME}/cpp-package/example/inference +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${MXNET_HOME}/lib + +mkdir -p ${CPP_INFERENCE_EXAMPLE_FOLDER}/model +cd ${CPP_INFERENCE_EXAMPLE_FOLDER}/model +wget -nc -O ${CPP_INFERENCE_EXAMPLE_FOLDER}/model/resnet50_ssd_model-symbol.json https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-symbol.json +wget -nc -O ${CPP_INFERENCE_EXAMPLE_FOLDER}/model/resnet50_ssd_model-0000.params https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-0000.params +wget -nc -O ${CPP_INFERENCE_EXAMPLE_FOLDER}/model/synset.txt https://s3.amazonaws.com/model-server/models/resnet50_ssd/synset.txt +wget -nc -O ${CPP_INFERENCE_EXAMPLE_FOLDER}/model/dog.jpg https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/dog.jpg?raw=true +wget -nc -O ${CPP_INFERENCE_EXAMPLE_FOLDER}/model/mean_224.nd https://github.com/dmlc/web-data/raw/master/mxnet/example/feature_extract/mean_224.nd +cd ${CPP_INFERENCE_EXAMPLE_FOLDER} +cp ${MXNET_HOME}/build/cpp-package/example/${CPP_INFERENCE_EXAMPLE} . + + +# Running the example with dog image. +LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${MXNET_HOME}/lib ${CPP_INFERENCE_EXAMPLE_FOLDER}/${CPP_INFERENCE_EXAMPLE} --symbol "${CPP_INFERENCE_EXAMPLE_FOLDER}/model/resnet50_ssd_model-symbol.json" --params "${CPP_INFERENCE_EXAMPLE_FOLDER}/model/resnet50_ssd_model-0000.params" --synset "${CPP_INFERENCE_EXAMPLE_FOLDER}/model/synset.txt" --mean "${CPP_INFERENCE_EXAMPLE_FOLDER}/model/mean_224.nd" --image "${CPP_INFERENCE_EXAMPLE_FOLDER}/model/dog.jpg" --warmup 10 --predict 2000 diff --git a/cpp-benchmark/unit_test_image_classification_gpu.sh b/cpp-benchmark/unit_test_image_classification_gpu.sh new file mode 100755 index 0000000..6515a1b --- /dev/null +++ b/cpp-benchmark/unit_test_image_classification_gpu.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# Downloading the data and model +source ${HOME}/.dlamirc +export MXNET_HOME=${HOME}/incubator-mxnet +export CPP_INFERENCE_EXAMPLE=image_classification +export CPP_INFERENCE_EXAMPLE_FOLDER=${MXNET_HOME}/cpp-package/example/inference +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${MXNET_HOME}/lib:/usr/local/cuda/lib64:/usr/local/cuda/lib:/usr/lib:/usr/local/lib + +mkdir -p ${CPP_INFERENCE_EXAMPLE_FOLDER}/model +cd ${CPP_INFERENCE_EXAMPLE_FOLDER}/model +wget -nc -O ${CPP_INFERENCE_EXAMPLE_FOLDER}/model/resnet50_ssd_model-symbol.json https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-symbol.json +wget -nc -O ${CPP_INFERENCE_EXAMPLE_FOLDER}/model/resnet50_ssd_model-0000.params https://s3.amazonaws.com/model-server/models/resnet50_ssd/resnet50_ssd_model-0000.params +wget -nc -O ${CPP_INFERENCE_EXAMPLE_FOLDER}/model/synset.txt https://s3.amazonaws.com/model-server/models/resnet50_ssd/synset.txt +wget -nc -O ${CPP_INFERENCE_EXAMPLE_FOLDER}/model/dog.jpg https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/dog.jpg?raw=true +wget -nc -O ${CPP_INFERENCE_EXAMPLE_FOLDER}/model/mean_224.nd https://github.com/dmlc/web-data/raw/master/mxnet/example/feature_extract/mean_224.nd +cd ${CPP_INFERENCE_EXAMPLE_FOLDER} +cp ${MXNET_HOME}/build/cpp-package/example/${CPP_INFERENCE_EXAMPLE} . + + +# Running the example with dog image. +${CPP_INFERENCE_EXAMPLE_FOLDER}/${CPP_INFERENCE_EXAMPLE} --symbol "${CPP_INFERENCE_EXAMPLE_FOLDER}/model/resnet50_ssd_model-symbol.json" --params "${CPP_INFERENCE_EXAMPLE_FOLDER}/model/resnet50_ssd_model-0000.params" --synset "${CPP_INFERENCE_EXAMPLE_FOLDER}/model/synset.txt" --mean "${CPP_INFERENCE_EXAMPLE_FOLDER}/model/mean_224.nd" --image "${CPP_INFERENCE_EXAMPLE_FOLDER}/model/dog.jpg" --warmup 10 --predict 2000 --gpu diff --git a/task_config_template.cfg b/task_config_template.cfg index 40c000e..de9b68f 100755 --- a/task_config_template.cfg +++ b/task_config_template.cfg @@ -28,6 +28,18 @@ patterns = ['Speed: (\d+\.\d+|\d+) samples/sec', 'Speed: (\d+\.\d+|\d+) samples/ metrics = ['speed', 'speed-p90', 'speed-p50'] compute_method = ['average', 'p90', 'p50'] +[metrics_cpp_inference] +patterns = ['InferenceTime_P50\(uSecs\): (\d+\.\d+|\d+)', 'InferenceTime_P90\(uSecs\): (\d+\.\d+|\d+)', 'InferenceTime_P99\(uSecs\): (\d+\.\d+|\d+)'] +metrics = ['InferenceTime_P50_uSecs', 'InferenceTime_P90_uSecs', 'InferenceTime_P99_uSecs'] +compute_method = ['last', 'last', 'last'] + +[resnet50_cpp_inference_cpu] +patterns = ['InferenceTime_P50\(uSecs\): (\d+\.\d+|\d+)', 'InferenceTime_P90\(uSecs\): (\d+\.\d+|\d+)', 'InferenceTime_P99\(uSecs\): (\d+\.\d+|\d+)'] +metrics = ['InferenceTime_P50_uSecs', 'InferenceTime_P90_uSecs', 'InferenceTime_P99_uSecs'] +compute_method = ['last', 'last', 'last'] +command_to_execute = bash cpp-benchmark/unit_test_image_classification_cpu.sh +num_gpus = 0 + [mkl_resnet18_cifar10_symbolic] patterns = ['Speed: (\d+\.\d+|\d+) samples/sec', 'Train-accuracy=(\d+\.\d+|\d+)', 'Time cost=(\d+\.\d+|\d+)', 'Validation-accuracy=(\d+\.\d+|\d+)' ] metrics = ['speed', 'training_acc', 'total_training_time', 'validation_acc']