在大规模分布式异步训练中,您可以使用WorkQueue进行弹性数据切分,以缓解长尾效应,从而降低模型训练所需的时间。本文介绍WorkQueue的调用格式、参数及其提供的方法。同时,以文件数据源和MaxCompute表数据源为例,介绍实现数据切分的经典示例。
背景信息
在大规模分布式异步训练中,如果每个Worker读取相同数量的样本,则慢节点的训练时长会远大于其他节点,造成长尾效应。并且随着训练规模扩大,长尾效应会越来越严重,导致训练的整体数据吞吐降低,进而增加训练时间。
为解决该问题,PAI提供了pai.data.WorkQueue
类,支持对多种数据源进行弹性数据切分,让慢节点获取较少的训练数据,快节点获取更多的训练数据,以缓解长尾效应,从而降低模型训练所需的时间。
版本配套关系
- Python版本:Python 2.7
- PAI-TensorFlow版本:PAI-TensorFlow 1.12
pai.data.WorkQueue
- 功能
工作项队列类,用于统一管理所有Worker上的工作项。每个Worker的当前剩余工作项被消费完后,会从同一个WorkQueue获得新的工作项,并将其作为数据源进行训练,从而使得训练快的Worker获得更多的工作项进行训练,以减少长尾效应。
- 格式
class pai.data.WorkQueue(works, num_epochs=1, shuffle=True, seed=None, prefix=None, num_slices=None, name='work_queue')
- 参数
参数名 描述 类型 是否必选 默认值 works 文件名或表名列表。 LIST of STRING 是 无 num_epochs 读取全部数据的次数。 INT 否 1 shuffle 是否每个Epoch都随机重洗数据,取值如下: - True:每个Epoch都随机重洗数据。
- False:不进行数据重洗。
BOOL 否 True seed 重洗数据的随机种子。取值为None时,表示系统自动选取随机种子。 INT 否 None prefix 工作项(文件名或表名)的前缀。取值为None时,表示无前缀。 STRING 否 None num_slices 工作项的总数量。集群越不稳定,需要将工作项总数量配置的越大,通常为Worker数量的10倍以上。取值为None时,表示不分片。 INT 否 None num_clients 工作队列支持的最大工作抢占并发数。 INT 否 1 name 工作队列的名称。 STRING 否 work_queue - 返回值
返回WorkQueue对象,您可以使用该对象调用
pai.data.WorkQueue
类提供的方法。
pai.data.WorkQueue提供的方法
pai.data.WorkQueue
类提供以下方法:
take
- 功能
从全局工作队列获取一个工作项,并下载至本地。
- 格式
WorkQueue.take()
- 参数
无
- 返回值
返回值类型为
tensorflow.Tensor
。
- 功能
input_dataset
- 功能
返回一个Dataset,其每个元素为一个工作项。
- 格式
WorkQueue.input_dataset()
- 参数
无
- 返回值
返回值类型为
tensorflow.data.Dataset
。
- 功能
input_producer
- 功能
返回全局工作队列在本地的代理队列,为Reader类Op使用。
- 格式
WorkQueue.input_producer()
- 参数
无
- 返回值
返回值类型为
tensorflow.FIFOQueue
。
- 功能
add_summary
- 功能
在Tensorboard中显示WorkQueue的资源水位信息。
- 格式
WorkQueue.add_summary()
- 参数
无
- 返回值
无
- 功能
典型示例
pai.data.WorkQueue
类支持对多种数据源进行弹性数据切分,以下分别以文件数据源和MaxCompute表数据源为例,介绍如何使用pai.data.WorkQueue
类实现弹性数据切分(仅提供核心代码片段):
- 文件数据源
import pai # ... # path1、path2及path3表示需要读取的文件列表。 # shuffle取值为True,表示每个Epoch都随机化打散文件路径。 work_queue = pai.data.WorkQueue([path1, path2, path3], shuffle=True) # 让WorkQueue支持TensorBoard。 work_queue.add_summary() # 创建文件读取器。 reader = tf.TextLineReader() # 从文件列表中读取2条记录。 keys, values = reader.read_up_to(work_queue.input_producer(), num_records=2) with tf.train.MonitoredTrainingSession() as sess: sess.run(...)
- MaxCompute表数据源
- TableRecordDataset数据源
import pai #... # odps_path1、odps_path2及odps_path3表示需要读取的MaxCompute表列表。 # shuffle取值为True,表示每个Epoch都随机化打散表路径。 # num_slices为工作项总数量。 # FLAGS.num_workers为训练中的Worker数量。 work_queue = pai.data.WorkQueue([odps_path1, odps_path2, odps_path3],shuffle=True, num_slices=FLAGS.num_workers * 10) # 创建文件名Dataset。 filenames_dataset = work_queue.input_dataset() # 将dataset作为文件名传入TableRecordDataset。 dataset = tf.data.TableRecordDataset(filenames_dataset, record_defaults=...)
关于
tf.data.TableRecordDataset
接口的调用,请参见TableRecordDataset。 - TableRecordReader数据源
import pai# ...# odps_path1、odps_path2及odps_path3表示需要读取的MaxCompute表列表。# shuffle取值为True,表示每个Epoch都随机化打散表路径。# num_slices为工作项总数量。# FLAGS.num_workers为训练中的Worker数量。work_queue = pai.data.WorkQueue( [odps_path1, odps_path2, odps_path3], shuffle=True, num_slices=FLAGS.num_workers * 10)# 创建表读取器。reader = tf.TableRecordReader()# 从表中读取2条记录。keys, values = reader.read_up_to(work_queue.input_producer(), num_records=2)
- TableRecordDataset数据源
内容没看懂? 不太想学习?想快速解决? 有偿解决: 联系专家
阿里云企业补贴进行中: 马上申请
腾讯云限时活动1折起,即将结束: 马上收藏
同尘科技为腾讯云授权服务中心。
购买腾讯云产品享受折上折,更有现金返利:同意关联,立享优惠
转转请注明出处:https://www.yunxiaoer.com/164525.html