聚合国内IT技术精华文章,分享IT技术精华,帮助IT从业人士成长

How to ignore illegal sample of dataset in PyTorch?

2020-02-09 21:50 浏览: 1698498 次 我要评论(0 条) 字号:

I have implemented a dataset class for my image samples. But it can’t handle the situation that a corrupted image has been read:

import torch.utils.data as data

class MyDataset(data.Dataset):
  ...
  def __getitem__(self, index):
    image = cv2.imread(image_list[index])
    if image is None:
      # What should we do?
...

The correct solution is in Pytorch Forum. Therefore I changed my code:

class MyDataset(data.Dataset):
  ...
  def __getitem__(self, index):
    image = cv2.imread(image_list[index])
    if image is None:
      return None
    # Other preprocessing
    ...

def my_collate(batch):
    batch = filter(lambda img: img is not None, batch)
    return data.dataloader.default_collate(list(batch))

dataset = MyDataset()
loader = data.DataLoader(dataset, collate_fn=my_collate)

But it reports:

Loading data exception: Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 108, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "train.py", line 197, in my_collate
    return data.dataloader.default_collate(batch)
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 34, in default_collate
    elem_type = type(batch[0])
TypeError: 'filter' object is not subscriptable

Seems default_collate() couldn’t recognize the ‘filter’ object. Don’t worry. We can just add a small function: list()

def my_collate(batch):
  ...
  return data.dataloader.default_collate(list(batch))



网友评论已有0条评论, 我也要评论

发表评论

*

* (保密)

Ctrl+Enter 快捷回复