*My post explains Caltech 256.
Caltech256() can use Caltech 256 dataset as shown below:
*Memos:
- The 1st argument is
root
(Required-Type:str
orpathlib.Path
). *An absolute or relative path is possible. - The 2nd argument is
transform
(Optional-Default:None
-Type:callable
). - The 3rd argument is
target_transform
(Optional-Default:None
-Type:callable
). - The 4th argument is
download
(Optional-Default:False
-Type:bool
): *Memos:- If it's
True
, the dataset is downloaded from the internet and extracted(unzipped) toroot
. - If it's
True
and the dataset is already downloaded, it's extracted. - If it's
True
and the dataset is already downloaded and extracted, nothing happens. - It should be
False
if the dataset is already downloaded and extracted because it's faster. - You can manually download and extract the dataset(
256_ObjectCategories.tar
) from here todata/caltech256/
.
- If it's
- About the categories of the image indices, ak47(0) is 0~97, american-flag(1) is 98~194, backpack(2) is 195~345, baseball-bat(3) is 346~472, baseball-glove(4) is 473~620, basketball-hoop(5) is 621~710, bat(6) is 711~816, bathtub(7) is 817~1048, bear(8) is 1049~1150, beer-mug(9) is 1151~1244, etc.
from torchvision.datasets import Caltech256
my_data = Caltech256(
root="data"
)
my_data = Caltech256(
root="data",
transform=None,
target_transform=None,
download=False
)
len(my_data)
# 30607
my_data
# Dataset Caltech256
# Number of datapoints: 30607
# Root location: data\caltech256
my_data.root
# 'data/caltech256'
print(my_data.transform)
# None
print(my_data.target_transform)
# None
my_data.download
# <bound method Caltech256.download of Dataset Caltech256
# Number of datapoints: 30607
# Root location: data\caltech256>
len(my_data.categories)
# 257
my_data.categories
# ['001.ak47', '002.american-flag', '003.backpack', '004.baseball-bat',
# '005.baseball-glove', '006.basketball-hoop', '007.bat', '008.bathtub',
# '009.bear', '010.beer-mug', '011.billiards', '012.binoculars',
# ...
# '254.greyhound', '255.tennis-shoes', '256.toad', '257.clutter']
my_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=499x278>, 0)
my_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=268x218>, 0)
my_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=300x186>, 0)
my_data[98]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x328>, 1)
my_data[195]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=375x500>, 2)
import matplotlib.pyplot as plt
def show_images(data, main_title=None):
plt.figure(figsize=(10, 5))
plt.suptitle(t=main_title, y=1.0, fontsize=14)
ims = (0, 1, 2, 98, 195, 346, 473, 621, 711, 817, 1049, 1151)
for i, j in enumerate(ims, start=1):
plt.subplot(2, 5, i)
im, lab = data[j]
plt.title(label=lab)
plt.imshow(X=im)
if i == 10:
break
plt.tight_layout()
plt.show()
show_images(data=my_data, main_title="my_data")
Top comments (0)