This section describes how to parallelize your existing app to support MALT-2. MALT-2 torch package overloads the optim package. It provides functions for distributed optimization as well as additional helper functions to split/permute data correctly on each replica.
Simple add the dstoptim package in your training file.
Use the distributed SGD training procedure when calling optim. E.g. replace ‘sgd’ or ‘adam’ with ‘dstsgd’ in your code as:
No other changes are required to communicate or average the parameters. The optimizer in dstsgd takes care of communicating and averaging the parallel models. However, one may require to permute/split the data on each machine differently to ensure that each replica is not processing the same inputs.
Other options to modify SGD and malt-2 behavior can be controlled by passing options (similar to passing options to the optimizer). As an example, one may add the following lines to their code:
where optimparams is the optimization parameters passed using the optim api. The other possible options for parameters with default values first are below:
In order to make sure each parallel replica processes random split,
MALT-2 provides function to accomplish this in your existing code.
E.g. Instead of using torch.randperm (tensor)
use optim.randperm(tensor)
.
This function has the same semantics as torch.randperm()
but uses different
manual seed on each replicas to ensure that each replica do not do the same work.
Additionally, MALT-2 also provides optim.nProc()
that returns the number of
concurrent replicas. This is useful to split the data across replicas. See the
fb.resnet.lua for examples of how this is accomplished (see train.lua).