diff --git a/auglab/configs/transform_params_gpu.json b/auglab/configs/transform_params_gpu.json index 291b652..eda598e 100644 --- a/auglab/configs/transform_params_gpu.json +++ b/auglab/configs/transform_params_gpu.json @@ -43,6 +43,8 @@ "RedistributeSegTransform": { "in_seg": 0.25, "retain_stats": true, + "std_noise_range": [0.1, 0.3], + "dilation_iterations_range": [1, 5], "probability": 0.4 }, "GaussianNoiseTransform": { diff --git a/auglab/transforms/gpu/fromSeg.py b/auglab/transforms/gpu/fromSeg.py index 83b3b90..7a1fc82 100644 --- a/auglab/transforms/gpu/fromSeg.py +++ b/auglab/transforms/gpu/fromSeg.py @@ -29,12 +29,16 @@ def __init__( same_on_batch: bool = False, p: float = 1.0, keepdim: bool = True, + std_noise_range: list[float] = [0.1, 0.3], + dilation_iterations_range: list[int] = [1, 3], **kwargs, ) -> None: super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim) self.in_seg = in_seg self.apply_to_channel = apply_to_channel self.retain_stats = retain_stats + self.std_noise_range = std_noise_range + self.dilation_iterations_range = dilation_iterations_range @torch.no_grad() def apply_transform( @@ -96,7 +100,8 @@ def apply_transform( # Vectorized dilation for all regions (3 iterations) dilated = masks.float() - for _ in range(3): + dilation_iterations = torch.randint(self.dilation_iterations_range[0], self.dilation_iterations_range[1]+1, (1,), device=input.device)[0].item() + for _ in range(dilation_iterations): if spatial_dims == 3: dilated = F.max_pool3d(dilated.unsqueeze(0), 3, 1, 1).squeeze(0) else: @@ -125,8 +130,9 @@ def apply_transform( dil_stds = dil_vars.sqrt() # redist_std per region + std_noise_range = torch.rand(1, device=input.device)[0] * (self.std_noise_range[1] - self.std_noise_range[0]) + self.std_noise_range[0] redist_std = torch.maximum( - torch.rand(R, device=input.device) * 0.2 + 0.4 * torch.abs((means - dil_means) * stds / (dil_stds + 1e-6)), + torch.rand(R, device=input.device) * std_noise_range + 0.4 * torch.abs((means - dil_means) * stds / (dil_stds + 1e-6)), torch.full((R,), 0.01, device=input.device, dtype=input.dtype) ) diff --git a/auglab/transforms/gpu/transforms.py b/auglab/transforms/gpu/transforms.py index f471370..6e61cc6 100644 --- a/auglab/transforms/gpu/transforms.py +++ b/auglab/transforms/gpu/transforms.py @@ -93,6 +93,8 @@ def _build_transforms(self) -> list[nn.Module]: in_seg=redistribute_params.get('in_seg', 0.2), retain_stats=redistribute_params.get('retain_stats', False), p=redistribute_params.get('probability', 0), + std_noise_range=redistribute_params.get('std_noise_range', [0.1, 0.3]), + dilation_iterations_range=redistribute_params.get('dilation_iterations_range', [1, 3]), )) # Scharr filter @@ -395,22 +397,22 @@ def pad_numpy_array(arr, shape): augmentor = AugTransformsGPU(json_path) # Load images and masks tensors - img_path = '/home/GRAMES.POLYMTL.CA/p118739/data_nvme_p118739/data/datasets/data-multi-subject/sub-amu02/anat/sub-amu02_T1w.nii.gz' + img_path = '/home/ge.polymtl.ca/p118739/data/datasets/data-multi-subject/sub-amu02/anat/sub-amu02_T1w.nii.gz' img = Image(img_path).change_orientation('RSP') img = resample_nib(img, new_size=[1,1,1], new_size_type='mm', interpolation='linear') img_tensor = torch.from_numpy(img.data.copy()).to(torch.float32) - seg_path = '/home/GRAMES.POLYMTL.CA/p118739/data_nvme_p118739/data/datasets/data-multi-subject/derivatives/labels/sub-amu02/anat/sub-amu02_T1w_label-spine_dseg.nii.gz' + seg_path = '/home/ge.polymtl.ca/p118739/data/datasets/data-multi-subject/derivatives/labels/sub-amu02/anat/sub-amu02_T1w_label-spine_dseg.nii.gz' seg = Image(seg_path).change_orientation('RSP') seg = resample_nib(seg, new_size=[1,1,1], new_size_type='mm', interpolation='nn') seg_tensor_all = torch.from_numpy(seg.data.copy()) - img2_path = '/home/GRAMES.POLYMTL.CA/p118739/data_nvme_p118739/data/datasets/spider-challenge-2023/sub-002/anat/sub-002_acq-lowresSag_T2w.nii.gz' + img2_path = '/home/ge.polymtl.ca/p118739/data/datasets/spider-challenge-2023/sub-002/anat/sub-002_acq-lowresSag_T2w.nii.gz' img2 = Image(img2_path).change_orientation('RSP') img2 = resample_nib(img2, new_size=[1,1,1], new_size_type='mm', interpolation='linear') img2_tensor = torch.from_numpy(img2.data.copy()).to(torch.float32) - seg2_path = '/home/GRAMES.POLYMTL.CA/p118739/data_nvme_p118739/data/datasets/spider-challenge-2023/derivatives/labels/sub-002/anat/sub-002_acq-lowresSag_T2w_label-spine_dseg.nii.gz' + seg2_path = '/home/ge.polymtl.ca/p118739/data/datasets/spider-challenge-2023/derivatives/labels/sub-002/anat/sub-002_acq-lowresSag_T2w_label-spine_dseg.nii.gz' seg2 = Image(seg2_path).change_orientation('RSP') seg2 = resample_nib(seg2, new_size=[1,1,1], new_size_type='mm', interpolation='nn') seg2_tensor_all = torch.from_numpy(seg2.data.copy())