diff --git a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java index 6879406..4e66d08 100644 --- a/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java +++ b/openml-lightgbm/lightgbm-provider/src/main/java/com/feedzai/openml/provider/lightgbm/LightGBMBinaryClassificationModelTrainer.java @@ -148,33 +148,35 @@ static void fit(final Dataset dataset, final int numIterations = parseInt(params.get(NUM_ITERATIONS_PARAMETER_NAME)); logger.debug("LightGBM model trainParams: {}", trainParams); - final SWIGTrainData swigTrainData = new SWIGTrainData( + try (final SWIGTrainData swigTrainData = new SWIGTrainData( numFeatures, instancesPerChunk, FairGBMParamParserUtil.isFairnessConstrained(params), sampleWeightColIndex.isPresent() - ); - final SWIGTrainBooster swigTrainBooster = new SWIGTrainBooster(); + ); final SWIGTrainBooster swigTrainBooster = new SWIGTrainBooster()) { + /// Create LightGBM dataset + final int constraintGroupColIndex = FairGBMParamParserUtil.getConstraintGroupColumnIndex(params, schema) + .orElse(FairGBMParamParserUtil.NO_SPECIFIC); + + createTrainDataset( + dataset, + numFeatures, + trainParams, + constraintGroupColIndex, + sampleWeightColIndex, + swigTrainData + ); - /// Create LightGBM dataset - final int constraintGroupColIndex = FairGBMParamParserUtil.getConstraintGroupColumnIndex(params, schema).orElse( - FairGBMParamParserUtil.NO_SPECIFIC); - createTrainDataset( - dataset, - numFeatures, - trainParams, - constraintGroupColIndex, - sampleWeightColIndex, - swigTrainData - ); + /// Create Booster from dataset + createBoosterStructure(swigTrainBooster, swigTrainData, trainParams); + trainBooster(swigTrainBooster.swigBoosterHandle, numIterations); - /// Create Booster from dataset - createBoosterStructure(swigTrainBooster, swigTrainData, trainParams); - trainBooster(swigTrainBooster.swigBoosterHandle, numIterations); + /// Save model + saveModelFileToDisk(swigTrainBooster.swigBoosterHandle, outputModelFilePath); - /// Save model - saveModelFileToDisk(swigTrainBooster.swigBoosterHandle, outputModelFilePath); - swigTrainBooster.close(); // Explicitly release C++ resources right away. They're no longer needed. + // Note: By using try-with-resources, the call to both `swigTrainData.close()` + // and `swigTrainBooster.close()` to release C++ resources is guaranteed + } } /**