DIGITS MNIST Example On Multiple GPUs Fail

Hey everyone, so I just got DIGITS setup in our Kubernetes cluster and was playing around with training a model on the MNIST dataset. Specifically I trained the LeNet model. Everything works fine running on a single GPU. However, we have an NVIDIA DGX system in place with 8 Tesla V100 GPUs and when trying to train the model with anything more than 1 GPU I receive some errors. The training gets to 100% then fails. From what I see in the logs I think the issue has something to do with this error I found in my logs:

(0) Invalid argument: Number of ways to split should evenly divide the split dimension, but got split_dim 0 (size = 3) and num_split 2
[[{{node train/parallelize/split_batch}}]]
[[train/parallelize/split_batch/_47]]
(1) Invalid argument: Number of ways to split should evenly divide the split dimension, but got split_dim 0 (size = 3) and num_split 2
[[{{node train/parallelize/split_batch}}]]

If anyone could provide some insight it would be greatly appreciated. Thanks in advance