DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

Caltech256 in PyTorch

Buy Me a Coffee

*My post explains Caltech 256.

Caltech256() can use Caltech 256 dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is transform(Optional-Default:None-Type:callable).
  • The 3rd argument is target_transform(Optional-Default:None-Type:callable).
  • The 4th argument is download(Optional-Default:False-Type:bool): *Memos:
    • If it's True, the dataset is downloaded from the internet and extracted(unzipped) to root.
    • If it's True and the dataset is already downloaded, it's extracted.
    • If it's True and the dataset is already downloaded and extracted, nothing happens.
    • It should be False if the dataset is already downloaded and extracted because it's faster.
    • You can manually download and extract the dataset(256_ObjectCategories.tar) from here to data/caltech256/.
  • About the categories of the image indices, ak47(0) is 0~97, american-flag(1) is 98~194, backpack(2) is 195~345, baseball-bat(3) is 346~472, baseball-glove(4) is 473~620, basketball-hoop(5) is 621~710, bat(6) is 711~816, bathtub(7) is 817~1048, bear(8) is 1049~1150, beer-mug(9) is 1151~1244, etc.
from torchvision.datasets import Caltech256

my_data = Caltech256(
    root="data"
)

my_data = Caltech256(
    root="data",
    transform=None,
    target_transform=None,
    download=False
)

len(my_data)
# 30607

my_data
# Dataset Caltech256
#    Number of datapoints: 30607
#    Root location: data\caltech256

my_data.root
# 'data/caltech256'

print(my_data.transform)
# None

print(my_data.target_transform)
# None

my_data.download
# <bound method Caltech256.download of Dataset Caltech256
#     Number of datapoints: 30607
#     Root location: data\caltech256>

len(my_data.categories)
# 257

my_data.categories
# ['001.ak47', '002.american-flag', '003.backpack', '004.baseball-bat',
#  '005.baseball-glove', '006.basketball-hoop', '007.bat', '008.bathtub',
#  '009.bear', '010.beer-mug', '011.billiards', '012.binoculars',
#  ...
#  '254.greyhound', '255.tennis-shoes', '256.toad', '257.clutter']

my_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=499x278>, 0)

my_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=268x218>, 0)

my_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x186>, 0)

my_data[98]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x328>, 1)

my_data[195]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=375x500>, 2)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    ims = (0, 1, 2, 98, 195, 346, 473, 621, 711, 817, 1049, 1151)
    for i, j in enumerate(ims, start=1):
        plt.subplot(2, 5, i)
        im, lab = data[j]
        plt.title(label=lab)
        plt.imshow(X=im)
        if i == 10:
            break
    plt.tight_layout()
    plt.show()

show_images(data=my_data, main_title="my_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Top comments (0)