通过H5文件构造Dataset满足Dataloader多进程读取需求

通过H5文件构造Dataset满足Dataloader多进程读取需求

连祈
2025-06-30 / 0 评论 / 4 阅读 / 正在检测是否收录...
  1. 使用Dataloader需要构建Dataset,dataset需要满足以下基本要求

    
    class MyDataset(Dataset):
     def __init__(self,):
       # 初始化:加载所有样本路径、标签、预处理等
       pass
    
     def __len__(self):
       # 返回数据集大小:Dataset 必须支持 len(dataset)
       return N
    
     def __getitem__(self, idx):
       # 根据索引返回一个样本,idx ∈ [0, N)
       # 通常返回 (data, label);也可以返回 dict、tuple 等任意容器
       return sample

  1. h5py多线程读取 并不能真正并行 ,HDF5 库内部对所有 API 调用加了一把全局互斥锁,对同一个文件的读写实际上仍会被串行化,无法获得并行加速。。因此应使用多进程读取要注意的是不要在父进程打开文件后再 fork。也就是在__init__阶段读取基本信息后关闭文件(使用with操作),等init结束后( Linux 下是通过 fork 复制主进程,此阶段结束后)再重新打开文件,否则所有子进程都会继承同一个文件描述符(FD)和内部的 HDF5 状态。这种跨进程共享同一 FD,在底层 HDF5 C 库看来就像是多个线程同时操作同一个连接,会触发其全局锁。
  2. 为了避免反复打开关闭h5文件,Dataloader场景下,会在__getitem__中打开一次然后复用

    self._h5_file = h5py.File(self.h5_path, 'r')

    因此需要回收资源

    def __del__(self):
      # 确保关闭文件句柄,避免资源泄露
      try:
          if self._h5_file is not None:
              self._h5_file.close()
      except Exception:
          pass
  3. 可以定义多个Dataset,然后通过ConcatDataset连接起来

    from torch.utils.data import ConcatDataset
    dataset = ConcatDataset([H5Dataset(p) for p in h5_paths])


  • 示例代码

    class MultiH5Dataset(torch.utils.data.Dataset):
      def __init__(self, h5_paths):
          self.paths = h5_paths
          # 预先读取每个文件的样本数,构建全局索引映射
          self.cum_counts = []
          total = 0
          for p in h5_paths:
              with h5py.File(p,'r') as f:
                  n = len(f['data'])
              total += n
              self.cum_counts.append(total)
    
      def __len__(self):
          return self.cum_counts[-1]
    
      def __getitem__(self, idx):
          # 找到 idx 属于哪一个文件
          file_idx = bisect.bisect_right(self.cum_counts, idx)
          # 计算在该文件内的局部索引
          prev = self.cum_counts[file_idx-1] if file_idx>0 else 0
          local_idx = idx - prev
          with h5py.File(self.paths[file_idx],'r') as f:
              data = f['data'][local_idx]
              label = f['label'][local_idx]
          return data, label
    
0

评论 (0)

取消