/* * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "BatchStream.h" #include "EntropyCalibrator.h" #include "argsParser.h" #include "buffers.h" #include "common.h" #include "logger.h" #include "parserOnnxConfig.h" #include "kernel.h" #include "NvInfer.h" #include #include using namespace nvonnxparser; using namespace nvinfer1; const std::string gSampleName = "TensorRT.platerecg"; static const int INPUT_H = 24; static const int INPUT_W = 94; /* Developed by Nyan. Plugin creation requirements: (1)Try to understand plugin process looking at tensorflow code, for example sparsetodence conversion, need to look at TF's sparsetodense code (/home/xavier/Desktop/tensorflow-master/tensorflow/core/util/ctc/ctc_decoder.h) (2)Then use base methods from IPluginV2IOExt and IPluginCreator as shown in the following sample (3)Important thing is to understand how to process in CUDA (some may not be necessary to process on CUDA) (4)Just need to make sure number of inputs and number of outputs should be same as Tensorflow API's requirements (5)Check the model graph with NETRON. SparseToDense has Four key inputs. They are Tensorflow' API arguments Inputs to CTCGreedyDecoder node (Can check from PB or ONNX graph using netron) const typename CTCDecoder::SequenceLength& seq_len sequence_length(Fill) const std::vector::Input>& input inputs(Transpose) Outputs are scores std::vector::Output>* output Input/Output shapes can be seen from pb graph by printing all layer input and output output = graph.get_tensor_by_name("import/d_predictions:0") x = sess.run(output, feed_dict={input: [img]}) print(x) print(x.shape) as shown in the Input_output_layers_size For the graph variation, ONNX-Graphsurgeon can be used. In the latest graph, SparseToDense node was removed using graphsurgeon. */ class CTCGreedyDecoder : public IPluginV2DynamicExt { public: //CTCGreedyDecoder has two inputs (import/transpose:0 (88, 1, 43), import/Fill:0 (1,)) //and two outputs (import/CTCGreedyDecoder:0 (7, 2), import/ToInt32:0 (7,)) CTCGreedyDecoder(const PluginFieldCollection& fc) { (void) fc; } //data represents the class member variables data serialized in //void serialize(void* buffer) const override //length is the length of data in serialization //Need to deserialize in this Constructor method CTCGreedyDecoder(const void* data, size_t length) { const char* d = static_cast(data); const char* const a = d; mDataType = static_cast(read(d)); assert(d == a + length); } // It makes no sense to construct CTCGreedyDecoder without arguments. CTCGreedyDecoder() = delete; virtual ~CTCGreedyDecoder() {} public: int getNbOutputs() const override { return 1;//it has one output } DimsExprs getOutputDimensions(int outputIndex, const DimsExprs* inputs, int nbInputs, IExprBuilder& exprBuilder) override { nvinfer1::DimsExprs output; output.nbDims=2; output.d[0] = exprBuilder.constant(1); output.d[1] = exprBuilder.constant(20); return output; } int initialize() override { return 0; } void terminate() override { //To release memory } size_t getWorkspaceSize(const PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs,int nbOutputs) const { return 0; } int enqueue(const PluginTensorDesc *inputDesc, const PluginTensorDesc *outputDesc, const void *const *inputs, void *const *outputs, void *workspace,cudaStream_t stream) override { Dims mInputDims = inputDesc[0].dims; int rows = mInputDims.d[0]; int batch = mInputDims.d[1]; int widths = mInputDims.d[2]; float* output = reinterpret_cast(outputs[0]); interface(stream, inputs[0], output, rows, batch, widths); return 0; } //All class member private variables are serialized //in the following two methods one after another //serializationSize is the size of all member variables size_t getSerializationSize() const override { size_t serializationSize = 0; serializationSize += sizeof(static_cast(mDataType)); if (mDataType == DataType::kINT8) { serializationSize += sizeof(float) * 2; } return serializationSize; } //serialize to char pointer void serialize(void* buffer) const override { char* d = static_cast(buffer); const char* const a = d; write(d, static_cast(mDataType)); assert(d == a + getSerializationSize()); } //Plugin configuration for the input/output types, formats and sizes //PluginTensorDesc are fields that a plugin might see for an input or output. //it has 4 attributes (Dims, DataType, TensorFormat, float scale) //you can assert all match to the requirements //you can check all input/output types meets the expectations //For this CTCGreedyDecoder //first input is 3D vector //second is 1D vector void configurePlugin(const DynamicPluginTensorDesc* in, int nbInput, const DynamicPluginTensorDesc* out, int nbOutput) override { cout << "nbInput " << nbInput << " nbOutput " << nbOutput << endl; assert(in && nbInput == 2); assert(out && nbOutput == 1); return ; } //! The combination of kLINEAR + kINT8/kHALF/kFLOAT is supported. bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) { cout << "nbInputs_here " << nbInputs << " nbOutputs_here " << nbOutputs << endl; assert(nbInputs == 2 && nbOutputs == 1 && pos < nbInputs + nbOutputs); bool condition = inOut[pos].format == TensorFormat::kLINEAR; condition &= inOut[pos].type != DataType::kINT32; condition &= inOut[pos].type == inOut[0].type; return condition; } DataType getOutputDataType(int index, const DataType* inputTypes, int nbInputs) const override { cout << "inputTypes " << inputTypes << " nbInputs " << nbInputs << endl; assert(inputTypes && nbInputs == 2); (void) index; return inputTypes[0]; } const char* getPluginType() const override { return "CTCGreedyDecoder"; } const char* getPluginVersion() const override { return "1"; } void destroy() override { delete this; } IPluginV2DynamicExt* clone() const override { auto* plugin = new CTCGreedyDecoder(*this); return plugin; } void setPluginNamespace(const char* libNamespace) override { mNamespace = libNamespace; } const char* getPluginNamespace() const override { return mNamespace.data(); } private: template void write(char*& buffer, const T& val) const { *reinterpret_cast(buffer) = val; buffer += sizeof(T); } template T read(const char*& buffer) const { T val = *reinterpret_cast(buffer); buffer += sizeof(T); return val; } private: //This sparsetodense plugin doesn't need private members //All inputs/outputs are in the stream from previous outputs //so it is not necessary to have private members DataType mDataType; std::string mNamespace; }; class CTCGreedyDecoderCreator : public IPluginCreator { public: const char* getPluginName() const override { return "CTCGreedyDecoder"; } const char* getPluginVersion() const override { return "1"; } const PluginFieldCollection* getFieldNames() override { return &mFieldCollection; } IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override { auto plugin = new CTCGreedyDecoder(*fc); mFieldCollection = *fc; mPluginName = name; return plugin; } IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override { //serialData is all data serialized in fun void serialize(void* buffer) const override //serialLength is length of data seiralized in serialize(void* buffer) auto plugin = new CTCGreedyDecoder(serialData, serialLength); mPluginName = name; return plugin; } void setPluginNamespace(const char* libNamespace) override { mNamespace = libNamespace; } const char* getPluginNamespace() const override { return mNamespace.c_str(); } private: std::string mNamespace; std::string mPluginName; PluginFieldCollection mFieldCollection{0, nullptr}; }; REGISTER_TENSORRT_PLUGIN(CTCGreedyDecoderCreator); class Recognition { template using SampleUniquePtr = std::unique_ptr; public: Recognition(const samplesCommon::OnnxSampleParams& params) : mParams(params) { } //! //! \brief Builds both engines. //! bool build(); //! //! \brief Prepares the model for inference by creating execution contexts and allocating buffers. //! bool prepare(); //! //! \brief Runs inference using TensorRT on a random image. //! bool infer(); private: bool buildPreprocessorEngine(const SampleUniquePtr& builder); bool buildPredictionEngine(const SampleUniquePtr& builder); Dims loadPGMFile(const std::string& fileName); bool validateOutput(int digit); samplesCommon::OnnxSampleParams mParams; //!< The parameters for the sample. nvinfer1::Dims mPredictionInputDims; //!< The dimensions of the input of the MNIST model. nvinfer1::Dims mPredictionOutputDims; //!< The dimensions of the output of the MNIST model. // Engines used for inference. The first is used for resizing inputs, the second for prediction. SampleUniquePtr mPreprocessorEngine{nullptr}, mPredictionEngine{nullptr}; SampleUniquePtr mPreprocessorContext{nullptr}, mPredictionContext{nullptr}; samplesCommon::ManagedBuffer mInput{}; //!< Host and device buffers for the input. samplesCommon::DeviceBuffer mPredictionInput{}; //!< Device buffer for the output of the preprocessor, i.e. the //!< input to the prediction model. samplesCommon::ManagedBuffer mOutput{}; //!< Host buffer for the ouptut template SampleUniquePtr makeUnique(T* t) { return SampleUniquePtr{t}; } }; //! //! \brief Builds the two engines required for inference. //! //! \details This function creates one TensorRT engine for resizing inputs to the correct sizes, //! then creates a TensorRT network by parsing the ONNX model and builds //! an engine that will be used to run inference (mPredictionEngine). //! //! \return Ruturns false if error in build preprocessor or predict engine. //! bool Recognition::build() { auto builder = makeUnique(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger())); if (!builder) { sample::gLogError << "Create inference builder failed." << std::endl; return false; } // This function will also set mPredictionInputDims and mPredictionOutputDims, // so it needs to be called before building the preprocessor. return buildPredictionEngine(builder) && buildPreprocessorEngine(builder); } //! //! \brief Builds an engine for preprocessing (mPreprocessorEngine). //! //! \return Ruturns false if error in build preprocessor engine. //! bool Recognition::buildPreprocessorEngine(const SampleUniquePtr& builder) { // Create the preprocessor engine using a network that supports full dimensions (createNetworkV2). auto preprocessorNetwork = makeUnique(builder->createNetworkV2(1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH))); if (!preprocessorNetwork) { sample::gLogError << "Create network failed." << std::endl; return false; } // Reshape a dynamically shaped input to the size expected by the model, (1, 1, 28, 28). auto input = preprocessorNetwork->addInput("input:0", nvinfer1::DataType::kFLOAT, Dims4{-1, 1, -1, -1}); auto resizeLayer = preprocessorNetwork->addResize(*input); resizeLayer->setOutputDimensions(mPredictionInputDims); preprocessorNetwork->markOutput(*resizeLayer->getOutput(0)); // Finally, configure and build the preprocessor engine. auto preprocessorConfig = makeUnique(builder->createBuilderConfig()); if (!preprocessorConfig) { sample::gLogError << "Create builder config failed." << std::endl; return false; } // Create an optimization profile so that we can specify a range of input dimensions. auto profile = builder->createOptimizationProfile(); // This profile will be valid for all images whose size falls in the range of [(1, 1, 1, 1), (1, 1, 56, 56)] // but TensorRT will optimize for (1, 1, 28, 28) // We do not need to check the return of setDimension and addOptimizationProfile here as all dims are explicitly set profile->setDimensions(input->getName(), OptProfileSelector::kMIN, Dims4{1, 24, 94, 3}); profile->setDimensions(input->getName(), OptProfileSelector::kOPT, Dims4{1, 24, 94, 3}); profile->setDimensions(input->getName(), OptProfileSelector::kMAX, Dims4{1, 24, 94, 3}); preprocessorConfig->addOptimizationProfile(profile); // Create a calibration profile. auto profileCalib = builder->createOptimizationProfile(); const int calibBatchSize{16}; // We do not need to check the return of setDimension and setCalibrationProfile here as all dims are explicitly set profileCalib->setDimensions(input->getName(), OptProfileSelector::kMIN, Dims4{calibBatchSize, 24, 94, 3}); profileCalib->setDimensions(input->getName(), OptProfileSelector::kOPT, Dims4{calibBatchSize, 24, 94, 3}); profileCalib->setDimensions(input->getName(), OptProfileSelector::kMAX, Dims4{calibBatchSize, 24, 94, 3}); preprocessorConfig->setCalibrationProfile(profileCalib); std::unique_ptr calibrator; if (mParams.int8) { preprocessorConfig->setFlag(BuilderFlag::kINT8); const int nCalibBatches{10}; MNISTBatchStream calibrationStream( calibBatchSize, nCalibBatches, "train-images-idx3-ubyte", "train-labels-idx1-ubyte", mParams.dataDirs); calibrator.reset( new Int8EntropyCalibrator2(calibrationStream, 0, "MNISTPreprocessor", "input")); preprocessorConfig->setInt8Calibrator(calibrator.get()); } mPreprocessorEngine = makeUnique(builder->buildEngineWithConfig(*preprocessorNetwork, *preprocessorConfig)); if (!mPreprocessorEngine) { sample::gLogError << "Preprocessor engine build failed." << std::endl; return false; } sample::gLogInfo << "Profile dimensions in preprocessor engine:" << std::endl; sample::gLogInfo << " Minimum = " << mPreprocessorEngine->getProfileDimensions(0, 0, OptProfileSelector::kMIN) << std::endl; sample::gLogInfo << " Optimum = " << mPreprocessorEngine->getProfileDimensions(0, 0, OptProfileSelector::kOPT) << std::endl; sample::gLogInfo << " Maximum = " << mPreprocessorEngine->getProfileDimensions(0, 0, OptProfileSelector::kMAX) << std::endl; return true; } //! //! \brief Builds an engine for prediction (mPredictionEngine). //! //! \details This function builds an engine for the MNIST model, and updates mPredictionInputDims and //! mPredictionOutputDims according to the dimensions specified by the model. The preprocessor reshapes inputs to //! mPredictionInputDims. //! //! \return Ruturns false if error in build prediction engine. //! bool Recognition::buildPredictionEngine(const SampleUniquePtr& builder) { // Create a network using the parser. const auto explicitBatch = 1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); auto network = makeUnique(builder->createNetworkV2(explicitBatch)); if (!network) { sample::gLogError << "Create network failed." << std::endl; return false; } auto parser = samplesCommon::infer_object(nvonnxparser::createParser(*network, sample::gLogger.getTRTLogger())); bool parsingSuccess = parser->parseFromFile(locateFile(mParams.onnxFileName, mParams.dataDirs).c_str(), static_cast(sample::gLogger.getReportableSeverity())); if (!parsingSuccess) { sample::gLogError << "Failed to parse model." << std::endl; return false; } // Get information about the inputs/outputs directly from the model. mPredictionInputDims = network->getInput(0)->getDimensions(); mPredictionOutputDims = network->getOutput(0)->getDimensions(); // Create a builder config auto config = makeUnique(builder->createBuilderConfig()); if (!config) { sample::gLogError << "Create builder config failed." << std::endl; return false; } config->setMaxWorkspaceSize(16_MiB); if (mParams.fp16) { config->setFlag(BuilderFlag::kFP16); } auto profileCalib = builder->createOptimizationProfile(); const auto inputName = mParams.inputTensorNames[0].c_str(); const int calibBatchSize{16}; // We do not need to check the return of setDimension and setCalibrationProfile here as all dims are explicitly set profileCalib->setDimensions(inputName, OptProfileSelector::kMIN, Dims4{calibBatchSize, 24, 94, 3}); profileCalib->setDimensions(inputName, OptProfileSelector::kOPT, Dims4{calibBatchSize, 24, 94, 3}); profileCalib->setDimensions(inputName, OptProfileSelector::kMAX, Dims4{calibBatchSize, 24, 94, 3}); config->setCalibrationProfile(profileCalib); // Create an optimization profile so that we can specify a range of input dimensions. auto profile = builder->createOptimizationProfile(); const int batchSize{1}; profile->setDimensions(inputName, OptProfileSelector::kMIN, Dims4{batchSize, 24, 94, 3}); profile->setDimensions(inputName, OptProfileSelector::kOPT, Dims4{batchSize, 24, 94, 3}); profile->setDimensions(inputName, OptProfileSelector::kMAX, Dims4{batchSize, 24, 94, 3}); config->addOptimizationProfile(profile); std::unique_ptr calibrator; if (mParams.int8) { config->setFlag(BuilderFlag::kINT8); int nCalibBatches{10}; MNISTBatchStream calibrationStream( calibBatchSize, nCalibBatches, "train-images-idx3-ubyte", "train-labels-idx1-ubyte", mParams.dataDirs); calibrator.reset( new Int8EntropyCalibrator2(calibrationStream, 0, "MNISTPrediction", inputName)); config->setInt8Calibrator(calibrator.get()); } // Build the prediciton engine. mPredictionEngine = makeUnique(builder->buildEngineWithConfig(*network, *config)); if (!mPredictionEngine) { sample::gLogError << "Prediction engine build failed." << std::endl; return false; } return true; } //! //! \brief Prepares the model for inference by creating an execution context and allocating buffers. //! //! \details This function sets up the sample for inference. This involves allocating buffers for the inputs and //! outputs, as well as creating TensorRT execution contexts for both engines. This only needs to be called a single //! time. //! //! \return Ruturns false if error in build preprocessor or predict context. //! bool Recognition::prepare() { mPreprocessorContext = makeUnique(mPreprocessorEngine->createExecutionContext()); if (!mPreprocessorContext) { sample::gLogError << "Preprocessor context build failed." << std::endl; return false; } mPredictionContext = makeUnique(mPredictionEngine->createExecutionContext()); if (!mPredictionContext) { sample::gLogError << "Prediction context build failed." << std::endl; return false; } // Since input dimensions are not known ahead of time, we only allocate the output buffer and preprocessor output // buffer. mPredictionInput.resize(mPredictionInputDims); mOutput.hostBuffer.resize(mPredictionOutputDims); mOutput.deviceBuffer.resize(mPredictionOutputDims); return true; } //! //! \brief Runs inference for this sample //! //! \details This function is the main execution function of the sample. //! It runs inference for using a random image from the MNIST dataset as an input. //! bool Recognition::infer() { // Load a random PGM file into a host buffer, then copy to device. std::random_device rd{}; std::default_random_engine generator{rd()}; std::uniform_int_distribution digitDistribution{0, 9}; int digit = digitDistribution(generator); Dims inputDims = loadPGMFile(locateFile(std::to_string(digit) + ".pgm", mParams.dataDirs)); mInput.deviceBuffer.resize(inputDims); CHECK(cudaMemcpy( mInput.deviceBuffer.data(), mInput.hostBuffer.data(), mInput.hostBuffer.nbBytes(), cudaMemcpyHostToDevice)); // Set the input size for the preprocessor CHECK_RETURN_W_MSG(mPreprocessorContext->setBindingDimensions(0, inputDims), false, "Invalid binding dimensions."); // We can only run inference once all dynamic input shapes have been specified. if (!mPreprocessorContext->allInputDimensionsSpecified()) { return false; } // Run the preprocessor to resize the input to the correct shape std::vector preprocessorBindings = {mInput.deviceBuffer.data(), mPredictionInput.data()}; // For engines using full dims, we can use executeV2, which does not include a separate batch size parameter. bool status = mPreprocessorContext->executeV2(preprocessorBindings.data()); if (!status) { return false; } // Next, run the model to generate a prediction. std::vector predicitonBindings = {mPredictionInput.data(), mOutput.deviceBuffer.data()}; status = mPredictionContext->executeV2(predicitonBindings.data()); if (!status) { return false; } // Copy the outputs back to the host and verify the output. CHECK(cudaMemcpy(mOutput.hostBuffer.data(), mOutput.deviceBuffer.data(), mOutput.deviceBuffer.nbBytes(), cudaMemcpyDeviceToHost)); return validateOutput(digit); } //! //! \brief Loads a PGM file into mInput and returns the dimensions of the loaded image. //! //! \details This function loads the specified PGM file into the input host buffer. //! Dims Recognition::loadPGMFile(const std::string& fileName) { std::ifstream infile(fileName, std::ifstream::binary); assert(infile.is_open() && "Attempting to read from a file that is not open."); std::string magic; int h, w, max; infile >> magic >> h >> w >> max; infile.seekg(1, infile.cur); Dims4 inputDims{1, 1, h, w}; size_t vol = samplesCommon::volume(inputDims); std::vector fileData(vol); infile.read(reinterpret_cast(fileData.data()), vol); // Print an ascii representation sample::gLogInfo << "Input:\n"; for (size_t i = 0; i < vol; i++) { sample::gLogInfo << (" .:-=+*#%@"[fileData[i] / 26]) << (((i + 1) % w) ? "" : "\n"); } sample::gLogInfo << std::endl; // Normalize and copy to the host buffer. mInput.hostBuffer.resize(inputDims); float* hostDataBuffer = static_cast(mInput.hostBuffer.data()); std::transform(fileData.begin(), fileData.end(), hostDataBuffer, [](uint8_t x) { return 1.0 - static_cast(x / 255.0); }); return inputDims; } //! //! \brief Checks whether the model prediction (in mOutput) is correct. //! bool Recognition::validateOutput(int digit) { const float* bufRaw = static_cast(mOutput.hostBuffer.data()); std::vector prob(bufRaw, bufRaw + mOutput.hostBuffer.size()); int curIndex{0}; for (const auto& elem : prob) { sample::gLogInfo << " Prob " << curIndex << " " << std::fixed << std::setw(5) << std::setprecision(4) << elem << " " << "Class " << curIndex << ": " << std::string(int(std::floor(elem * 10 + 0.5f)), '*') << std::endl; ++curIndex; } int predictedDigit = std::max_element(prob.begin(), prob.end()) - prob.begin(); return digit == predictedDigit; } //! //! \brief Initializes members of the params struct using the command line args //! samplesCommon::OnnxSampleParams initializeSampleParams(const samplesCommon::Args& args) { samplesCommon::OnnxSampleParams params; if (args.dataDirs.empty()) //!< Use default directories if user hasn't provided directory paths { params.dataDirs.push_back("data/mnist/"); params.dataDirs.push_back("data/samples/mnist/"); } else //!< Use the data directory provided by the user { params.dataDirs = args.dataDirs; } params.onnxFileName = "numplate_recg_sparsetodenseremoved.onnx"; params.inputTensorNames.push_back("input:0"); params.outputTensorNames.push_back("d_predictions:0"); params.int8 = args.runInInt8; params.fp16 = args.runInFp16; return params; } //! //! \brief Prints the help information for running this sample //! void printHelpInfo() { std::cout << "Usage: ./sample_dynamic_reshape [-h or --help] [-d or --datadir=]" << std::endl; std::cout << "--help, -h Display help information" << std::endl; std::cout << "--datadir Specify path to a data directory, overriding the default. This option can be used " "multiple times to add multiple directories. If no data directories are given, the default is to use " "(data/samples/mnist/, data/mnist/)" << std::endl; std::cout << "--int8 Run in Int8 mode." << std::endl; std::cout << "--fp16 Run in FP16 mode." << std::endl; } //gdb --args ./platerecg_debug --useDLACore=-1 --fp16 --datadir=/usr/src/tensorrt/data/platerect/ --saveEngine=/usr/src/tensorrt/samples/NumPlateRecognition loadEngine=/usr/src/tensorrt/samples/NumPlateRecognition //./platerecg --useDLACore=-1 --fp16 --datadir=/usr/src/tensorrt/data/platerect/ --saveEngine=/usr/src/tensorrt/samples/NumPlateRecognition loadEngine=/usr/src/tensorrt/samples/NumPlateRecognition int main(int argc, char** argv) { samplesCommon::Args args; bool argsOK = samplesCommon::parseArgs(args, argc, argv); if (!argsOK) { sample::gLogError << "Invalid arguments" << std::endl; printHelpInfo(); return EXIT_FAILURE; } if (args.help) { printHelpInfo(); return EXIT_SUCCESS; } auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv); sample::gLogger.reportTestStart(sampleTest); Recognition platerecg{initializeSampleParams(args)}; if (!platerecg.build()) { return sample::gLogger.reportFail(sampleTest); } if (!platerecg.prepare()) { return sample::gLogger.reportFail(sampleTest); } if (!platerecg.infer()) { return sample::gLogger.reportFail(sampleTest); } return sample::gLogger.reportPass(sampleTest); }