DataParallel(module, device_ids=None, output_device=None, dim=0)¶
Implements data parallelism at the module level.
This container parallelizes the application of the given
moduleby splitting the input across the specified devices by chunking in the batch dimension (other objects will be copied once per device). In the forward pass, the module is replicated on each device, and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.
The batch size should be larger than the number of GPUs used.
It is recommended to use
DistributedDataParallel, instead of this class, to do multi-GPU training, even if there is only a single node. See: Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel and Distributed Data Parallel.
Arbitrary positional and keyword inputs are allowed to be passed into DataParallel but some types are specially handled. tensors will be scattered on dim specified (default 0). tuple, list and dict types will be shallow copied. The other types will be shared among different threads and can be corrupted if written to in the model’s forward pass.
modulemust have its parameters and buffers on
device_idsbefore running this
In each forward,
moduleis replicated on each device, so any updates to the running module in
forwardwill be lost. For example, if
modulehas a counter attribute that is incremented in each
forward, it will always stay at the initial value because the update is done on the replicas which are destroyed after
DataParallelguarantees that the replica on
devicewill have its parameters and buffers sharing storage with the base parallelized
module. So in-place updates to the parameters or buffers on
devicewill be recorded. E.g.,
spectral_norm()rely on this behavior to update the buffers.
Forward and backward hooks defined on
moduleand its submodules will be invoked
len(device_ids)times, each with inputs located on a particular device. Particularly, the hooks are only guaranteed to be executed in correct order with respect to operations on corresponding devices. For example, it is not guaranteed that hooks set via
register_forward_pre_hook()be executed before all
forward()calls, but that each such hook be executed before the corresponding
forward()call of that device.
modulereturns a scalar (i.e., 0-dimensional tensor) in
forward(), this wrapper will return a vector of length equal to number of devices used in data parallelism, containing the result from each device.
There is a subtlety in using the
pack sequence -> recurrent network -> unpack sequencepattern in a
DataParallel. See My recurrent network doesn’t work with data parallelism section in FAQ for details.
~DataParallel.module (Module) – the module to be parallelized
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) >>> output = net(input_var) # input_var can be on any device, including CPU