使用Dataloader需要构建Dataset,dataset需要满足以下基本要求
class MyDataset(Dataset): def __init__(self,): # 初始化:加载所有样本路径、标签、预处理等 pass def __len__(self): # 返回数据集大小:Dataset 必须支持 len(dataset) return N def __getitem__(self, idx): # 根据索引返回一个样本,idx ∈ [0, N) # 通常返回 (data, label);也可以返回 dict、tuple 等任意容器 return sample
- h5py多线程读取 并不能真正并行 ,HDF5 库内部对所有 API 调用加了一把全局互斥锁,对同一个文件的读写实际上仍会被串行化,无法获得并行加速。。因此应使用多进程读取要注意的是不要在父进程打开文件后再 fork。也就是在__init__阶段读取基本信息后关闭文件(使用with操作),等init结束后( Linux 下是通过 fork 复制主进程,此阶段结束后)再重新打开文件,否则所有子进程都会继承同一个文件描述符(FD)和内部的 HDF5 状态。这种跨进程共享同一 FD,在底层 HDF5 C 库看来就像是多个线程同时操作同一个连接,会触发其全局锁。
为了避免反复打开关闭h5文件,Dataloader场景下,会在__getitem__中打开一次然后复用
self._h5_file = h5py.File(self.h5_path, 'r')
因此需要回收资源
def __del__(self): # 确保关闭文件句柄,避免资源泄露 try: if self._h5_file is not None: self._h5_file.close() except Exception: pass
可以定义多个Dataset,然后通过ConcatDataset连接起来
from torch.utils.data import ConcatDataset dataset = ConcatDataset([H5Dataset(p) for p in h5_paths])
示例代码
class MultiH5Dataset(torch.utils.data.Dataset): def __init__(self, h5_paths): self.paths = h5_paths # 预先读取每个文件的样本数,构建全局索引映射 self.cum_counts = [] total = 0 for p in h5_paths: with h5py.File(p,'r') as f: n = len(f['data']) total += n self.cum_counts.append(total) def __len__(self): return self.cum_counts[-1] def __getitem__(self, idx): # 找到 idx 属于哪一个文件 file_idx = bisect.bisect_right(self.cum_counts, idx) # 计算在该文件内的局部索引 prev = self.cum_counts[file_idx-1] if file_idx>0 else 0 local_idx = idx - prev with h5py.File(self.paths[file_idx],'r') as f: data = f['data'][local_idx] label = f['label'][local_idx] return data, label
评论 (0)