Python API of Rabit

This page contains document of python API of rabit.

Reliable Allreduce and Broadcast Library.

Author: Tianqi Chen

rabit.allreduce(data, op, prepare_fun=None)

Perform allreduce, return the result.

Parameters:
  • data (numpy array) – Input data.
  • op (int) – Reduction operators, can be MIN, MAX, SUM, BITOR
  • prepare_fun (function) – Lazy preprocessing function, if it is not None, prepare_fun(data) will be called by the function before performing allreduce, to intialize the data If the result of Allreduce can be recovered directly, then prepare_fun will NOT be called
Returns:

result – The result of allreduce, have same shape as data

Return type:

array_like

Notes

This function is not thread-safe.

rabit.broadcast(data, root)

Broadcast object from one node to all other nodes.

Parameters:
  • data (any type that can be pickled) – Input data, if current rank does not equal root, this can be None
  • root (int) – Rank of the node to broadcast data from.
Returns:

object – the result of broadcast.

Return type:

int

rabit.checkpoint(global_model, local_model=None)

Checkpoint the model.

This means we finished a stage of execution. Every time we call check point, there is a version number which will increase by one.

Parameters:
  • global_model (anytype that can be pickled) – globally shared model/state when calling this function, the caller need to gauranttees that global_model is the same in all nodes
  • local_model (anytype that can be pickled) – Local model, that is specific to current node/rank. This can be None when no local state is needed.

Notes

local_model requires explicit replication of the model for fault-tolerance. This will bring replication cost in checkpoint function. while global_model do not need explicit replication. It is recommended to use global_model if possible.

rabit.finalize()

Finalize the rabit engine.

Call this function after you finished all jobs.

rabit.get_processor_name()

Get the processor name.

Returns:name – the name of processor(host)
Return type:str
rabit.get_rank()

Get rank of current process.

Returns:rank – Rank of current process.
Return type:int
rabit.get_world_size()

Get total number workers.

Returns:n – Total number of process.
Return type:int
rabit.init(args=None, lib='standard')

Intialize the rabit module, call this once before using anything.

Parameters:
  • args (list of str, optional) – The list of arguments used to initialized the rabit usually you need to pass in sys.argv. Defaults to sys.argv when it is None.
  • lib ({'standard', 'mock', 'mpi'}) – Type of library we want to load
rabit.load_checkpoint(with_local=False)

Load latest check point.

Parameters:with_local (bool, optional) – whether the checkpoint contains local model
Returns:tuple – if with_local: return (version, gobal_model, local_model) else return (version, gobal_model) if returned version == 0, this means no model has been CheckPointed and global_model, local_model returned will be None
Return type:tuple
rabit.tracker_print(msg)

Print message to the tracker.

This function can be used to communicate the information of the progress to the tracker

Parameters:msg (str) – The message to be printed to tracker.
rabit.version_number()

Returns version number of current stored model.

This means how many calls to CheckPoint we made so far.

Returns:version – Version number of currently stored model
Return type:int