Image Dataset Example

from xaiunits.datagenerator.image_generation import BalancedImageDataset, ImbalancedImageDataset

data = BalancedImageDataset(
    seed=0,
    backgrounds=5,
    shapes=10,
    n_variants=4,
    background_size=(300, 300),
    shape_type="geometric",
    position="random",
    overlay_scale=0.3,
    rotation=True,
    shape_colors=["orange"],
    contour_thickness=20,
)

def jupyter_display(img_tensor):
    from torchvision import transforms
    from IPython.display import display
    display(transforms.ToPILImage()(img_tensor).convert("RGB"))

# Examining one datapoint from the dataset
x, y_label, context = data[0]

# x is an image tensor
print("x.shape:", x.shape)
jupyter_display(x)

# y is an integer label representing the foreground shape
print("y_label:", y_label)

# context is a dict that contains "ground_truth_attribute"
print("context:", list(context.keys()))
jupyter_display(context["ground_truth_attribute"])
x.shape: torch.Size([3, 300, 300])
_images/7462a101da1fbb16b4f9812c7134c9178d67a8fdb6cc30837654a83a5c456a62.png
y_label: 3
context: ['fg_shape', 'bg_label', 'fg_color', 'ground_truth_attribute']
_images/b1c796556abbe868a25e4bc9efd500457a0e5d95fd4286b7e73fbcdf6059a193.png
# Note that the ground truth mask contains both the foreground shape and a small border around the shape
# (In order to capture CNN's triggering on the edge of the shape)
# The width of this border is set by the contour_thickness parameter when you define the dataset
jupyter_display(context["ground_truth_attribute"] * x)
_images/4bb67a41b675236d007b19107e20c97a3e4e300783843d4d0d4af4d18fae17de.png
# The imbalanced dataset has an extra parameter for imbalance
# This dataset also requires shapes>=2, and backgrounds>=shapes
imbalanced_data = ImbalancedImageDataset(
        backgrounds=3,
        shapes=2,
        n_variants=4,
        shape_colors="yellow",
        background_size=(300, 300),
        imbalance=0.5,
        shuffled=False
)

# Here we have selected an imbalance of 0.5 and n_variants=4
# So for any given foreground shape, 2 images will be associated with one background
# And for that foreground shape, the remaining 2 images will be spread across the other backgrounds
print("First foreground shape")
jupyter_display(imbalanced_data[0][0])
jupyter_display(imbalanced_data[1][0])
jupyter_display(imbalanced_data[2][0])
jupyter_display(imbalanced_data[3][0])
First foreground shape
_images/a63a2219c72b57f3440a82e0c4333fde70bf5b9a8edea02fc7c88f2252625f0d.png _images/88e113eee1750760c682fbbc08dbb74050c2ac25fa59e6647747b713096fe314.png _images/370de47c03573f1a4af51bed4edfa44a09711c3df284accc5c81a90bc1ca302d.png _images/b0af0edf6bfe7d2897c1a07d426d15d3e11de8af4cf080565a1234b5901d2dbe.png
print("Second foreground shape") # since shapes=2, n_variants=4
jupyter_display(imbalanced_data[4][0])
jupyter_display(imbalanced_data[5][0])
jupyter_display(imbalanced_data[6][0])
jupyter_display(imbalanced_data[7][0])
Second foreground shape
_images/3e9e1c9e882da116fa80096623d0e21e4e41679f797a32dea2904b9a07235c44.png _images/d78aad383ecae87f845e716de9303df38a6143f682b9f79a2945ca387d5c85a5.png _images/876092a7de8400e91ad285a24c8484eb1b42621f99627d3fd18cf7f700c44685.png _images/4120b5dd205e6df0f6e368e3d34be8c56a7532a844196a9ec31a8b84c4cb5922.png