2 Commits

100 changed files with 389 additions and 12164 deletions
Split View
  1. +1
    -1
      .gitignore
  2. +1
    -32
      mindnlp/__init__.py
  3. +0
    -79
      mindnlp/core/cuda/amp/autocast_mode.py
  4. +0
    -19
      mindnlp/core/distributed/_shard/checkpoint/__init__.py
  5. +0
    -13
      mindnlp/core/distributed/_shard/sharded_tensor/_ops/__init__.py
  6. +0
    -21
      mindnlp/core/distributed/_sharded_tensor/__init__.py
  7. +0
    -22
      mindnlp/core/distributed/_sharding_spec/__init__.py
  8. +0
    -9
      mindnlp/core/distributed/_tensor/api.py
  9. +0
    -10
      mindnlp/core/distributed/_tensor/placement_types.py
  10. +0
    -77
      mindnlp/core/distributed/elastic/__init__.py
  11. +0
    -41
      mindnlp/core/distributed/elastic/agent/server/__init__.py
  12. +0
    -957
      mindnlp/core/distributed/elastic/agent/server/api.py
  13. +0
    -65
      mindnlp/core/distributed/elastic/agent/server/health_check_server.py
  14. +0
    -417
      mindnlp/core/distributed/elastic/agent/server/local_elastic_agent.py
  15. +0
    -52
      mindnlp/core/distributed/elastic/control_plane.py
  16. +0
    -170
      mindnlp/core/distributed/elastic/events/__init__.py
  17. +0
    -114
      mindnlp/core/distributed/elastic/events/api.py
  18. +0
    -22
      mindnlp/core/distributed/elastic/events/handlers.py
  19. +0
    -164
      mindnlp/core/distributed/elastic/metrics/__init__.py
  20. +0
    -216
      mindnlp/core/distributed/elastic/metrics/api.py
  21. +0
    -233
      mindnlp/core/distributed/elastic/multiprocessing/__init__.py
  22. +0
    -923
      mindnlp/core/distributed/elastic/multiprocessing/api.py
  23. +0
    -383
      mindnlp/core/distributed/elastic/multiprocessing/errors/__init__.py
  24. +0
    -166
      mindnlp/core/distributed/elastic/multiprocessing/errors/error_handler.py
  25. +0
    -19
      mindnlp/core/distributed/elastic/multiprocessing/errors/handlers.py
  26. +0
    -104
      mindnlp/core/distributed/elastic/multiprocessing/redirects.py
  27. +0
    -16
      mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/__init__.py
  28. +0
    -34
      mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/handlers.py
  29. +0
    -78
      mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py
  30. +0
    -158
      mindnlp/core/distributed/elastic/multiprocessing/tail_log.py
  31. +0
    -167
      mindnlp/core/distributed/elastic/rendezvous/__init__.py
  32. +0
    -384
      mindnlp/core/distributed/elastic/rendezvous/api.py
  33. +0
    -273
      mindnlp/core/distributed/elastic/rendezvous/c10d_rendezvous_backend.py
  34. +0
    -1431
      mindnlp/core/distributed/elastic/rendezvous/dynamic_rendezvous.py
  35. +0
    -1077
      mindnlp/core/distributed/elastic/rendezvous/etcd_rendezvous.py
  36. +0
    -217
      mindnlp/core/distributed/elastic/rendezvous/etcd_rendezvous_backend.py
  37. +0
    -248
      mindnlp/core/distributed/elastic/rendezvous/etcd_server.py
  38. +0
    -212
      mindnlp/core/distributed/elastic/rendezvous/etcd_store.py
  39. +0
    -96
      mindnlp/core/distributed/elastic/rendezvous/registry.py
  40. +0
    -128
      mindnlp/core/distributed/elastic/rendezvous/static_tcp_rendezvous.py
  41. +0
    -284
      mindnlp/core/distributed/elastic/rendezvous/utils.py
  42. +0
    -54
      mindnlp/core/distributed/elastic/timer/__init__.py
  43. +0
    -283
      mindnlp/core/distributed/elastic/timer/api.py
  44. +0
    -25
      mindnlp/core/distributed/elastic/timer/debug_info_logging.py
  45. +0
    -396
      mindnlp/core/distributed/elastic/timer/file_based_local_timer.py
  46. +0
    -128
      mindnlp/core/distributed/elastic/timer/local_timer.py
  47. +0
    -9
      mindnlp/core/distributed/elastic/utils/__init__.py
  48. +0
    -62
      mindnlp/core/distributed/elastic/utils/api.py
  49. +0
    -184
      mindnlp/core/distributed/elastic/utils/distributed.py
  50. +0
    -14
      mindnlp/core/distributed/elastic/utils/log_level.py
  51. +0
    -70
      mindnlp/core/distributed/elastic/utils/logging.py
  52. +0
    -225
      mindnlp/core/distributed/elastic/utils/store.py
  53. +0
    -14
      mindnlp/core/distributed/launcher/__init__.py
  54. +0
    -289
      mindnlp/core/distributed/launcher/api.py
  55. +0
    -7
      mindnlp/core/distributed/pipelining/README.md
  56. +0
    -20
      mindnlp/core/distributed/rpc/_testing/__init__.py
  57. +0
    -922
      mindnlp/core/distributed/run.py
  58. +0
    -5
      mindnlp/core/nn/attention/flex_attention.py
  59. +0
    -90
      mindnlp/core/npu/amp/autocast_mode.py
  60. +0
    -0
      mindnlp/core/testing/_internal/__init__.py
  61. +1
    -1
      mindnlp/dataset/transforms/lookup.py
  62. +3
    -3
      mindnlp/experimental/rwkv6/modeling_rwkv6.py
  63. +1
    -1
      mindnlp/experimental/rwkv6/sampler_rwkv6.py
  64. +16
    -16
      mindnlp/integrations/safetensors.py
  65. +2
    -2
      mindnlp/quant/mindbnb/bitsandbytes/nn/modules.py
  66. +1
    -1
      mindnlp/quant/mindbnb/bitsandbytes/utils.py
  67. +2
    -2
      mindnlp/quant/mindbnb/integrations/replace_modules.py
  68. +1
    -1
      mindnlp/quant/mindbnb/tests/test_mindbnb_linear.py
  69. +3
    -3
      mindnlp/quant/smooth_quant/quant.py
  70. +1
    -1
      mindnlp/quant/smooth_quant/smooth.py
  71. +1
    -1
      mindnlp/transformers/__init__.py
  72. +4
    -4
      mindnlp/transformers/generation/logits_process.py
  73. +111
    -111
      mindnlp/transformers/masking_utils.py
  74. +1
    -2
      mindnlp/transformers/modeling_utils.py
  75. +2
    -2
      mindnlp/transformers/ms_utils.py
  76. +7
    -7
      mindnlp/transformers/trainer.py
  77. +1
    -1
      mindnlp/triton/__init__.py
  78. +1
    -1
      mindnlp/utils/decorators.py
  79. +3
    -3
      mindnlp/utils/safetensors_patch.py
  80. +1
    -1
      mindnlp/utils/testing_utils.py
  81. +0
    -0
      mindtorch/_C/_ConvBackend.py
  82. +13
    -7
      mindtorch/_C/__init__.py
  83. +117
    -0
      mindtorch/_C/_distributed_c10d.py
  84. +8
    -8
      mindtorch/_C/_nn.py
  85. +1
    -1
      mindtorch/_C/size.py
  86. +0
    -0
      mindtorch/__future__.py
  87. +31
    -1
      mindtorch/__init__.py
  88. +0
    -0
      mindtorch/_apis/__init__.py
  89. +9
    -9
      mindtorch/_apis/cpu.py
  90. +4
    -4
      mindtorch/_apis/gpu.py
  91. +35
    -35
      mindtorch/_apis/meta.py
  92. +0
    -0
      mindtorch/_apis/npu.py
  93. +6
    -6
      mindtorch/_bind.py
  94. +0
    -0
      mindtorch/_custom_ops.py
  95. +0
    -0
      mindtorch/_dtype.py
  96. +0
    -0
      mindtorch/_dynamo/__init__.py
  97. +0
    -0
      mindtorch/_dynamo/_trace_wrapped_higher_order_op.py
  98. +0
    -0
      mindtorch/_dynamo/config.py
  99. +0
    -0
      mindtorch/_dynamo/decorators.py
  100. +0
    -0
      mindtorch/_dynamo/eval_frame.py

+ 1
- 1
.gitignore View File

@@ -154,7 +154,7 @@ RuntimeProfiler*
checkpoint-*/
data*/
!mindnlp/data/
!mindnlp/core/utils/data/
!mindtorch/utils/data/
!mindnlp/dataset/
!docs/api/data/
!data2vec/


+ 1
- 32
mindnlp/__init__.py View File

@@ -17,44 +17,13 @@
MindNLP library.
"""
import os
import platform

# huggingface env
if os.environ.get('HF_ENDPOINT', None) is None:
os.environ["HF_ENDPOINT"] = 'https://hf-mirror.com'

# for huawei cloud modelarts
if 'RANK_TABLE_FILE' in os.environ:
del os.environ['RANK_TABLE_FILE']

import mindspore
from mindspore._c_expression import MSContext # pylint: disable=no-name-in-module, import-error
try:
from mindspore._c_expression import disable_multi_thread
except:
disable_multi_thread = None

if os.environ.get('DEVICE_TARGET', None) is not None:
mindspore.set_device(os.environ.get('DEVICE_TARGET'))

# for different ascend devices
if platform.system().lower() == 'linux' and mindspore.get_context('device_target') == 'Ascend':
SOC = MSContext.get_instance().get_ascend_soc_version()
# enable vmm since only vmm can release device memory when del tensor.
if SOC != 'ascend310b':
os.environ["MS_ALLOC_CONF"] = 'enable_vmm:True,vmm_align_size:2MB'

if SOC in ('ascend910', 'ascend310b'):
# context.set_context(ascend_config={"precision_mode": "allow_mix_precision"})
mindspore.device_context.ascend.op_precision.precision_mode('allow_mix_precision')
if SOC == 'ascend310b' and disable_multi_thread is not None:
disable_multi_thread()

# set mindnlp.core to torch
from .utils.torch_proxy import initialize_torch_proxy, setup_metadata_patch
initialize_torch_proxy()
setup_metadata_patch()

import mindtorch
from .utils.safetensors_patch import setup_safetensors_patch
setup_safetensors_patch()



+ 0
- 79
mindnlp/core/cuda/amp/autocast_mode.py View File

@@ -1,79 +0,0 @@
# mypy: allow-untyped-defs
import functools
from typing import Any
from typing_extensions import deprecated
from mindnlp import core
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
class autocast(core.amp.autocast_mode.autocast):
r"""See :class:`core.autocast`.
``core.cuda.amp.autocast(args...)`` is deprecated. Please use ``core.amp.autocast("cuda", args...)`` instead.
"""
@deprecated(
"`core.cuda.amp.autocast(args...)` is deprecated. "
"Please use `core.amp.autocast('cuda', args...)` instead.",
category=FutureWarning,
)
def __init__(
self,
enabled: bool = True,
dtype: core.dtype = core.float16,
cache_enabled: bool = True,
):
super().__init__(
"cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
)
def __enter__(self):
return super().__enter__()
# TODO: discuss a unified TorchScript-friendly API for autocast
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
return super().__exit__(exc_type, exc_val, exc_tb)
def __call__(self, func):
return super().__call__(func)
# Preserved only for BC reasons
@deprecated(
"`core.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. "
"Please use `core.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.",
category=FutureWarning,
)
def _cast(value, dtype):
return core.amp.autocast_mode._cast(value, "cuda", dtype)
@deprecated(
"`core.cuda.amp.custom_fwd(args...)` is deprecated. "
"Please use `core.amp.custom_fwd(args..., device_type='cuda')` instead.",
category=FutureWarning,
)
def custom_fwd(fwd=None, *, cast_inputs=None):
"""
``core.cuda.amp.custom_fwd(args...)`` is deprecated. Please use
``core.amp.custom_fwd(args..., device_type='cuda')`` instead.
"""
return functools.partial(core.amp.custom_fwd, device_type="cuda")(
fwd=fwd, cast_inputs=cast_inputs
)
@deprecated(
"`core.cuda.amp.custom_bwd(args...)` is deprecated. "
"Please use `core.amp.custom_bwd(args..., device_type='cuda')` instead.",
category=FutureWarning,
)
def custom_bwd(bwd):
"""
``core.cuda.amp.custom_bwd(args...)`` is deprecated. Please use
``core.amp.custom_bwd(args..., device_type='cuda')`` instead.
"""
return functools.partial(core.amp.custom_bwd, device_type="cuda")(bwd)

+ 0
- 19
mindnlp/core/distributed/_shard/checkpoint/__init__.py View File

@@ -1,19 +0,0 @@
# Keep old package for BC purposes, this file should be removed once
# everything moves to the `core.distributed.checkpoint` package.
import sys
import warnings
from mindnlp import core
from core.distributed.checkpoint import * # noqa: F403
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`core.distributed._shard.checkpoint` will be deprecated, "
"use `core.distributed.checkpoint` instead",
DeprecationWarning,
stacklevel=2,
)
sys.modules["core.distributed._shard.checkpoint"] = core.distributed.checkpoint

+ 0
- 13
mindnlp/core/distributed/_shard/sharded_tensor/_ops/__init__.py View File

@@ -1,13 +0,0 @@
from mindnlp import core.distributed._shard.sharded_tensor._ops.misc_ops
from mindnlp import core.distributed._shard.sharded_tensor._ops.tensor_ops
# Import all ChunkShardingSpec ops
from core.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import (
sharded_embedding,
)
from core.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import (
sharded_embedding_bag,
)
from .binary_cmp import allclose, equal
from .init import constant_, kaiming_uniform_, normal_, uniform_

+ 0
- 21
mindnlp/core/distributed/_sharded_tensor/__init__.py View File

@@ -1,21 +0,0 @@
# Keep old package for BC purposes, this file should be removed once
# everything moves to the `core.distributed._shard` package.
import sys
import warnings
from mindnlp import core
from core.distributed._shard.sharded_tensor import * # noqa: F403
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`core.distributed._sharded_tensor` will be deprecated, "
"use `core.distributed._shard.sharded_tensor` instead",
DeprecationWarning,
stacklevel=2,
)
sys.modules[
"core.distributed._sharded_tensor"
] = core.distributed._shard.sharded_tensor

+ 0
- 22
mindnlp/core/distributed/_sharding_spec/__init__.py View File

@@ -1,22 +0,0 @@
# Keep old package for BC purposes, this file should be removed once
# everything moves to the `core.distributed._shard` package.
import sys
import warnings
from mindnlp import core
from core.distributed._shard.sharding_spec import * # noqa: F403
with warnings.catch_warnings():
warnings.simplefilter("always")
warnings.warn(
"`core.distributed._sharding_spec` will be deprecated, "
"use `core.distributed._shard.sharding_spec` instead",
DeprecationWarning,
stacklevel=2,
)
from mindnlp import core.distributed._shard.sharding_spec as _sharding_spec
sys.modules["core.distributed._sharding_spec"] = _sharding_spec

+ 0
- 9
mindnlp/core/distributed/_tensor/api.py View File

@@ -1,9 +0,0 @@
"""
NOTE: core.distributed._tensor has been moved to core.distributed.tensor.
The imports here are purely for backward compatibility. We will remove these
imports in a few releases
TODO: throw warnings when this module imported
"""
from core.distributed.tensor._api import * # noqa: F401, F403

+ 0
- 10
mindnlp/core/distributed/_tensor/placement_types.py View File

@@ -1,10 +0,0 @@
"""
NOTE: core.distributed._tensor has been moved to core.distributed.tensor.
The imports here are purely for backward compatibility. We will remove these
imports in a few releases
TODO: throw warnings when this module imported
"""
from core.distributed.tensor._dtensor_spec import * # noqa: F401, F403
from core.distributed.tensor.placement_types import * # noqa: F401, F403

+ 0
- 77
mindnlp/core/distributed/elastic/__init__.py View File

@@ -1,77 +0,0 @@
#!/usr/bin/env/python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Torchelastic agent and user worker failover contract:
**TL;DR;**:
* TE(torchelastic) expects user workers to finish with the 5 minutes drift
* It is better to design DDP app to fail for all workers, rather than a single one.
* TE does not synchronize number of restarts between agents
* TE re-rendezvous does not trigger restart decrease
* When a single agent finishes its job(successfully or not), it will close rendezvous.
If other agents still have workers in progress, they will be terminated.
* Based on above, scale down does not work if at least single agent finishes the job.
* When Scale up is detected by agents, it will not decrease ``max_restarts``
In general TE(torchelastic) can launch arbitrary user code, but there is some
clarifications need to be done around what failover mechanism torchelastic
provides and what failover mechanism it expects from user workers.
Torchelastic currently supports DDP style applications. That means that
TE expects *ALL* workers finish approximately at the same time. In practice,
it is nearly to impossible to guarantee that all workers in arbitrary
DDP application finish at the time, so TE provides a finalization barrier
that waits for TIMEOUT(5 minutes) for worker finalization.
**Worker Failure**
When worker fails, TE will check the number of restarts
available, if there is more than 0 restarts, TE will start a new rendezvous
round and restart the worker process. New rendezvous round will other
TE agents to terminate their workers.
.. note:: The TE agent does not synchronize restarts between themselves.
When a single agent performs restart, it will trigger a local ``max_restarts``
decrease, other agent will not decrease their ``max_restarts``.
the user to run the distributed application locally on a dev host.
A single worker failure can cause the whole cluster to fail:
If a single worker is constantly failing, it will cause the TE agent
``max_restarts`` to go to zero. This will cause an agent to finish its
work and close rendezvous. If there are any other workers on different
agents, they will be terminated.
**Re-Rendezvous**
Re-rendezvous occurs when TE agents detect a new node
trying to joint a cluster. TE will not decrease ``max_restarts``. TE agents
will terminate its workers and start a new rendezvous round.
Note about DynamicRendezvous(etcd-v2, c10d-experimental): If the rendezvous
has already max_nodes, the new node won't be added to the wait list right
away since there is no need to tear down a rendezvous that is already fully
utilized. The new node will wait until its timeout (600 secs by default)
and periodically check the number of participants. If the number becomes
less than max_nodes, it will be added to the wait list; otherwise, it will time out after 600 secs.
*Scale up event*. When scale up event happens, torchelastic rendezvous
will detect that there are new nodes trying to join. Torchelastic agent
will stop all workers and perform re-rendezvous. Note: when scale up event
happens, *``max_restarts``* will *not* decrease.
*Scale down event*. When scale down event happens, rendezvous will not
notify the torchelastic agent about it. If TE agent launched with ``max_restarts=0`` ,
it relies on the underlying scheduler to handle job restart. If the ``max_restarts>0`` ,
TE agent will terminate workers and start a new rdzv round, which is a *Scale up event*.
"""

+ 0
- 41
mindnlp/core/distributed/elastic/agent/server/__init__.py View File

@@ -1,41 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
The elastic agent is the control plane of torchelastic.
It is a process that launches and manages underlying worker processes.
The agent is responsible for:
1. Working with distributed torch: the workers are started with all the
necessary information to successfully and trivially call
``core.distributed.init_process_group()``.
2. Fault tolerance: monitors workers and upon detecting worker failures
or unhealthiness, tears down all workers and restarts everyone.
3. Elasticity: Reacts to membership changes and restarts workers with the new
members.
The simplest agents are deployed per node and works with local processes.
A more advanced agent can launch and manage workers remotely. Agents can
be completely decentralized, making decisions based on the workers it manages.
Or can be coordinated, communicating to other agents (that manage workers
in the same job) to make a collective decision.
"""
from .api import ( # noqa: F401
ElasticAgent,
RunResult,
SimpleElasticAgent,
Worker,
WorkerGroup,
WorkerSpec,
WorkerState,
)
from .local_elastic_agent import TORCHELASTIC_ENABLE_FILE_TIMER, TORCHELASTIC_TIMER_FILE

+ 0
- 957
mindnlp/core/distributed/elastic/agent/server/api.py View File

@@ -1,957 +0,0 @@
# mypy: ignore-errors
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import abc
import json
import os
import signal
import socket
import time
import traceback
import warnings
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from mindnlp import core.distributed.elastic.rendezvous as rdzv
from mindnlp import core.distributed.elastic.utils.store as store_util
from core.distributed.elastic.events import Event, EventSource, record
from core.distributed.elastic.metrics import prof, put_metric
from core.distributed.elastic.multiprocessing import ProcessFailure, SignalException
from core.distributed.elastic.rendezvous import RendezvousGracefulExitError
from core.distributed.elastic.utils.logging import get_logger
__all__ = [
"WorkerSpec",
"Worker",
"WorkerState",
"WorkerGroup",
"RunResult",
"ElasticAgent",
"SimpleElasticAgent",
]
_TERMINAL_STATE_SYNC_ID = "torchelastic/agent/terminal_state"
DEFAULT_ROLE = "default"
logger = get_logger(__name__)
@dataclass
class WorkerSpec:
"""Blueprint information about a particular type of worker.
For a given role, there must only exist a single worker spec.
Worker spec is expected to be homogeneous across all nodes (machine),
that is each node runs the same number of workers for a particular spec.
Args:
role: user-defined role for the workers with this spec
local_world_size: number local workers to run
fn: (deprecated use entrypoint instead)
entrypoint: worker function or command
args: arguments to pass to ``entrypoint``
rdzv_handler: handles rdzv for this set of workers
max_restarts: number of max retries for the workers
monitor_interval: monitor status of workers every ``n`` seconds
master_port: fixed port to run the c10d store on rank 0
if not specified then will chose a random free port
master_addr: fixed master_addr to run the c10d store on rank 0
if not specified then will chose hostname on agent rank 0
redirects: redirect std streams to a file,
selectively redirect for a particular
local rank by passing a map
tee: tees the specified std stream(s) to console + file,
selectively tee for a particular local rank by passing a map,
takes precedence over ``redirects`` settings.
"""
role: str
local_world_size: int
rdzv_handler: rdzv.RendezvousHandler
fn: Optional[Callable] = None
# TODO @kiuk - make entrypoint a required field
entrypoint: Union[Callable, str, None] = None
args: Tuple = ()
max_restarts: int = 3
monitor_interval: float = 0.1
master_port: Optional[int] = None
master_addr: Optional[str] = None
local_addr: Optional[str] = None
def __post_init__(self):
assert self.local_world_size > 0
assert self.monitor_interval > 0
if self.fn:
warnings.warn(
"WorkerSpec.fn will be deprecated,"
" please use WorkerSpec.entrypoint instead",
category=DeprecationWarning,
)
self.entrypoint = self.fn
assert self.entrypoint
def get_entrypoint_name(self):
"""Get the entry point name.
If the entrypoint is a function (e.g. ``Callable``) returns its ``__qualname__``
else if the entrypoint is a binary (e.g. ``str``), returns the binary name.
"""
if isinstance(self.entrypoint, str):
return os.path.basename(self.entrypoint)
else:
assert self.entrypoint is not None
return self.entrypoint.__qualname__
class Worker:
"""A worker instance.
Contrast this with ``WorkerSpec`` that represents the specifications of a
worker. A ``Worker`` is created from a ``WorkerSpec``. A ``Worker`` is to
a ``WorkerSpec`` as an object is to a class.
The ``id`` of the worker is interpreted
by the specific implementation of ``ElasticAgent``. For a local
agent, it could be the ``pid (int)`` of the worker, for a remote
agent it could be encoded as ``host:port (string)``.
Args:
id (Any): uniquely identifies a worker (interpreted by the agent)
local_rank (int): local rank of the worker
global_rank (int): global rank of the worker
role_rank (int): rank of the worker across all workers that have the same role
world_size (int): number of workers (globally)
role_world_size (int): number of workers that have the same role
"""
__slots__ = [
"id",
"local_rank",
"global_rank",
"role_rank",
"world_size",
"role_world_size",
]
def __init__(
self,
local_rank: int,
global_rank: int = -1,
role_rank: int = -1,
world_size: int = -1,
role_world_size: int = -1,
):
# unique identifier for this worker
self.id: Any = None
# rank of the worker among workers with the same role being monitored
# by the same ``agent`` instance.
self.local_rank: int = local_rank
# rank of the worker among all the workers across all roles
# across all ``agent`` instances.
# Global rank is not stable between re-rendezvous.
self.global_rank: int = global_rank
# rank of the worker among all the workers with the same role
# across all ``agent`` instances.
# Role rank is not stable between re-rendezvous.
self.role_rank: int = role_rank
# total number of workers (globally). Due to elasticity
# the world size may change between re-rendezvous.
self.world_size: int = world_size
# total number of workers that share the same role. Due to elasticity
# the role world size may change between re-rendezvous.
self.role_world_size: int = role_world_size
def __str__(self):
return (
f"local_rank={self.local_rank},global_rank={self.global_rank}"
f",role_rank={self.role_rank},world_size={self.world_size}"
f",role_world_size={self.role_world_size}"
)
def __repr__(self):
return str(self)
class WorkerState(str, Enum):
"""A state of the ``WorkerGroup``.
Workers in a worker group change state as a unit. If a single worker
in a worker group fails the entire set is considered failed::
UNKNOWN - agent lost track of worker group state, unrecoverable
INIT - worker group object created not yet started
HEALTHY - workers running and healthy
UNHEALTHY - workers running and unhealthy
STOPPED - workers stopped (interrupted) by the agent
SUCCEEDED - workers finished running (exit 0)
FAILED - workers failed to successfully finish (exit !0)
A worker group starts from an initial ``INIT`` state,
then progresses to ``HEALTHY`` or ``UNHEALTHY`` states,
and finally reaches a terminal ``SUCCEEDED`` or ``FAILED`` state.
Worker groups can be interrupted and temporarily put into ``STOPPED`` state
by the agent. Workers in ``STOPPED`` state are scheduled to be restarted
in the near future by the agent. Some examples of workers being put into
``STOPPED`` state are:
1. Worker group failure|unhealthy observed
2. Membership change detected
When actions (start, stop, rdzv, retry, etc) on worker group fails
and results in the action being partially applied to the worker group
the state will be ``UNKNOWN``. Typically this happens on uncaught/unhandled
exceptions during state change events on the agent. The agent is not
expected to recover worker groups in ``UNKNOWN`` state and is better off
self terminating and allowing the job manager to retry the node.
"""
UNKNOWN = "UNKNOWN"
INIT = "INIT"
HEALTHY = "HEALTHY"
UNHEALTHY = "UNHEALTHY"
STOPPED = "STOPPED"
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
@staticmethod
def is_running(state: "WorkerState") -> bool:
"""Return the state of the Worker.
Returns:
True if the worker state represents workers still running
(e.g. that the process exists but not necessarily healthy).
"""
return state in {WorkerState.HEALTHY, WorkerState.UNHEALTHY}
class WorkerGroup:
"""A set of ``Worker`` instances.
The class defines a set of ``Worker`` instances for the given ``WorkerSpec`` managed by ``ElasticAgent``. Whether the worker
group contains cross instance workers or not depends on the implementation of the agent.
"""
__slots__ = [
"spec",
"workers",
"store",
"group_rank",
"group_world_size",
"state",
"master_addr",
"master_port",
]
def __init__(self, spec: WorkerSpec):
self.spec = spec
self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)]
# assigned after rdzv
self.store = None
self.group_rank = None
self.group_world_size = None
self.master_addr = None
self.master_port = None
self.state = WorkerState.INIT
class _RoleInstanceInfo:
"""The class is used by the agent to exchange the information with other agents.
The information is used to determine the rank of the workers that agent
manages in heterogeneous environments, where different agents can have
different number of workers.
"""
__slots__ = ["role", "rank", "local_world_size"]
def __init__(self, role: str, rank: int, local_world_size: int):
r"""Initialize the agent class instance.
Args:
role (str): user-defined role for the workers with this spec
rank (int): the rank of the agent
local_world_size (int): number of local workers to run
"""
self.role = role
self.rank = rank
self.local_world_size = local_world_size
def serialize(self) -> bytes:
dict_data = {
"role": self.role,
"rank": self.rank,
"local_world_size": self.local_world_size,
}
return json.dumps(dict_data).encode(encoding="UTF-8")
@staticmethod
def deserialize(data: bytes):
dict_data = json.loads(data.decode(encoding="UTF-8"))
return _RoleInstanceInfo(
dict_data["role"], dict_data["rank"], dict_data["local_world_size"]
)
@staticmethod
def compare(obj1, obj2) -> int:
if obj1.role == obj2.role:
return obj1.rank - obj2.rank
elif obj1.role > obj2.role:
return 1
else:
return -1
@staticmethod
def find_role_boundaries(roles_infos: List, role: str) -> Tuple[int, int]:
start_idx, end_idx = -1, -1
for idx, role_info in enumerate(roles_infos):
if role_info.role == role:
if start_idx == -1:
start_idx = idx
end_idx = idx
return (start_idx, end_idx)
@dataclass
class RunResult:
"""Return results of the worker executions.
Run results follow an "all-or-nothing" policy where the run is successful if and
only if ALL local workers managed by this agent complete successfully.
If the result is successful (e.g. ``is_failed() = False``) then the ``return_values``
field contains the outputs (return values) of the workers managed by THIS agent mapped
by their GLOBAL ranks. That is ``result.return_values[0]`` is the return value of
global rank 0.
.. note:: ``return_values`` are only meaningful for when the worker entrypoint
is a function. Workers specified as a binary entrypoint do not canonically
have a return value and the ``return_values`` field is meaningless and
may be empty.
If ``is_failed()`` returns ``True`` then the ``failures`` field contains the
failure information, again, mapped by the GLOBAL rank of the worker that failed.
The keys in ``return_values`` and ``failures`` are mutually exclusive, that is,
a worker's final state can only be one of: succeeded, failed. Workers intentionally
terminated by the agent according to the agent's restart policy, are not represented
in either ``return_values`` nor ``failures``.
"""
state: WorkerState
return_values: Dict[int, Any] = field(default_factory=dict)
failures: Dict[int, ProcessFailure] = field(default_factory=dict)
def is_failed(self) -> bool:
return self.state == WorkerState.FAILED
def _get_fq_hostname() -> str:
return socket.getfqdn(socket.gethostname())
class ElasticAgent(abc.ABC):
"""An agent process responsible for managing one or more worker processes.
The worker processes are assumed to be regular distributed PyTorch scripts.
When the worker process is created by the agent, the agent provides the
necessary information for the worker processes to properly initialize
a torch process group.
The exact deployment topology and ratio of agent-to-worker is dependent
on the specific implementation of the agent and the user's job placement
preferences. For instance, to run a distributed training job on GPU with
8 trainers (one per GPU) one can:
1. Use 8 x single GPU instances, place an agent per instance, managing
1 worker per agent.
2. Use 4 x double GPU instances, place an agent per instance, managing
2 workers per agent.
3. Use 2 x quad GPU instances, place an agent per instance, managing
4 workers per agent.
4. Use 1 x 8 GPU instance, place an agent per instance, managing
8 workers per agent.
Usage
::
group_result = agent.run()
if group_result.is_failed():
# workers failed
failure = group_result.failures[0]
logger.exception("worker 0 failed with exit code : %s", failure.exit_code)
else:
return group_result.return_values[0] # return rank 0's results
"""
@abc.abstractmethod
def run(self, role: str = DEFAULT_ROLE) -> RunResult:
"""Run the agent.
Supports retrying the worker group on failures up to ``max_restarts``.
Returns:
The result of the execution, containing the return values or
failure details for each worker mapped by the worker's global rank.
Raises:
Exception - any other failures NOT related to worker process
"""
raise NotImplementedError
@abc.abstractmethod
def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup:
"""Return the ``WorkerGroup`` for the given ``role``.
Note that the worker group is a mutable object and hence in a
multi-threaded/process environment it may change state.
Implementors are encouraged (but not required) to return
a defensive read-only copy.
"""
raise NotImplementedError
class SimpleElasticAgent(ElasticAgent):
"""An ``ElasticAgent`` that manages one particular type of worker role.
An ``ElasticAgent`` that manages workers (``WorkerGroup``) for a single ``WorkerSpec``
such as one particular type of worker role.
"""
def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300):
self._worker_group = WorkerGroup(spec)
self._remaining_restarts = self._worker_group.spec.max_restarts
self._store = None
self._exit_barrier_timeout = exit_barrier_timeout
self._total_execution_time = 0
def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup:
return self._worker_group
@abc.abstractmethod
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
r"""Start ``worker_group.spec.local_world_size`` number of workers.
This is according to worker spec for the worker group .
Returns a map of ``local_rank`` to worker ``id``.
"""
raise NotImplementedError
@abc.abstractmethod
def _stop_workers(
self, worker_group: WorkerGroup, is_restart: bool = False
) -> None:
r"""Stop all workers in the given worker group.
Implementors must deal with workers in all states defined by
``WorkerState``. That is, it must gracefully handle stopping
non-existent workers, unhealthy (stuck) workers, etc.
"""
raise NotImplementedError
@abc.abstractmethod
def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
r"""Check on the workers for the ``worker_group``.
This function also returns the new state of the worker group.
"""
raise NotImplementedError
@abc.abstractmethod
def _shutdown(
self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False
) -> None:
"""Clean up any resources that were allocated during the agent's work.
Args:
death_sig: Signal to send to the child process, SIGTERM is default
"""
raise NotImplementedError
@prof
def _rendezvous(self, worker_group: WorkerGroup) -> None:
r"""Run rendezvous for the workers specified by the worker spec.
Assigns workers a new global rank and world size.
Updates the rendezvous store for the worker group.
"""
spec = worker_group.spec
with self.record_duration("RENDEZVOUS"):
rdzv_info = spec.rdzv_handler.next_rendezvous()
store = rdzv_info.store
group_rank = rdzv_info.rank
group_world_size = rdzv_info.world_size
# master_addr/master_port could be explicitly overriden
# TODO: BC - specific to static rdzv and can be simplifed further
master_addr = spec.master_addr or rdzv_info.bootstrap_store_info.master_addr
master_port = spec.master_port or rdzv_info.bootstrap_store_info.master_port
self._store = store
with self.record_duration("ASSIGN_WORKER_RANKS"):
workers = self._assign_worker_ranks(
store, group_rank, group_world_size, spec
)
worker_group.workers = workers
worker_group.store = store
worker_group.group_rank = group_rank
worker_group.group_world_size = group_world_size
worker_group.master_addr = master_addr
worker_group.master_port = master_port
restart_count = spec.max_restarts - self._remaining_restarts
logger.info(
"[%(role)s] Rendezvous complete for workers. Result:\n"
" restart_count=%(restart_count)s\n"
" master_addr=%(master_addr)s\n"
" master_port=%(master_port)s\n"
" group_rank=%(group_rank)s\n"
" group_world_size=%(group_world_size)s\n"
" local_ranks=%(local_ranks)s\n"
" role_ranks=%(role_ranks)s\n"
" global_ranks=%(global_ranks)s\n"
" role_world_sizes=%(role_world_sizes)s\n"
" global_world_sizes=%(global_world_sizes)s\n",
{
"role": spec.role,
"restart_count": restart_count,
"master_addr": master_addr,
"master_port": master_port,
"group_rank": group_rank,
"group_world_size": group_world_size,
"local_ranks": [worker.local_rank for worker in workers],
"role_ranks": [worker.role_rank for worker in workers],
"global_ranks": [worker.global_rank for worker in workers],
"role_world_sizes": [worker.role_world_size for worker in workers],
"global_world_sizes": [worker.world_size for worker in workers],
},
)
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `core.distributed.elastic.metrics.prof`.
@prof
def _assign_worker_ranks(
self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
) -> List[Worker]:
"""Determine proper ranks for worker processes.
Fast Path: when all workers have the same role and world size. We calculate
the global rank to be group_rank * group_world_size + local_rank. And the
`role_world_size` is the same as `global_world_size`. No TCP store is used in
this case. This is only enabled when users set the environment variable
`TORCH_ELASTIC_WORKER_IDENTICAL` to 1.
Time complexity: each worker O(1), overall O(1)
Slow Path: when workers have different roles and world sizes. We use the
the following algorithm:
1. Each agent writes its configuration(group_rank, group_world_size
, num_workers) to the common store.
2. The rank 0 agent reads all the role_info from the store and
determines each agents worker ranks.
3. Determine the global rank: the global rank of the workers is computed
by cumulative sum of the local_world_size for all workers in front of it.
For efficiency reasons each worker is assigned a base global rank
such that it's workers are in the range [base_global_rank,
base_global_rank + local_world_size).
4. Determine the role rank: The role rank is determined using the algorithms
in the point 3 with the exception that the ranks are calculated with
respect to the role name.
5. The rank 0 agent writes the assigned ranks to the store.
6. Each agent reads the assigned ranks from the store.
Time complexity: each worker O(1), rank0 O(n), overall O(n)
"""
if os.environ.get("TORCH_ELASTIC_WORKER_IDENTICAL", "0") == "1":
global_world_size = group_world_size * spec.local_world_size
base_global_rank = group_rank * spec.local_world_size
base_role_rank = base_global_rank
role_world_size = global_world_size
else:
ROLE_INFO_PREFIX = "torchelastic/role_info/"
ASSIGNED_RANKS_PREFIX = "torchelastic/assigned_ranks/"
agent_role_info = _RoleInstanceInfo(
spec.role, group_rank, spec.local_world_size
)
store.set(f"{ROLE_INFO_PREFIX}{group_rank}", agent_role_info.serialize())
# tcp store is collocated with rank 0 so we can use it to do extra compute to reduce overall # of operations.
if group_rank == 0:
role_infos_bytes = store.multi_get(
[f"torchelastic/role_info/{i}" for i in range(group_world_size)]
)
role_infos = [
_RoleInstanceInfo.deserialize(info_bytes)
for info_bytes in role_infos_bytes
]
role_sizes = defaultdict(lambda: 0)
global_size = 0
for role_info in role_infos:
role_sizes[role_info.role] += role_info.local_world_size
global_size += role_info.local_world_size
base_global_rank = 0
role_ranks = defaultdict(lambda: 0)
keys = []
values = []
for i, role_info in enumerate(role_infos):
keys.append(f"{ASSIGNED_RANKS_PREFIX}{i}")
values.append(
json.dumps(
[
base_global_rank,
global_size,
role_ranks[role_info.role],
role_sizes[role_info.role],
]
)
)
base_global_rank += role_info.local_world_size
role_ranks[role_info.role] += role_info.local_world_size
store.multi_set(keys, values)
# get will block until the data is available in the store.
(
base_global_rank,
global_world_size,
base_role_rank,
role_world_size,
) = json.loads(store.get(f"{ASSIGNED_RANKS_PREFIX}{group_rank}"))
workers = []
for local_rank in range(spec.local_world_size):
worker = Worker(
local_rank=local_rank,
global_rank=base_global_rank + local_rank,
role_rank=base_role_rank + local_rank,
world_size=global_world_size,
role_world_size=role_world_size,
)
workers.append(worker)
return workers
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `core.distributed.elastic.metrics.prof`.
@prof
def _initialize_workers(self, worker_group: WorkerGroup) -> None:
r"""Start a fresh set of workers for the worker_group.
Essentially, a rendezvous followed by a ``start_workers``.
The caller should first call ``_stop_workers()`` to stop running workers
prior to calling this method.
Optimistically sets the state of the worker group that
just started as ``HEALTHY`` and delegates the actual monitoring
of state to ``_monitor_workers()`` method
"""
role = worker_group.spec.role
logger.info("[%s] Rendezvous'ing worker group", role)
# TODO after stopping workers, wait at least monitor_interval*2 for
# workers on different nodes to fail on a collective op before waiting
# on the rdzv barrier, this way we ensure that nodes enter rdzv
# at around the same time and reduce false positive rdzv timeout errors
self._rendezvous(worker_group)
logger.info("[%s] Starting worker group", role)
worker_ids = self._start_workers(worker_group)
for local_rank, w_id in worker_ids.items():
worker = worker_group.workers[local_rank]
worker.id = w_id
worker_group.state = WorkerState.HEALTHY
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `core.distributed.elastic.metrics.prof`.
@prof
def _restart_workers(self, worker_group: WorkerGroup) -> None:
"""Restart (stops, rendezvous, starts) all local workers in the group."""
role = worker_group.spec.role
logger.info("[%s] Stopping worker group", role)
self._stop_workers(worker_group, is_restart=True)
worker_group.state = WorkerState.STOPPED
self._initialize_workers(worker_group)
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `core.distributed.elastic.metrics.prof`.
@prof
def run(self, role: str = DEFAULT_ROLE) -> RunResult:
start_time = time.monotonic()
shutdown_called: bool = False
try:
result = self._invoke_run(role)
self._total_execution_time = int(time.monotonic() - start_time)
self._record_metrics(result)
self._record_worker_events(result)
return result
except RendezvousGracefulExitError as e:
logger.info("Rendezvous gracefully exited: %s", e)
except SignalException as e:
logger.warning("Received %s death signal, shutting down workers", e.sigval)
self._shutdown(e.sigval)
shutdown_called = True
raise
finally:
if not shutdown_called:
self._shutdown()
# record the execution time in case there were any exceptions during run.
self._total_execution_time = int(time.monotonic() - start_time)
def get_event_failed(self) -> Event:
return self._construct_event(
state="FAILED",
source=EventSource.AGENT,
raw_error=traceback.format_exc(),
)
def get_event_succeeded(self) -> Event:
return self._construct_event(
state="SUCCEEDED",
source=EventSource.AGENT,
)
def _record_worker_events(self, result: RunResult) -> None:
for worker in self._worker_group.workers:
failure = result.failures.get(worker.global_rank)
state: str = self._get_worker_state(worker, result)
raw_error = json.dumps(failure.error_file_data) if failure else None
record(self._construct_event(state, EventSource.WORKER, worker, raw_error))
def _get_worker_state(self, worker: Worker, result: RunResult) -> str:
failure = result.failures.get(worker.global_rank)
if result.state in {WorkerState.UNHEALTHY, WorkerState.FAILED} and not failure:
# The worker got terminated by the torchelastic agent via SIGTERM signal
return "TERMINATED"
elif failure or worker.global_rank in result.return_values:
return result.state.value
else:
raise ValueError(f"Unknown worker: {worker.global_rank}")
@contextmanager
def record_duration(self, state: str):
start_time = time.perf_counter()
try:
yield
finally:
end_time = time.perf_counter()
duration_ms = (end_time - start_time) * 1000
record(
self._construct_event(
state=state, source=EventSource.AGENT, duration_ms=duration_ms
)
)
def _construct_event(
self,
state: str,
source: EventSource,
worker: Optional[Worker] = None,
raw_error: Optional[str] = None,
duration_ms: Optional[float] = None,
) -> Event:
wg = self._worker_group
spec = wg.spec
md = {
"group_world_size": wg.group_world_size,
"entry_point": spec.get_entrypoint_name(),
}
if worker:
md["local_rank"] = (worker.local_rank,)
md["role_rank"] = (worker.role_rank,)
md["role_world_size"] = (worker.role_world_size,)
global_rank = worker.global_rank
worker_id = str(worker.id)
else:
global_rank = None
worker_id = None
md_str = json.dumps(md)
metadata = {
"run_id": spec.rdzv_handler.get_run_id(),
"global_rank": global_rank,
"group_rank": wg.group_rank,
"worker_id": worker_id,
"role": spec.role,
"hostname": _get_fq_hostname(),
"state": state,
"total_run_time": self._total_execution_time,
"rdzv_backend": spec.rdzv_handler.get_backend(),
"raw_error": raw_error,
"metadata": md_str,
"agent_restarts": spec.max_restarts - self._remaining_restarts,
"duration_ms": duration_ms,
}
return Event(
f"torchelastic.worker.status.{state}", source=source, metadata=metadata
)
def _record_metrics(self, group_results: RunResult):
is_failed = group_results.is_failed()
self._record_flakiness_metric(is_failed)
spec = self._worker_group.spec
restarts_happened = self._remaining_restarts != spec.max_restarts
put_metric(f"workers.{spec.role}.run_total", 1)
self._record_metric_with_condition(
"run_success_with_retries", not is_failed and restarts_happened
)
self._record_metric_with_condition(
"run_success_no_retries", not is_failed and not restarts_happened
)
self._record_metric_with_condition(
"run_failed_with_retries", is_failed and restarts_happened
)
self._record_metric_with_condition(
"run_failed_no_retries", is_failed and not restarts_happened
)
def _record_metric_with_condition(self, metric_name, condition):
spec = self._worker_group.spec
if condition:
put_metric(f"workers.{spec.role}.{metric_name}", 1)
else:
put_metric(f"workers.{spec.role}.{metric_name}", 0)
def _record_flakiness_metric(self, is_failed: bool = False):
if is_failed:
flakiness = 100.0
else:
spec = self._worker_group.spec
flakiness = 100.0 - 100.0 * (self._remaining_restarts + 1) / (
spec.max_restarts + 1
)
spec = self._worker_group.spec
put_metric(f"workers.{spec.role}.flakiness", int(flakiness))
def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
# NOTE: currently only works for a single role
spec = self._worker_group.spec
role = spec.role
logger.info(
"[%s] starting workers for entrypoint: %s", role, spec.get_entrypoint_name()
)
self._initialize_workers(self._worker_group)
monitor_interval = spec.monitor_interval
rdzv_handler = spec.rdzv_handler
while True:
assert self._worker_group.state != WorkerState.INIT
time.sleep(monitor_interval)
run_result = self._monitor_workers(self._worker_group)
state = run_result.state
self._worker_group.state = state
put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
put_metric(f"workers.{role}.{state.name.lower()}", 1)
if state == WorkerState.SUCCEEDED:
logger.info(
"[%s] worker group successfully finished."
" Waiting %s seconds for other agents to finish.",
role,
self._exit_barrier_timeout,
)
self._exit_barrier()
return run_result
elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
if self._remaining_restarts > 0:
logger.info(
"[%s] Worker group %s. "
"%s/%s attempts left;"
" will restart worker group",
role,
state.name,
self._remaining_restarts,
spec.max_restarts,
)
self._remaining_restarts -= 1
self._restart_workers(self._worker_group)
else:
self._stop_workers(self._worker_group)
self._worker_group.state = WorkerState.FAILED
return run_result
elif state == WorkerState.HEALTHY:
# membership changes do not count as retries
num_nodes_waiting = rdzv_handler.num_nodes_waiting()
group_rank = self._worker_group.group_rank
if num_nodes_waiting > 0:
logger.info(
"[%s] Detected %s "
"new nodes from group_rank=%s; "
"will restart worker group",
role,
num_nodes_waiting,
group_rank,
)
self._restart_workers(self._worker_group)
else:
raise Exception( # noqa: TRY002
f"[{role}] Worker group in {state.name} state"
)
def _exit_barrier(self):
"""
Define a barrier that keeps the agent process alive until all workers finish.
Wait for ``exit_barrier_timeout`` seconds for all agents to finish
executing their local workers (either successfully or not). This
acts as a safety guard against user scripts that terminate at different
times.
"""
logger.info(
"Local worker group finished (%s). "
"Waiting %s seconds for other agents to finish",
self._worker_group.state,
self._exit_barrier_timeout,
)
start = time.time()
try:
store_util.barrier(
store=self._store,
world_size=self._worker_group.group_world_size,
key_prefix=_TERMINAL_STATE_SYNC_ID,
barrier_timeout=self._exit_barrier_timeout,
)
logger.info(
"Done waiting for other agents. Elapsed: %s seconds",
time.time() - start,
)
except SignalException as e:
logger.warning("Got termination signal: %s", e.sigval)
raise
except Exception:
logger.exception(
"Error waiting on exit barrier. Elapsed: %s seconds",
time.time() - start,
)

+ 0
- 65
mindnlp/core/distributed/elastic/agent/server/health_check_server.py View File

@@ -1,65 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable
from core.distributed.elastic.utils.logging import get_logger
log = get_logger(__name__)
__all__ = ["HealthCheckServer", "create_healthcheck_server"]
class HealthCheckServer:
"""
Interface for health check monitoring server, which can be extended
by starting tcp/http server on the specified port.
Args:
alive_callback: Callable[[], int], callback to last progress time of agent
port: int, port number to start tcp/http server
timeout: int, timeout seconds to decide agent is alive/dead
"""
_alive_callback: Callable[[], int]
_port: int
_timeout: int
def __init__(
self, alive_callback: Callable[[], int], port: int, timeout: int
) -> None:
self._alive_callback = alive_callback
self._port = port
self._timeout = timeout
def start(self) -> None:
"""
Unsupported functionality for Pytorch, doesn't start any health check server
"""
log.warning("No health check server started")
def stop(self) -> None:
"""
Function to stop health check server
"""
log.info("Stopping noop health check server.")
def create_healthcheck_server(
alive_callback: Callable[[], int],
port: int,
timeout: int,
) -> HealthCheckServer:
"""
creates health check server object
"""
return HealthCheckServer(alive_callback, port, timeout)

+ 0
- 417
mindnlp/core/distributed/elastic/agent/server/local_elastic_agent.py View File

@@ -1,417 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import signal
import socket
import time
import uuid
from string import Template
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
from mindnlp import core.distributed.elastic.timer as timer
from core.distributed.elastic import events
from core.distributed.elastic.agent.server.api import (
RunResult,
SimpleElasticAgent,
WorkerGroup,
WorkerSpec,
WorkerState,
)
from core.distributed.elastic.agent.server.health_check_server import (
create_healthcheck_server,
HealthCheckServer,
)
from core.distributed.elastic.metrics.api import prof
from core.distributed.elastic.multiprocessing import (
LogsSpecs,
PContext,
start_processes,
)
from core.distributed.elastic.utils import macros
from core.distributed.elastic.utils.logging import get_logger
if TYPE_CHECKING:
from core.distributed.elastic.events.api import EventMetadataValue
logger = get_logger(__name__)
__all__ = [
"LocalElasticAgent",
"TORCHELASTIC_ENABLE_FILE_TIMER",
"TORCHELASTIC_TIMER_FILE",
"TORCHELASTIC_HEALTH_CHECK_PORT",
]
TORCHELASTIC_ENABLE_FILE_TIMER = "TORCHELASTIC_ENABLE_FILE_TIMER"
TORCHELASTIC_HEALTH_CHECK_PORT = "TORCHELASTIC_HEALTH_CHECK_PORT"
TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE"
class LocalElasticAgent(SimpleElasticAgent):
"""An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers.
This agent is deployed per host and is configured to spawn ``n`` workers.
When using GPUs, ``n`` maps to the number of GPUs available on the host.
The local agent does not communicate to other local agents deployed on
other hosts, even if the workers may communicate inter-host. The worker id
is interpreted to be a local process. The agent starts and stops all worker
processes as a single unit.
The worker function and argument passed to the worker function must be
python multiprocessing compatible. To pass multiprocessing data structures
to the workers you may create the data structure in the same multiprocessing
context as the specified ``start_method`` and pass it as a function argument.
The ``exit_barrier_timeout`` specifies the amount of time (in seconds) to wait
for other agents to finish. This acts as a safety net to handle cases where
workers finish at different times, to prevent agents from viewing workers
that finished early as a scale-down event. It is strongly advised that the
user code deal with ensuring that workers are terminated in a synchronous
manner rather than relying on the exit_barrier_timeout.
A named pipe based watchdog can be enabled in ```LocalElasticAgent``` if an
environment variable ``TORCHELASTIC_ENABLE_FILE_TIMER`` with value 1 has
been defined in the ```LocalElasticAgent``` process.
Optionally, another environment variable ```TORCHELASTIC_TIMER_FILE```
can be set with a unique file name for the named pipe. If the environment
variable ```TORCHELASTIC_TIMER_FILE``` is not set, ```LocalElasticAgent```
will internally create a unique file name and set it to the environment
variable ```TORCHELASTIC_TIMER_FILE```, and this environment variable will
be propagated to the worker processes to allow them to connect to the same
named pipe that ```LocalElasticAgent``` uses.
Logs are written to the specified log directory. Each log line will be by default
prefixed by ``[${role_name}${local_rank}]:`` (e.g. ``[trainer0]: foobar``).
Log prefixes can be customized by passing a `template string
<https://docs.python.org/3/library/string.html#template-strings>`_ as the
``log_line_prefix_template`` argument.
The following macros (identifiers) are substituted at runtime:
``${role_name}, ${local_rank}, ${rank}``. For example, to prefix each log line with
global rank instead of the local rank, set ``log_line_prefix_template = "[${rank}]:``.
Example launching function
::
def trainer(args) -> str:
return "do train"
def main():
start_method="spawn"
shared_queue= multiprocessing.get_context(start_method).Queue()
spec = WorkerSpec(
role="trainer",
local_world_size=nproc_per_process,
entrypoint=trainer,
args=("foobar",),
...<OTHER_PARAMS...>)
agent = LocalElasticAgent(spec, start_method)
results = agent.run()
if results.is_failed():
print("trainer failed")
else:
print(f"rank 0 return value: {results.return_values[0]}")
# prints -> rank 0 return value: do train
Example launching binary
::
def main():
spec = WorkerSpec(
role="trainer",
local_world_size=nproc_per_process,
entrypoint="/usr/local/bin/trainer",
args=("--trainer-args", "foobar"),
...<OTHER_PARAMS...>)
agent = LocalElasticAgent(spec)
results = agent.run()
if not results.is_failed():
print("binary launches do not have return values")
"""
def __init__(
self,
spec: WorkerSpec,
logs_specs: LogsSpecs,
start_method="spawn",
exit_barrier_timeout: float = 300,
log_line_prefix_template: Optional[str] = None,
):
super().__init__(spec, exit_barrier_timeout)
self._start_method = start_method
self._pcontext: Optional[PContext] = None
self._rdzv_handler = spec.rdzv_handler
self._log_line_prefix_template = log_line_prefix_template
self._worker_watchdog: Optional[timer.FileTimerServer] = None
self._logs_specs = logs_specs
self._health_check_server: Optional[HealthCheckServer] = None
def _setup_local_watchdog(self, envs: Dict[int, Dict[str, str]]) -> None:
enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER
watchdog_enabled = os.getenv(enable_watchdog_env_name)
watchdog_file_env_name = TORCHELASTIC_TIMER_FILE
watchdog_file_path = os.getenv(watchdog_file_env_name)
if watchdog_enabled is not None and str(watchdog_enabled) == "1":
if watchdog_file_path is None:
watchdog_file_path = "/tmp/watchdog_timer_" + str(uuid.uuid4())
logger.info("Starting a FileTimerServer with %s ...", watchdog_file_path)
if not envs:
logger.warning(
"Empty envs variables, using empty run_id for FileTimerServer"
)
run_id = ""
else:
run_id = envs[0]["TORCHELASTIC_RUN_ID"]
self._worker_watchdog = timer.FileTimerServer(
file_path=watchdog_file_path,
run_id=run_id,
max_interval=0.1,
daemon=True,
log_event=self._log_watchdog_event,
)
self._worker_watchdog.start()
logger.info("FileTimerServer started")
else:
logger.info(
"Environment variable '%s' not found. Do not start FileTimerServer.",
enable_watchdog_env_name,
)
# Propagate the watchdog file env to worker processes
if watchdog_file_path is not None:
for worker_env in envs.values():
worker_env[watchdog_file_env_name] = watchdog_file_path
@staticmethod
def _get_current_time_secs() -> int:
return int(time.time())
def _setup_healthcheck(self) -> None:
healthcheck_port_env_name = TORCHELASTIC_HEALTH_CHECK_PORT
healthcheck_port = os.getenv(healthcheck_port_env_name)
if healthcheck_port is not None:
logger.info(
"Found healthcheck port %s: %s",
healthcheck_port_env_name,
healthcheck_port,
)
if self._worker_watchdog is None:
logger.info(
"FileTimerServer doesn't exist, using current time as dummy callback"
)
alive_callback = LocalElasticAgent._get_current_time_secs
else:
alive_callback = self._worker_watchdog.get_last_progress_time
try:
healthcheck_port_as_int = int(healthcheck_port)
self._health_check_server = create_healthcheck_server(
alive_callback=alive_callback,
port=healthcheck_port_as_int,
timeout=60,
)
self._health_check_server.start()
except ValueError:
logger.info(
"Invalid healthcheck port value: '%s', expecting integer. Not starting healthcheck server.",
healthcheck_port,
)
else:
logger.info(
"Environment variable '%s' not found. Do not start health check.",
healthcheck_port_env_name,
)
def _get_fq_hostname(self) -> str:
return socket.getfqdn(socket.gethostname())
def _log_watchdog_event(
self,
name: str,
request: Optional[timer.FileTimerRequest],
) -> None:
wg = self._worker_group
spec = wg.spec
md = {"watchdog_event": name}
if request is not None:
md["worker_pid"] = str(request.worker_pid)
md["scope_id"] = request.scope_id
md["expiration_time"] = str(request.expiration_time)
md["signal"] = str(request.signal)
md_str = json.dumps(md)
state = "RUNNING"
metadata: Dict[str, EventMetadataValue] = {
"run_id": spec.rdzv_handler.get_run_id(),
"global_rank": None,
"group_rank": wg.group_rank,
"worker_id": None,
"role": spec.role,
"hostname": self._get_fq_hostname(),
"state": state,
"total_run_time": self._total_execution_time,
"rdzv_backend": spec.rdzv_handler.get_backend(),
"raw_error": None,
"metadata": md_str,
"agent_restarts": spec.max_restarts - self._remaining_restarts,
}
# Note: The 'metadata' field of the Event is converted to a TorchelasticStatusLogEntry later.
# The 'name' field of the Event is NOT used in the TorchelasticStatusLogEntry.
event = events.Event(
name=name, source=events.EventSource.AGENT, metadata=metadata
)
events.record(event)
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `core.distributed.elastic.metrics.prof`.
@prof
def _stop_workers(
self, worker_group: WorkerGroup, is_restart: bool = False
) -> None:
self._shutdown(is_restart=is_restart)
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `core.distributed.elastic.metrics.prof`.
@prof
def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
spec = worker_group.spec
store = worker_group.store
assert store is not None
restart_count = spec.max_restarts - self._remaining_restarts
use_agent_store: bool = spec.rdzv_handler.use_agent_store
logger.info("use_agent_store: %s", use_agent_store)
args: Dict[int, Tuple] = {}
envs: Dict[int, Dict[str, str]] = {}
log_line_prefixes: Optional[Dict[int, str]] = (
{} if self._log_line_prefix_template else None
)
for worker in worker_group.workers:
local_rank = worker.local_rank
worker_env = {
"LOCAL_RANK": str(local_rank),
"RANK": str(worker.global_rank),
"GROUP_RANK": str(worker_group.group_rank),
"ROLE_RANK": str(worker.role_rank),
"ROLE_NAME": spec.role,
"LOCAL_WORLD_SIZE": str(spec.local_world_size),
"WORLD_SIZE": str(worker.world_size),
"GROUP_WORLD_SIZE": str(worker_group.group_world_size),
"ROLE_WORLD_SIZE": str(worker.role_world_size),
"MASTER_ADDR": worker_group.master_addr,
"MASTER_PORT": str(worker_group.master_port),
"TORCHELASTIC_RESTART_COUNT": str(restart_count),
"TORCHELASTIC_MAX_RESTARTS": str(spec.max_restarts),
"TORCHELASTIC_RUN_ID": spec.rdzv_handler.get_run_id(),
"TORCHELASTIC_USE_AGENT_STORE": str(use_agent_store),
"TORCH_NCCL_ASYNC_ERROR_HANDLING": os.getenv(
"TORCH_NCCL_ASYNC_ERROR_HANDLING", str(1)
),
}
if "OMP_NUM_THREADS" in os.environ:
worker_env["OMP_NUM_THREADS"] = os.environ["OMP_NUM_THREADS"]
if self._log_line_prefix_template:
log_line_prefix = Template(
self._log_line_prefix_template
).safe_substitute(
role_name=spec.role,
rank=worker.global_rank,
local_rank=local_rank,
)
log_line_prefixes[local_rank] = log_line_prefix
envs[local_rank] = worker_env
worker_args = list(spec.args)
worker_args = macros.substitute(worker_args, str(local_rank))
args[local_rank] = tuple(worker_args)
self._setup_local_watchdog(envs=envs)
self._setup_healthcheck()
assert spec.entrypoint is not None
assert self._logs_specs is not None
self._pcontext = start_processes(
name=spec.role,
entrypoint=spec.entrypoint,
args=args,
envs=envs,
logs_specs=self._logs_specs,
log_line_prefixes=log_line_prefixes,
start_method=self._start_method,
)
return self._pcontext.pids()
def _shutdown(
self, death_sig: signal.Signals = signal.SIGTERM, is_restart: bool = False
) -> None:
if self._worker_watchdog is not None:
self._worker_watchdog.stop()
self._worker_watchdog = None
if self._health_check_server is not None:
self._health_check_server.stop()
self._health_check_server = None
if self._pcontext:
self._pcontext.close(death_sig)
if not is_restart and self._rdzv_handler:
self._rdzv_handler.shutdown()
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `core.distributed.elastic.metrics.prof`.
@prof
def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
role = worker_group.spec.role
worker_pids = {w.id for w in worker_group.workers}
assert self._pcontext is not None
pc_pids = set(self._pcontext.pids().values())
if worker_pids != pc_pids:
logger.error(
"[%s] worker pids do not match process_context pids."
" Expected: %s, actual: %s",
role,
worker_pids,
pc_pids,
)
return RunResult(state=WorkerState.UNKNOWN)
result = self._pcontext.wait(0)
if result:
if result.is_failed():
# map local rank failure to global rank
worker_failures = {}
for local_rank, failure in result.failures.items():
worker = worker_group.workers[local_rank]
worker_failures[worker.global_rank] = failure
return RunResult(
state=WorkerState.FAILED,
failures=worker_failures,
)
else:
# copy ret_val_queue into a map with a global ranks
workers_ret_vals = {}
for local_rank, ret_val in result.return_values.items():
worker = worker_group.workers[local_rank]
workers_ret_vals[worker.global_rank] = ret_val
return RunResult(
state=WorkerState.SUCCEEDED,
return_values=workers_ret_vals,
)
else:
return RunResult(state=WorkerState.HEALTHY)

+ 0
- 52
mindnlp/core/distributed/elastic/control_plane.py View File

@@ -1,52 +0,0 @@
import os
from contextlib import contextmanager, ExitStack
from typing import Generator
from core.distributed.elastic.multiprocessing.errors import record
__all__ = [
"worker_main",
]
TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET"
@contextmanager
def _worker_server(socket_path: str) -> Generator[None, None, None]:
from core._C._distributed_c10d import _WorkerServer
server = _WorkerServer(socket_path)
try:
yield
finally:
server.shutdown()
@contextmanager
@record
def worker_main() -> Generator[None, None, None]:
"""
This is a context manager that wraps your main entry function. This combines
the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that
exposes handlers via a unix socket specified by
``Torch_WORKER_SERVER_SOCKET``.
Example
::
@worker_main()
def main():
pass
if __name__=="__main__":
main()
"""
with ExitStack() as stack:
socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET)
if socket_path is not None:
stack.enter_context(_worker_server(socket_path))
yield

+ 0
- 170
mindnlp/core/distributed/elastic/events/__init__.py View File

@@ -1,170 +0,0 @@
#!/usr/bin/env/python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Module contains events processing mechanisms that are integrated with the standard python logging.
Example of usage:
::
from core.distributed.elastic import events
event = events.Event(name="test_event", source=events.EventSource.WORKER, metadata={...})
events.get_logging_handler(destination="console").info(event)
"""
import inspect
import logging
import os
import socket
import traceback
from typing import Dict, Optional
from core.distributed.elastic.events.handlers import get_logging_handler
from .api import ( # noqa: F401
Event,
EventMetadataValue,
EventSource,
NodeState,
RdzvEvent,
)
_events_loggers: Dict[str, logging.Logger] = {}
def _get_or_create_logger(destination: str = "null") -> logging.Logger:
"""
Construct python logger based on the destination type or extends if provided.
Available destination could be found in ``handlers.py`` file.
The constructed logger does not propagate messages to the upper level loggers,
e.g. root logger. This makes sure that a single event can be processed once.
Args:
destination: The string representation of the event handler.
Available handlers found in ``handlers`` module
"""
global _events_loggers
if destination not in _events_loggers:
_events_logger = logging.getLogger(f"torchelastic-events-{destination}")
_events_logger.setLevel(os.environ.get("LOGLEVEL", "INFO"))
# Do not propagate message to the root logger
_events_logger.propagate = False
logging_handler = get_logging_handler(destination)
_events_logger.addHandler(logging_handler)
# Add the logger to the global dictionary
_events_loggers[destination] = _events_logger
return _events_loggers[destination]
def record(event: Event, destination: str = "null") -> None:
_get_or_create_logger(destination).info(event.serialize())
def record_rdzv_event(event: RdzvEvent) -> None:
_get_or_create_logger("dynamic_rendezvous").info(event.serialize())
def construct_and_record_rdzv_event(
run_id: str,
message: str,
node_state: NodeState,
name: str = "",
hostname: str = "",
pid: Optional[int] = None,
master_endpoint: str = "",
local_id: Optional[int] = None,
rank: Optional[int] = None,
) -> None:
"""
Initialize rendezvous event object and record its operations.
Args:
run_id (str): The run id of the rendezvous.
message (str): The message describing the event.
node_state (NodeState): The state of the node (INIT, RUNNING, SUCCEEDED, FAILED).
name (str): Event name. (E.g. Current action being performed).
hostname (str): Hostname of the node.
pid (Optional[int]): The process id of the node.
master_endpoint (str): The master endpoint for the rendezvous store, if known.
local_id (Optional[int]): The local_id of the node, if defined in dynamic_rendezvous.py
rank (Optional[int]): The rank of the node, if known.
Returns:
None
Example:
>>> # See DynamicRendezvousHandler class
>>> def _record(
... self,
... message: str,
... node_state: NodeState = NodeState.RUNNING,
... rank: Optional[int] = None,
... ) -> None:
... construct_and_record_rdzv_event(
... name=f"{self.__class__.__name__}.{get_method_name()}",
... run_id=self._settings.run_id,
... message=message,
... node_state=node_state,
... hostname=self._this_node.addr,
... pid=self._this_node.pid,
... local_id=self._this_node.local_id,
... rank=rank,
... )
"""
# We don't want to perform an extra computation if not needed.
if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler):
return
# Set up parameters.
if not hostname:
hostname = socket.getfqdn()
if not pid:
pid = os.getpid()
# Determines which file called this function.
callstack = inspect.stack()
filename = "no_file"
if len(callstack) > 1:
stack_depth_1 = callstack[1]
filename = os.path.basename(stack_depth_1.filename)
if not name:
name = stack_depth_1.function
# Delete the callstack variable. If kept, this can mess with python's
# garbage collector as we are holding on to stack frame information in
# the inspect module.
del callstack
# Set up error trace if this is an exception
if node_state == NodeState.FAILED:
error_trace = traceback.format_exc()
else:
error_trace = ""
# Initialize event object
event = RdzvEvent(
name=f"{filename}:{name}",
run_id=run_id,
message=message,
hostname=hostname,
pid=pid,
node_state=node_state,
master_endpoint=master_endpoint,
rank=rank,
local_id=local_id,
error_trace=error_trace,
)
# Finally, record the event.
record_rdzv_event(event)

+ 0
- 114
mindnlp/core/distributed/elastic/events/api.py View File

@@ -1,114 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Dict, Optional, Union
__all__ = ["EventSource", "Event", "NodeState", "RdzvEvent"]
EventMetadataValue = Union[str, int, float, bool, None]
class EventSource(str, Enum):
"""Known identifiers of the event producers."""
AGENT = "AGENT"
WORKER = "WORKER"
@dataclass
class Event:
"""
The class represents the generic event that occurs during the torchelastic job execution.
The event can be any kind of meaningful action.
Args:
name: event name.
source: the event producer, e.g. agent or worker
timestamp: timestamp in milliseconds when event occurred.
metadata: additional data that is associated with the event.
"""
name: str
source: EventSource
timestamp: int = 0
metadata: Dict[str, EventMetadataValue] = field(default_factory=dict)
def __str__(self):
return self.serialize()
@staticmethod
def deserialize(data: Union[str, "Event"]) -> "Event":
if isinstance(data, Event):
return data
if isinstance(data, str):
data_dict = json.loads(data)
data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined]
return Event(**data_dict)
def serialize(self) -> str:
return json.dumps(asdict(self))
class NodeState(str, Enum):
"""The states that a node can be in rendezvous."""
INIT = "INIT"
RUNNING = "RUNNING"
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
@dataclass
class RdzvEvent:
"""
Dataclass to represent any rendezvous event.
Args:
name: Event name. (E.g. Current action being performed)
run_id: The run id of the rendezvous
message: The message describing the event
hostname: Hostname of the node
pid: The process id of the node
node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED)
master_endpoint: The master endpoint for the rendezvous store, if known
rank: The rank of the node, if known
local_id: The local_id of the node, if defined in dynamic_rendezvous.py
error_trace: Error stack trace, if this is an error event.
"""
name: str
run_id: str
message: str
hostname: str
pid: int
node_state: NodeState
master_endpoint: str = ""
rank: Optional[int] = None
local_id: Optional[int] = None
error_trace: str = ""
def __str__(self):
return self.serialize()
@staticmethod
def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent":
if isinstance(data, RdzvEvent):
return data
if isinstance(data, str):
data_dict = json.loads(data)
data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined]
return RdzvEvent(**data_dict)
def serialize(self) -> str:
return json.dumps(asdict(self))

+ 0
- 22
mindnlp/core/distributed/elastic/events/handlers.py View File

@@ -1,22 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Dict
_log_handlers: Dict[str, logging.Handler] = {
"console": logging.StreamHandler(),
"dynamic_rendezvous": logging.NullHandler(),
"null": logging.NullHandler(),
}
def get_logging_handler(destination: str = "null") -> logging.Handler:
global _log_handlers
return _log_handlers[destination]

+ 0
- 164
mindnlp/core/distributed/elastic/metrics/__init__.py View File

@@ -1,164 +0,0 @@
#!/usr/bin/env/python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Metrics API.
**Overview**:
The metrics API in torchelastic is used to publish telemetry metrics.
It is designed to be used by torchelastic's internal modules to
publish metrics for the end user with the goal of increasing visibility
and helping with debugging. However you may use the same API in your
jobs to publish metrics to the same metrics ``sink``.
A ``metric`` can be thought of as timeseries data
and is uniquely identified by the string-valued tuple
``(metric_group, metric_name)``.
torchelastic makes no assumptions about what a ``metric_group`` is
and what relationship it has with ``metric_name``. It is totally up
to the user to use these two fields to uniquely identify a metric.
.. note:: The metric group ``torchelastic`` is reserved by torchelastic for
platform level metrics that it produces.
For instance torchelastic may output the latency (in milliseconds)
of a re-rendezvous operation from the agent as
``(torchelastic, agent.rendezvous.duration.ms)``
A sensible way to use metric groups is to map them to a stage or module
in your job. You may also encode certain high level properties
the job such as the region or stage (dev vs prod).
**Publish Metrics**:
Using torchelastic's metrics API is similar to using python's logging
framework. You first have to configure a metrics handler before
trying to add metric data.
The example below measures the latency for the ``calculate()`` function.
::
import time
from mindnlp import core.distributed.elastic.metrics as metrics
# makes all metrics other than the one from "my_module" to go /dev/null
metrics.configure(metrics.NullMetricsHandler())
metrics.configure(metrics.ConsoleMetricsHandler(), "my_module")
def my_method():
start = time.time()
calculate()
end = time.time()
metrics.put_metric("calculate_latency", int(end-start), "my_module")
You may also use the core.distributed.elastic.metrics.prof` decorator
to conveniently and succinctly profile functions
::
# -- in module examples.foobar --
from mindnlp import core.distributed.elastic.metrics as metrics
metrics.configure(metrics.ConsoleMetricsHandler(), "foobar")
metrics.configure(metrics.ConsoleMetricsHandler(), "Bar")
@metrics.prof
def foo():
pass
class Bar():
@metrics.prof
def baz():
pass
``@metrics.prof`` will publish the following metrics
::
<leaf_module or classname>.success - 1 if the function finished successfully
<leaf_module or classname>.failure - 1 if the function threw an exception
<leaf_module or classname>.duration.ms - function duration in milliseconds
**Configuring Metrics Handler**:
`core.distributed.elastic.metrics.MetricHandler` is responsible for emitting
the added metric values to a particular destination. Metric groups can be
configured with different metric handlers.
By default torchelastic emits all metrics to ``/dev/null``.
By adding the following configuration metrics,
``torchelastic`` and ``my_app`` metric groups will be printed out to
console.
::
from mindnlp import core.distributed.elastic.metrics as metrics
metrics.configure(metrics.ConsoleMetricHandler(), group = "torchelastic")
metrics.configure(metrics.ConsoleMetricHandler(), group = "my_app")
**Writing a Custom Metric Handler**:
If you want your metrics to be emitted to a custom location, implement
the `core.distributed.elastic.metrics.MetricHandler` interface
and configure your job to use your custom metric handler.
Below is a toy example that prints the metrics to ``stdout``
::
from mindnlp import core.distributed.elastic.metrics as metrics
class StdoutMetricHandler(metrics.MetricHandler):
def emit(self, metric_data):
ts = metric_data.timestamp
group = metric_data.group_name
name = metric_data.name
value = metric_data.value
print(f"[{ts}][{group}]: {name}={value}")
metrics.configure(StdoutMetricHandler(), group="my_app")
Now all metrics in the group ``my_app`` will be printed to stdout as:
::
[1574213883.4182858][my_app]: my_metric=<value>
[1574213940.5237644][my_app]: my_metric=<value>
"""
from typing import Optional
from .api import ( # noqa: F401
configure,
ConsoleMetricHandler,
get_elapsed_time_ms,
getStream,
MetricData,
MetricHandler,
MetricsConfig,
NullMetricHandler,
prof,
profile,
publish_metric,
put_metric,
)
def initialize_metrics(cfg: Optional[MetricsConfig] = None):
pass
try:
from core.distributed.elastic.metrics.static_init import * # type: ignore[import] # noqa: F401 F403
except ModuleNotFoundError:
pass

+ 0
- 216
mindnlp/core/distributed/elastic/metrics/api.py View File

@@ -1,216 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import abc
import time
from collections import namedtuple
from functools import wraps
from typing import Dict, Optional
from typing_extensions import deprecated
__all__ = [
"MetricsConfig",
"MetricHandler",
"ConsoleMetricHandler",
"NullMetricHandler",
"MetricStream",
"configure",
"getStream",
"prof",
"profile",
"put_metric",
"publish_metric",
"get_elapsed_time_ms",
"MetricData",
]
MetricData = namedtuple("MetricData", ["timestamp", "group_name", "name", "value"])
class MetricsConfig:
__slots__ = ["params"]
def __init__(self, params: Optional[Dict[str, str]] = None):
self.params = params
if self.params is None:
self.params = {}
class MetricHandler(abc.ABC):
@abc.abstractmethod
def emit(self, metric_data: MetricData):
pass
class ConsoleMetricHandler(MetricHandler):
def emit(self, metric_data: MetricData):
print(
f"[{metric_data.timestamp}][{metric_data.group_name}]: {metric_data.name}={metric_data.value}"
)
class NullMetricHandler(MetricHandler):
def emit(self, metric_data: MetricData):
pass
class MetricStream:
def __init__(self, group_name: str, handler: MetricHandler):
self.group_name = group_name
self.handler = handler
def add_value(self, metric_name: str, metric_value: int):
self.handler.emit(
MetricData(time.time(), self.group_name, metric_name, metric_value)
)
_metrics_map: Dict[str, MetricHandler] = {}
_default_metrics_handler: MetricHandler = NullMetricHandler()
# pyre-fixme[9]: group has type `str`; used as `None`.
def configure(handler: MetricHandler, group: Optional[str] = None):
if group is None:
global _default_metrics_handler
# pyre-fixme[9]: _default_metrics_handler has type `NullMetricHandler`; used
# as `MetricHandler`.
_default_metrics_handler = handler
else:
_metrics_map[group] = handler
def getStream(group: str):
if group in _metrics_map:
handler = _metrics_map[group]
else:
handler = _default_metrics_handler
return MetricStream(group, handler)
def _get_metric_name(fn):
qualname = fn.__qualname__
split = qualname.split(".")
if len(split) == 1:
module = fn.__module__
if module:
return module.split(".")[-1] + "." + split[0]
else:
return split[0]
else:
return qualname
def prof(fn=None, group: str = "torchelastic"):
r"""
@profile decorator publishes duration.ms, count, success, failure metrics for the function that it decorates.
The metric name defaults to the qualified name (``class_name.def_name``) of the function.
If the function does not belong to a class, it uses the leaf module name instead.
Usage
::
@metrics.prof
def x():
pass
@metrics.prof(group="agent")
def y():
pass
"""
def wrap(f):
@wraps(f)
def wrapper(*args, **kwargs):
key = _get_metric_name(f)
try:
start = time.time()
result = f(*args, **kwargs)
put_metric(f"{key}.success", 1, group)
except Exception:
put_metric(f"{key}.failure", 1, group)
raise
finally:
put_metric(f"{key}.duration.ms", get_elapsed_time_ms(start), group) # type: ignore[possibly-undefined]
return result
return wrapper
if fn:
return wrap(fn)
else:
return wrap
@deprecated("Deprecated, use `@prof` instead", category=FutureWarning)
def profile(group=None):
"""
@profile decorator adds latency and success/failure metrics to any given function.
Usage
::
@metrics.profile("my_metric_group")
def some_function(<arguments>):
"""
def wrap(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
start_time = time.time()
result = func(*args, **kwargs)
publish_metric(group, f"{func.__name__}.success", 1)
except Exception:
publish_metric(group, f"{func.__name__}.failure", 1)
raise
finally:
publish_metric(
group,
f"{func.__name__}.duration.ms",
get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined]
)
return result
return wrapper
return wrap
def put_metric(metric_name: str, metric_value: int, metric_group: str = "torchelastic"):
"""
Publish a metric data point.
Usage
::
put_metric("metric_name", 1)
put_metric("metric_name", 1, "metric_group_name")
"""
getStream(metric_group).add_value(metric_name, metric_value)
@deprecated(
"Deprecated, use `put_metric(metric_group)(metric_name, metric_value)` instead",
category=FutureWarning,
)
def publish_metric(metric_group: str, metric_name: str, metric_value: int):
metric_stream = getStream(metric_group)
metric_stream.add_value(metric_name, metric_value)
def get_elapsed_time_ms(start_time_in_seconds: float):
"""Return the elapsed time in millis from the given start time."""
end_time = time.time()
return int((end_time - start_time_in_seconds) * 1000)

+ 0
- 233
mindnlp/core/distributed/elastic/multiprocessing/__init__.py View File

@@ -1,233 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Library that launches and manages ``n`` copies of worker subprocesses either specified by a function or a binary.
For functions, it uses ``core.multiprocessing`` (and therefore python
``multiprocessing``) to spawn/fork worker processes. For binaries it uses python
``subprocessing.Popen`` to create worker processes.
Usage 1: Launching two trainers as a function
::
from core.distributed.elastic.multiprocessing import Std, start_processes
def trainer(a, b, c):
pass # train
# runs two trainers
# LOCAL_RANK=0 trainer(1,2,3)
# LOCAL_RANK=1 trainer(4,5,6)
ctx = start_processes(
name="trainer",
entrypoint=trainer,
args={0: (1,2,3), 1: (4,5,6)},
envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}},
log_dir="/tmp/foobar",
redirects=Std.ALL, # write all worker stdout/stderr to a log file
tee={0: Std.ERR}, # tee only local rank 0's stderr to console
)
# waits for all copies of trainer to finish
ctx.wait()
Usage 2: Launching 2 echo workers as a binary
::
# same as invoking
# echo hello
# echo world > stdout.log
ctx = start_processes(
name="echo"
entrypoint="echo",
log_dir="/tmp/foobar",
args={0: "hello", 1: "world"},
redirects={1: Std.OUT},
)
Just like ``core.multiprocessing``, the return value of the function
:func:`start_processes` is a process context (:class:`api.PContext`). If a function
was launched, a :class:`api.MultiprocessContext` is returned and if a binary
was launched a :class:`api.SubprocessContext` is returned. Both are specific
implementations of the parent :class:`api.PContext` class.
"""
from typing import Callable, Dict, Optional, Tuple, Union
from core.distributed.elastic.multiprocessing.api import ( # noqa: F401
_validate_full_rank,
DefaultLogsSpecs,
LogsDest,
LogsSpecs,
MultiprocessContext,
PContext,
ProcessFailure,
RunProcsResult,
SignalException,
Std,
SubprocessContext,
to_map,
)
from core.distributed.elastic.utils.logging import get_logger
__all__ = [
"start_processes",
"MultiprocessContext",
"PContext",
"ProcessFailure",
"RunProcsResult",
"SignalException",
"Std",
"LogsDest",
"LogsSpecs",
"DefaultLogsSpecs",
"SubprocessContext",
"to_map",
]
def start_processes(
name: str,
entrypoint: Union[Callable, str],
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
start_method: str = "spawn",
) -> PContext:
"""
Start ``n`` copies of ``entrypoint`` processes with the provided options.
``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary).
The number of copies is determined by the number of entries for ``args`` and
``envs`` arguments, which need to have the same key set.
``args`` and ``env`` parameters are the arguments and environment variables
to pass down to the entrypoint mapped by the replica index (local rank).
All local ranks must be accounted for.
That is, the keyset should be ``{0,1,...,(nprocs-1)}``.
.. note:: When the ``entrypoint`` is a binary (``str``), ``args`` can only be strings.
If any other type is given, then it is casted to a string representation
(e.g. ``str(arg1)``). Furthermore, a binary failure will only write
an ``error.json`` error file if the main function is annotated with
``core.distributed.elastic.multiprocessing.errors.record``. For function launches,
this is done by default and there is no need to manually annotate
with the ``@record`` annotation.
``redirects`` and ``tee`` are bitmasks specifying which std stream(s) to redirect
to a log file in the ``log_dir``. Valid mask values are defined in ``Std``.
To redirect/tee only certain local ranks, pass ``redirects`` as a map with the key as
the local rank to specify the redirect behavior for.
Any missing local ranks will default to ``Std.NONE``.
``tee`` acts like the unix "tee" command in that it redirects + prints to console.
To avoid worker stdout/stderr from printing to console, use the ``redirects`` parameter.
For each process, the ``log_dir`` will contain:
#. ``{local_rank}/error.json``: if the process failed, a file with the error info
#. ``{local_rank}/stdout.json``: if ``redirect & STDOUT == STDOUT``
#. ``{local_rank}/stderr.json``: if ``redirect & STDERR == STDERR``
.. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory.
Example:
::
log_dir = "/tmp/test"
# ok; two copies of foo: foo("bar0"), foo("bar1")
start_processes(
name="trainer",
entrypoint=foo,
args:{0:("bar0",), 1:("bar1",),
envs:{0:{}, 1:{}},
log_dir=log_dir
)
# invalid; envs missing for local rank 1
start_processes(
name="trainer",
entrypoint=foo,
args:{0:("bar0",), 1:("bar1",),
envs:{0:{}},
log_dir=log_dir
)
# ok; two copies of /usr/bin/touch: touch file1, touch file2
start_processes(
name="trainer",
entrypoint="/usr/bin/touch",
args:{0:("file1",), 1:("file2",),
envs:{0:{}, 1:{}},
log_dir=log_dir
)
# caution; arguments casted to string, runs:
# echo "1" "2" "3" and echo "[1, 2, 3]"
start_processes(
name="trainer",
entrypoint="/usr/bin/echo",
args:{0:(1,2,3), 1:([1,2,3],),
envs:{0:{}, 1:{}},
log_dir=log_dir
)
Args:
name: a human readable short name that describes what the processes are
(used as header when tee'ing stdout/stderr outputs)
entrypoint: either a ``Callable`` (function) or ``cmd`` (binary)
args: arguments to each replica
envs: env vars to each replica
log_dir: directory used to write log files
start_method: multiprocessing start method (spawn, fork, forkserver)
ignored for binaries
redirects: which std streams to redirect to a log file
tee: which std streams to redirect + print to console
local_ranks_filter: which ranks' logs to print to console
"""
nprocs = len(args)
_validate_full_rank(args, nprocs, "args")
_validate_full_rank(envs, nprocs, "envs")
context: PContext
if isinstance(entrypoint, str):
context = SubprocessContext(
name=name,
entrypoint=entrypoint,
args=args,
envs=envs,
logs_specs=logs_specs,
log_line_prefixes=log_line_prefixes,
)
else:
context = MultiprocessContext(
name=name,
entrypoint=entrypoint,
args=args,
envs=envs,
log_line_prefixes=log_line_prefixes,
start_method=start_method,
logs_specs=logs_specs,
)
try:
context.start()
return context
except Exception:
context.close()
raise

+ 0
- 923
mindnlp/core/distributed/elastic/multiprocessing/api.py View File

@@ -1,923 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import abc
import logging
import os
import re
import shutil
import signal
import subprocess
import sys
import tempfile
import threading
import time
from abc import ABC, abstractmethod
from contextlib import nullcontext
from dataclasses import dataclass, field
from enum import IntFlag
from multiprocessing import synchronize
from types import FrameType
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
from mindnlp import core.multiprocessing as mp
from core.distributed.elastic.multiprocessing.errors import ProcessFailure, record
from core.distributed.elastic.multiprocessing.redirects import (
redirect_stderr,
redirect_stdout,
)
from core.distributed.elastic.multiprocessing.subprocess_handler import (
get_subprocess_handler,
SubprocessHandler,
)
from core.distributed.elastic.multiprocessing.tail_log import TailLog
IS_WINDOWS = sys.platform == "win32"
IS_MACOS = sys.platform == "darwin"
logger = logging.getLogger(__name__)
__all__ = [
"DefaultLogsSpecs",
"SignalException",
"Std",
"to_map",
"RunProcsResult",
"PContext",
"get_std_cm",
"MultiprocessContext",
"SubprocessContext",
"LogsDest",
"LogsSpecs",
]
class SignalException(Exception):
"""
Exception is raised inside the torchelastic agent process by the termination handler
if the death signal got received by the process.
"""
def __init__(self, msg: str, sigval: signal.Signals) -> None:
super().__init__(msg)
self.sigval = sigval
def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None:
"""Termination handler that raises exceptions on the main process.
When the process receives death signal(SIGTERM, SIGINT), this termination handler will
be invoked. It raises the ``SignalException`` exception that should be processed by the
user code. Python does not terminate process after the termination handler is finished,
so the exception should not be silently ignored, otherwise the process will never
be terminated.
"""
sigval = signal.Signals(signum)
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
def _get_kill_signal() -> signal.Signals:
"""Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows."""
if IS_WINDOWS:
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
else:
return signal.SIGKILL
def _get_default_signal() -> signal.Signals:
"""Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows."""
if IS_WINDOWS:
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
else:
return signal.SIGTERM
def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
actual_keys = set(d.keys())
expected_keys = set(range(nprocs))
if actual_keys != expected_keys:
raise RuntimeError(
f"{what}, local rank mapping mismatch,"
f" expected: {expected_keys}, actual: {actual_keys}"
)
_MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$"
_VALUE_REGEX = r"^[0123]$"
class Std(IntFlag):
NONE = 0
OUT = 1
ERR = 2
ALL = OUT | ERR
@classmethod
def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]:
"""
Example:
::
from_str("0") -> Std.NONE
from_str("1") -> Std.OUT
from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
Any other input raises an exception
"""
def to_std(v: str) -> Std: # type: ignore[return]
s = Std(int(v))
if s in Std:
return s
# return None -> should NEVER reach here since we regex check input
if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0)
return to_std(vm)
elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2)
d: Dict[int, Std] = {}
for m in vm.split(","):
i, v = m.split(":")
d[int(i)] = to_std(v)
return d
else:
raise ValueError(
f"{vm} does not match: <{_VALUE_REGEX}> or <{_MAPPING_REGEX}>"
)
def to_map(
val_or_map: Union[Std, Dict[int, Std]], local_world_size: int
) -> Dict[int, Std]:
"""
Certain APIs take redirect settings either as a single value (e.g. apply to all
local ranks) or as an explicit user-provided mapping. This method is a convenience
method that converts a value or mapping into a mapping.
Example:
::
to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT}
to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
"""
if isinstance(val_or_map, Std):
return dict.fromkeys(range(local_world_size), val_or_map)
else:
map = {}
for i in range(local_world_size):
map[i] = val_or_map.get(i, Std.NONE)
return map
@dataclass
class LogsDest:
"""
For each log type, holds mapping of local rank ids to file paths.
"""
stdouts: Dict[int, str] = field(default_factory=dict)
stderrs: Dict[int, str] = field(default_factory=dict)
tee_stdouts: Dict[int, str] = field(default_factory=dict)
tee_stderrs: Dict[int, str] = field(default_factory=dict)
error_files: Dict[int, str] = field(default_factory=dict)
class LogsSpecs(ABC):
"""
Defines logs processing and redirection for each worker process.
Args:
log_dir:
Base directory where logs will be written.
redirects:
Streams to redirect to files. Pass a single ``Std``
enum to redirect for all workers, or a mapping keyed
by local_rank to selectively redirect.
tee:
Streams to duplicate to stdout/stderr.
Pass a single ``Std`` enum to duplicate streams for all workers,
or a mapping keyed by local_rank to selectively duplicate.
"""
def __init__(
self,
log_dir: Optional[str] = None,
redirects: Union[Std, Dict[int, Std]] = Std.NONE,
tee: Union[Std, Dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[Set[int]] = None,
) -> None:
self._root_log_dir = log_dir
self._redirects = redirects
self._tee = tee
self._local_ranks_filter = local_ranks_filter
@abstractmethod
def reify(
self,
envs: Dict[int, Dict[str, str]],
) -> LogsDest:
"""
Given the environment variables, builds destination of log files for each of the local ranks.
Envs parameter contains env variables dict for each of the local ranks, where entries are defined in:
:func:`~torchelastic.distributed.elastic.agent.server.local_elastic_agent.LocalElasticAgent._start_workers`.
"""
@property
@abstractmethod
def root_log_dir(self) -> str:
pass
class DefaultLogsSpecs(LogsSpecs):
"""
Default LogsSpecs implementation:
- `log_dir` will be created if it doesn't exist
- Generates nested folders for each attempt and rank.
"""
def __init__(
self,
log_dir: Optional[str] = None,
redirects: Union[Std, Dict[int, Std]] = Std.NONE,
tee: Union[Std, Dict[int, Std]] = Std.NONE,
local_ranks_filter: Optional[Set[int]] = None,
) -> None:
if log_dir != os.devnull:
if not log_dir:
log_dir = tempfile.mkdtemp(prefix="torchelastic_")
elif not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
else:
if os.path.isfile(log_dir):
raise NotADirectoryError(f"log_dir: {log_dir} is a file")
super().__init__(log_dir, redirects, tee, local_ranks_filter)
# initialized only once
self._run_log_dir = None
@property
def root_log_dir(self) -> str:
return str(self._root_log_dir)
def _make_log_dir(self, log_dir: Optional[str], rdzv_run_id: str):
base_log_dir = log_dir or tempfile.mkdtemp(prefix="torchelastic_")
os.makedirs(base_log_dir, exist_ok=True)
dir = tempfile.mkdtemp(prefix=f"{rdzv_run_id}_", dir=base_log_dir)
logger.info("log directory set to: %s", dir)
return dir
def reify(
self,
envs: Dict[int, Dict[str, str]],
) -> LogsDest:
"""
Uses following scheme to build log destination paths:
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stdout.log`
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/stderr.log`
- `<log_dir>/<rdzv_run_id>/attempt_<attempt>/<rank>/error.json`
"""
nprocs = len(envs)
global_env = {} # use only to query properies that are not dependent on a rank
if nprocs > 0:
global_env = envs[0]
else:
logger.warning(
"Empty envs map provided when defining logging destinations."
)
# Keys are always defined, but values can be missing in unit tests
run_id = global_env.get("TORCHELASTIC_RUN_ID", "test_run_id")
restart_count = global_env.get("TORCHELASTIC_RESTART_COUNT", "0")
attempt_log_dir: str = ""
if self._root_log_dir != os.devnull:
if not self._run_log_dir:
self._run_log_dir = self._make_log_dir(self._root_log_dir, run_id)
attempt_log_dir = os.path.join(self._run_log_dir, f"attempt_{restart_count}") # type: ignore[call-overload]
shutil.rmtree(attempt_log_dir, ignore_errors=True)
os.makedirs(attempt_log_dir)
if self._root_log_dir == os.devnull:
attempt_log_dir = os.devnull
# create subdirs for each local rank in the logs_dir
# logs_dir
# |- 0
# |- error.json
# |- stdout.log
# |- stderr.log
# |- ...
# |- (nprocs-1)
redirs = to_map(self._redirects, nprocs)
ts = to_map(self._tee, nprocs)
# to tee stdout/stderr we first redirect into a file
# then tail -f stdout.log/stderr.log so add tee settings to redirects
for local_rank, tee_std in ts.items():
redirect_std = redirs[local_rank]
redirs[local_rank] = redirect_std | tee_std
SYS_STREAM = "" # special case to indicate to output to console
stdouts = dict.fromkeys(range(nprocs), SYS_STREAM)
stderrs = dict.fromkeys(range(nprocs), SYS_STREAM)
tee_stdouts: Dict[int, str] = {}
tee_stderrs: Dict[int, str] = {}
error_files = {}
for local_rank in range(nprocs):
if attempt_log_dir == os.devnull:
tee_stdouts[local_rank] = os.devnull
tee_stderrs[local_rank] = os.devnull
error_files[local_rank] = os.devnull
envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = ""
else:
clogdir = os.path.join(attempt_log_dir, str(local_rank))
os.mkdir(clogdir)
rd = redirs[local_rank]
if (rd & Std.OUT) == Std.OUT:
stdouts[local_rank] = os.path.join(clogdir, "stdout.log")
if (rd & Std.ERR) == Std.ERR:
stderrs[local_rank] = os.path.join(clogdir, "stderr.log")
t = ts[local_rank]
if t & Std.OUT == Std.OUT:
tee_stdouts[local_rank] = stdouts[local_rank]
if t & Std.ERR == Std.ERR:
tee_stderrs[local_rank] = stderrs[local_rank]
if (
self._local_ranks_filter
and local_rank not in self._local_ranks_filter
):
# If stream is tee'd, only write to file, but don't tail
if local_rank in tee_stdouts:
tee_stdouts.pop(local_rank, None)
if local_rank in tee_stderrs:
tee_stderrs.pop(local_rank, None)
# If stream is not redirected, don't print
if stdouts[local_rank] == SYS_STREAM:
stdouts[local_rank] = os.devnull
if stderrs[local_rank] == SYS_STREAM:
stderrs[local_rank] = os.devnull
error_file = os.path.join(clogdir, "error.json")
error_files[local_rank] = error_file
logger.info(
"Setting worker%s reply file to: %s", local_rank, error_file
)
envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file
return LogsDest(stdouts, stderrs, tee_stdouts, tee_stderrs, error_files)
def __repr__(self) -> str:
return (
f"DefaultLogsSpecs(root_log_dir={self._root_log_dir}, redirects={self._redirects}, "
f"tee={self._tee}, local_ranks_filter={self._local_ranks_filter})"
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, DefaultLogsSpecs):
return False
return (
self._root_log_dir == other._root_log_dir
and self._redirects == other._redirects
and self._tee == other._tee
and self._local_ranks_filter == other._local_ranks_filter
)
@dataclass
class RunProcsResult:
"""
Results of a completed run of processes started with ``start_processes()``. Returned by ``PContext``.
Note the following:
1. All fields are mapped by local rank
2. ``return_values`` - only populated for functions (not the binaries).
3. ``stdouts`` - path to stdout.log (empty string if no redirect)
4. ``stderrs`` - path to stderr.log (empty string if no redirect)
"""
return_values: Dict[int, Any] = field(default_factory=dict)
failures: Dict[int, ProcessFailure] = field(default_factory=dict)
stdouts: Dict[int, str] = field(default_factory=dict)
stderrs: Dict[int, str] = field(default_factory=dict)
def is_failed(self) -> bool:
return len(self.failures) > 0
class PContext(abc.ABC):
"""
The base class that standardizes operations over a set of processes that are launched via different mechanisms.
The name ``PContext`` is intentional to disambiguate with ``core.multiprocessing.ProcessContext``.
.. warning:: stdouts and stderrs should ALWAYS be a superset of
tee_stdouts and tee_stderrs (respectively) this is b/c
tee is implemented as a redirect + tail -f <stdout/stderr.log>
"""
def __init__(
self,
name: str,
entrypoint: Union[Callable, str],
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
):
self.name = name
# validate that all mappings have the same number of keys and
# all local ranks are accounted for
nprocs = len(args)
# TODO log_line_prefixes can be exanded too
logs_dest = logs_specs.reify(envs)
_validate_full_rank(logs_dest.stdouts, nprocs, "stdouts")
_validate_full_rank(logs_dest.stderrs, nprocs, "stderrs")
self.entrypoint = entrypoint
self.args = args
self.envs = envs
self.stdouts = logs_dest.stdouts
self.stderrs = logs_dest.stderrs
self.error_files = logs_dest.error_files
self.nprocs = nprocs
self._stdout_tail = TailLog(
name, logs_dest.tee_stdouts, sys.stdout, log_line_prefixes
)
self._stderr_tail = TailLog(
name, logs_dest.tee_stderrs, sys.stderr, log_line_prefixes
)
def start(self) -> None:
"""Start processes using parameters defined in the constructor."""
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGTERM, _terminate_process_handler)
signal.signal(signal.SIGINT, _terminate_process_handler)
if not IS_WINDOWS:
signal.signal(signal.SIGHUP, _terminate_process_handler)
signal.signal(signal.SIGQUIT, _terminate_process_handler)
else:
logger.warning(
"Failed to register signal handlers since torchelastic is running on a child thread. "
"This could lead to orphaned worker processes if the torchrun is terminated."
)
self._start()
self._stdout_tail.start()
self._stderr_tail.start()
@abc.abstractmethod
def _start(self) -> None:
"""Start processes using strategy defined in a particular context."""
raise NotImplementedError
@abc.abstractmethod
def _poll(self) -> Optional[RunProcsResult]:
"""
Poll the run status of the processes running under this context.
This method follows an "all-or-nothing" policy and returns
a ``RunProcessResults`` object if either all processes complete
successfully or any process fails. Returns ``None`` if
all processes are still running.
"""
raise NotImplementedError
def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]:
"""
Wait for the specified ``timeout`` seconds, polling every ``period`` seconds
for the processes to be done. Returns ``None`` if the processes are still running
on timeout expiry. Negative timeout values are interpreted as "wait-forever".
A timeout value of zero simply queries the status of the processes (e.g. equivalent
to a poll).
..note: Multiprocessing library registers SIGTERM and SIGINT signal handlers that raise
``SignalException`` when the signals received. It is up to the consumer of the code
to properly handle the exception. It is important not to swallow the exception otherwise
the process would not terminate. Example of the typical workflow can be:
.. code-block:: python
pc = start_processes(...)
try:
pc.wait(1)
.. do some other work
except SignalException as e:
pc.shutdown(e.sigval, timeout=30)
If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating
received signal. If child processes will not terminate in the timeout time, the process will send
the SIGKILL.
"""
if timeout == 0:
return self._poll()
if timeout < 0:
timeout = sys.maxsize
expiry = time.time() + timeout
while time.time() < expiry:
pr = self._poll()
if pr:
return pr
time.sleep(period)
return None
@abc.abstractmethod
def pids(self) -> Dict[int, int]:
"""Return pids of processes mapped by their respective local_ranks."""
raise NotImplementedError
@abc.abstractmethod
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
r"""
Terminates all processes managed by this context and cleans up any
meta resources (e.g. redirect, error_file files).
"""
raise NotImplementedError
def close(
self, death_sig: Optional[signal.Signals] = None, timeout: int = 30
) -> None:
r"""
Terminates all processes managed by this context and cleans up any
meta resources (e.g. redirect, error_file files).
Args:
death_sig: Death signal to terminate processes.
timeout: Time to wait for processes to finish, if process is
still alive after this time, it will be terminated via SIGKILL.
"""
if not death_sig:
death_sig = _get_default_signal()
self._close(death_sig=death_sig, timeout=timeout)
if self._stdout_tail:
self._stdout_tail.stop()
if self._stderr_tail:
self._stderr_tail.stop()
def get_std_cm(std_rd: str, redirect_fn):
if IS_WINDOWS or IS_MACOS or not std_rd:
return nullcontext()
else:
return redirect_fn(std_rd)
def _wrap(
local_rank: int,
fn: Callable,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None)
stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None)
ret_vals: Dict[int, mp.SimpleQueue],
queue_finished_reading_event: synchronize.Event,
) -> None:
# get the per-rank params up front so we fail fast if no mapping is found
args_ = args[local_rank]
env_ = envs[local_rank]
ret_val_ = ret_vals[local_rank]
stdout_rd = stdout_redirects[local_rank]
stderr_rd = stderr_redirects[local_rank]
stdout_cm = get_std_cm(stdout_rd, redirect_stdout)
stderr_cm = get_std_cm(stderr_rd, redirect_stderr)
for k, v in env_.items():
os.environ[k] = v
with stdout_cm, stderr_cm:
ret = record(fn)(*args_)
ret_val_.put(ret)
queue_finished_reading_event.wait()
class MultiprocessContext(PContext):
"""``PContext`` holding worker processes invoked as a function."""
def __init__(
self,
name: str,
entrypoint: Callable,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
start_method: str,
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
):
super().__init__(
name,
entrypoint,
args,
envs,
logs_specs,
log_line_prefixes,
)
self.start_method = start_method
# each ret_val queue will always contain a single element.
self._ret_vals = {
local_rank: mp.get_context(self.start_method).SimpleQueue()
for local_rank in range(self.nprocs)
}
# see comments in ``join()`` for what this is
self._return_values: Dict[int, Any] = {}
self._pc: Optional[mp.ProcessContext] = None
# Note: set method should ONLY be invoked for the use case when all processes finished
# successfully. If any process died on event.wait() calling set() method will deadlock.
self._worker_finished_event = mp.get_context(self.start_method).Event()
def _start(self):
if self._pc:
raise ValueError(
"The process context already initialized."
" Most likely the start method got called twice."
)
self._pc = mp.start_processes(
fn=_wrap,
args=(
self.entrypoint,
self.args,
self.envs,
self.stdouts,
self.stderrs,
self._ret_vals,
self._worker_finished_event,
),
nprocs=self.nprocs,
join=False,
daemon=False,
start_method=self.start_method,
)
def _is_done(self) -> bool:
return len(self._return_values) == self.nprocs
def _poll(self) -> Optional[RunProcsResult]:
assert self._pc is not None # assertion for mypy type checker
try:
# core.mp.ProcessContext Throws an Exception if some/all of
# worker processes failed
# timeout < 0 checks worker status and return immediately
# Join will never return success since we use synchronize.Event to wait
# for all processes to finish.
self._pc.join(-1)
# IMPORTANT: we use multiprocessing.Queue to carry worker return values
# back to the parent, the worker process will wait before terminating
# until all the buffered items are fed by the feeder thread to the underlying
# pipe. Hence to prevent deadlocks on large return values,
# we opportunistically try queue.get on each join call
# See: https://docs.python.org/2/library/multiprocessing.html#all-platforms
for local_rank in range(0, self.nprocs):
return_queue = self._ret_vals[local_rank]
if not return_queue.empty():
# save the return values temporarily into a member var
self._return_values[local_rank] = return_queue.get()
if self._is_done():
# we should ALWAYS have ALL the return values when all the processes are done
self._worker_finished_event.set()
# At this point workers finished running the user function
# But the child process might still have not exited. Wait for them.
# pc.join() blocks [forever] until "a" proc exits. Loop until all of them exits.
while not self._pc.join():
logger.debug(
"entrypoint fn finished, waiting for all child procs to exit..."
)
_validate_full_rank(
self._return_values, self.nprocs, "return_value queue"
)
self.close()
return RunProcsResult(
return_values=self._return_values,
stdouts=self.stdouts,
stderrs=self.stderrs,
)
else:
return None
except (mp.ProcessRaisedException, mp.ProcessExitedException) as e:
failed_local_rank = e.error_index
# entrypoint for MultiprocessContext will always be a Callable
fn_name = self.entrypoint.__qualname__ # type: ignore[union-attr]
failed_proc = self._pc.processes[failed_local_rank]
error_filepath = self.error_files[failed_local_rank]
logger.exception(
"failed (exitcode: %s)"
" local_rank: %s (pid: %s)"
" of fn: %s (start_method: %s)",
failed_proc.exitcode,
failed_local_rank,
e.pid,
fn_name,
self.start_method,
)
self.close()
return RunProcsResult(
failures={
failed_local_rank: ProcessFailure(
local_rank=failed_local_rank,
pid=e.pid,
exitcode=failed_proc.exitcode,
error_file=error_filepath,
)
},
stdouts=self.stdouts,
stderrs=self.stderrs,
)
def pids(self) -> Dict[int, int]:
assert self._pc is not None # assertion for mypy type checking
return dict(enumerate(self._pc.pids()))
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
if not self._pc:
return
for proc in self._pc.processes:
if proc.is_alive():
logger.warning(
"Closing process %s via signal %s", proc.pid, death_sig.name
)
try:
os.kill(proc.pid, death_sig)
except ProcessLookupError:
# If the process exited because of some reason,
# `ProcessLookupError` will be raised, it is safe to ignore it.
pass
end = time.monotonic() + timeout
for proc in self._pc.processes:
time_to_wait = end - time.monotonic()
if time_to_wait <= 0:
break
proc.join(time_to_wait)
for proc in self._pc.processes:
if proc.is_alive():
logger.warning(
"Unable to shutdown process %s via %s, forcefully exiting via %s",
proc.pid,
death_sig,
_get_kill_signal(),
)
try:
os.kill(proc.pid, _get_kill_signal())
except ProcessLookupError:
# If the process exited because of some reason,
# `ProcessLookupError` will be raised, it is safe to ignore it.
pass
proc.join()
class SubprocessContext(PContext):
"""``PContext`` holding worker processes invoked as a binary."""
def __init__(
self,
name: str,
entrypoint: str,
args: Dict[int, Tuple],
envs: Dict[int, Dict[str, str]],
logs_specs: LogsSpecs,
log_line_prefixes: Optional[Dict[int, str]] = None,
):
super().__init__(
name,
entrypoint,
args,
envs,
logs_specs,
log_line_prefixes,
)
# state vector; _vdone[local_rank] -> is local_rank finished or not
self._running_local_ranks: Set[int] = set(range(self.nprocs))
self._failures: Dict[int, ProcessFailure] = {}
self.subprocess_handlers: Dict[int, SubprocessHandler] = {}
def _start(self):
if self.subprocess_handlers:
raise ValueError(
"The subprocess handlers already initialized. Most likely the start method got called twice."
)
self.subprocess_handlers = {
local_rank: get_subprocess_handler(
entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str
args=self.args[local_rank],
env=self.envs[local_rank],
stdout=self.stdouts[local_rank],
stderr=self.stderrs[local_rank],
local_rank_id=local_rank,
)
for local_rank in range(self.nprocs)
}
def _poll(self) -> Optional[RunProcsResult]:
done_local_ranks = set()
for local_rank in self._running_local_ranks:
handler = self.subprocess_handlers[local_rank]
exitcode = handler.proc.poll()
if exitcode is not None:
done_local_ranks.add(local_rank)
if exitcode != 0: # failed or signaled
self._failures[local_rank] = ProcessFailure(
local_rank=local_rank,
pid=handler.proc.pid,
exitcode=exitcode,
error_file=self.error_files[local_rank],
)
# else: --> succeeded; nothing to do
self._running_local_ranks.difference_update(done_local_ranks)
# if ALL procs are finished or ANY have failed
if not self._running_local_ranks or self._failures:
self.close() # terminate all running procs
result = RunProcsResult(
failures=self._failures,
stdouts=self.stdouts,
stderrs=self.stderrs,
)
if result.is_failed():
first_failure = min(result.failures.values(), key=lambda f: f.timestamp)
logger.error(
"failed (exitcode: %s)"
" local_rank: %s (pid: %s)"
" of binary: %s",
first_failure.exitcode,
first_failure.local_rank,
first_failure.pid,
self.entrypoint,
)
else:
# Populate return with dummy values. This provides consistency with MultiprocessingHandler
result.return_values = dict.fromkeys(range(self.nprocs))
return result
else: # there are no failures and procs still running
return None
def pids(self) -> Dict[int, int]:
return {
local_rank: sh.proc.pid
for local_rank, sh in self.subprocess_handlers.items()
}
def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
if not self.subprocess_handlers:
return
for handler in self.subprocess_handlers.values():
if handler.proc.poll() is None:
logger.warning(
"Sending process %s closing signal %s",
handler.proc.pid,
death_sig.name,
)
handler.close(death_sig=death_sig)
end = time.monotonic() + timeout
for handler in self.subprocess_handlers.values():
time_to_wait = end - time.monotonic()
if time_to_wait <= 0:
break
try:
handler.proc.wait(time_to_wait)
except subprocess.TimeoutExpired:
# Ignore the timeout expired exception, since
# the child process will be forcefully terminated via SIGKILL
pass
for handler in self.subprocess_handlers.values():
if handler.proc.poll() is None:
logger.warning(
"Unable to shutdown process %s via %s, forcefully exiting via %s",
handler.proc.pid,
death_sig,
_get_kill_signal(),
)
handler.close(death_sig=_get_kill_signal())
handler.proc.wait()

+ 0
- 383
mindnlp/core/distributed/elastic/multiprocessing/errors/__init__.py View File

@@ -1,383 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Each host in a distributed PyTorch job runs with a single TorchElastic agent,
and multiple workers (as children processes of the TorchElastic agent).
Since the workers are user-provided (your PyTorch script/job), TorchElastic
has a way to propagate errors on the trainers through the agent and up to the
scheduler, which ultimately informs the end-user about the state of the job
and applies any retry policies.
TorchElastic categorizes errors into 3 categories:
+----------------+----------------+--------------------------------------------------------------+
| Category | Sub-Category | Description |
+================+================+==============================================================+
| User Error | Input Error | invalid inputs to TorchElastic APIs (e.g. min > max nodes) |
| +----------------+--------------------------------------------------------------+
| | Worker Failure | any failures on the worker child process |
+----------------+----------------+--------------------------------------------------------------+
| Platform Error | n/a | failures caused by the agent |
+----------------+----------------+--------------------------------------------------------------+
| Infra Error | n/a | failures outside the domain of the agent and workers |
| | | (e.g. host failures) |
+----------------+----------------+--------------------------------------------------------------+
All errors other than "Worker Failure" are either raised canonically from the
agent process or implicitly or explicitly crash the agent process. So the
standard language (python) provided exception handling strategies apply.
Worker Failures are special because the exception/failure originates on a different
process from the agent so the error needs to be propagated inter-process
(e.g. the agent cannot simply ``try-catch`` an exception raised on the worker process).
TorchElastic agents use :func:`core.distributed.elastic.multiprocessing.start_processes`
to launch the workers which has a simple file based inter-process error propagation
built-in.
Any function or binary entrypoint decorated with :func:`record`
will write uncaught exceptions (with the trace information) to a file specified by the
environment variable ``TORCHELASTIC_ERROR_FILE``. The parent process (e.g. agent)
sets this env var on each child it launches, then aggregates the error files for all
children, and propagates the one with the **smallest** timestamp (e.g. the **first** error).
"""
import json
import os
import signal
import socket
import time
import warnings
from dataclasses import dataclass, field
from datetime import datetime
from functools import wraps
from string import Template
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from core.distributed.elastic.utils.logging import get_logger
from .error_handler import ErrorHandler # noqa: F401
from .handlers import get_error_handler # noqa: F401
__all__ = [
"ProcessFailure",
"ChildFailedError",
"record",
"ErrorHandler",
"get_error_handler",
]
logger = get_logger(__name__)
JSON = Dict
_EMPTY_ERROR_DATA = {"message": "<NONE>"}
_NOT_AVAILABLE = "<N/A>"
T = TypeVar("T")
@dataclass
class ProcessFailure:
"""
Represent the failed process result. When the worker process fails, it may record failure root cause into the file.
Tries to read the failure timestamp from the provided ``error_file``,
if the ``error_file`` does not exist, the timestamp is the current
timestamp (seconds since epoch).
The ``message`` field is a concise explanation of the failure. If
the error file exists then the message is obtained from the error file.
Otherwise one is generated based on the failure signature.
.. note:: It is assumed that the ``error_file`` is written by
``core.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler``.
Otherwise the behavior is undefined.
"""
local_rank: int
pid: int
exitcode: int
error_file: str
error_file_data: JSON = field(init=False)
message: str = field(init=False)
timestamp: int = field(init=False)
def __post_init__(self):
self.error_file_data = _EMPTY_ERROR_DATA
if os.path.isfile(self.error_file):
try:
with open(self.error_file) as fp:
self.error_file_data = json.load(fp)
logger.debug(
"User process failed with error data: %s",
json.dumps(self.error_file_data, indent=2),
)
self.message, self.timestamp = self._get_error_data(
self.error_file_data
)
except Exception:
logger.exception("Failed to parse reply file: %s", self.error_file)
raise
else:
self._set_no_reply_file()
# make up an informative message if not already present
if not self.message:
# signals typically do not generate an error file message
if self.exitcode < 0:
self.message = (
f"Signal {-self.exitcode} ({self.signal_name()})"
f" received by PID {self.pid}"
)
else:
self.message = "To enable traceback see: https://pycore.org/docs/stable/elastic/errors.html"
def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]:
message = error_file_data["message"]
if isinstance(message, str):
timestamp = int(error_file_data.get("timestamp", 0))
else:
timestamp = int(message["extraInfo"]["timestamp"])
return (message, timestamp)
def _set_no_reply_file(self):
self.error_file = _NOT_AVAILABLE
self.error_file_data = _EMPTY_ERROR_DATA
self.message = ""
self.timestamp = int(time.time())
def signal_name(self) -> str:
if self.exitcode < 0:
# We don't want to kill the parent process trying to find the signal name.
# if the signal doesn't map to a known name, use not available.
try:
return signal.Signals(-self.exitcode).name
except Exception:
return _NOT_AVAILABLE
else:
return _NOT_AVAILABLE
def timestamp_isoformat(self):
"""Return timestamp in ISO format (YYYY-MM-DD_HH:MM:SS)."""
return datetime.fromtimestamp(self.timestamp).isoformat(sep="_")
GlobalRank = int
_FAILURE_FORMAT_TEMPLATE = """[${idx}]:
time : ${time}
host : ${hostname}
rank : ${rank} (local_rank: ${local_rank})
exitcode : ${exitcode} (pid: ${pid})
error_file: ${error_file}
traceback : ${message}"""
# extra new lines before and after are intentional
_MSG_FORMAT_TEMPLATE = """
${boarder}
${title}
${section}
Failures:
${other_failures}
${section}
Root Cause (first observed failure):
${root_failure}
${boarder}"""
class ChildFailedError(Exception):
"""
Special exception type that can be raised from a function annotated with the
``@record`` decorator to have the child process' (root exception) propagate
up the stack as-is (e.g. without being wrapped in the parent's traceback).
Useful in cases where the parent is a simple nanny process
and the child (worker) processes are actually doing meaningful compute.
In this case, errors typically occur on the child process as the parent
is not doing anything non-trivial, and child errors should be propagated
to the scheduler for accurate root cause diagnostics.
.. note:: The propagation relies on error files rather than exception handling to
support both function and binary launches.
Example:
::
# process tree on a host (container)
0: scheduler-init-process:
|- 1: torchelastic_agent:
|- 2: trainer_0 (ok)
|- 3: trainer_1 (fail) -> error.json
|- ...
|- n+2: trainer_n (ok)
|- n+3: other processes
|- ...
In the example above, trainer 1's failure (written into error.json) is
the root cause and should be reported to the scheduler's init process.
The torchelastic agent raises a ``ChildFailedError("trainer", {1: "trainer_1/error.json"})``
upon detecting trainer 1's failure which would propagate the contents
of trainer 1's error file to the scheduler's init process.
"""
def __init__(self, name: str, failures: Dict[GlobalRank, ProcessFailure]):
self.name = name
self.failures = failures
assert (
self.failures
) # does not make sense to create a ChildFaileError with no failures
super().__init__(self.format_msg())
def get_first_failure(self) -> Tuple[GlobalRank, ProcessFailure]:
rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp)
return rank, self.failures[rank]
def format_msg(self, boarder_delim="=", section_delim="-"):
title = f"{self.name} FAILED"
root_rank, _root_failure = self.get_first_failure()
root_failure_fmt: str = ""
other_failures_fmt: List[str] = []
width = len(title)
for idx, (rank, failure) in enumerate(self.failures.items()):
fmt, w = self._format_failure(idx, rank, failure)
width = max(width, w)
if rank == root_rank:
root_failure_fmt = fmt
else:
other_failures_fmt.append(fmt)
# upper boundary on width
width = min(width, 60)
return Template(_MSG_FORMAT_TEMPLATE).substitute(
boarder=boarder_delim * width,
title=title,
section=section_delim * width,
root_failure=root_failure_fmt,
other_failures="\n".join(other_failures_fmt or [" <NO_OTHER_FAILURES>"]),
)
def _format_failure(
self, idx: int, rank: int, failure: ProcessFailure
) -> Tuple[str, int]:
# failure.message is either a str (when the failure does not generate a traceback - e.g. signals)
# or a dict (json) of the form
# {"message": $ERROR_MSG, "extraInfo": {"py_callstack": $TRACEBACK, timestamp: $TS}}
# so the display logic is:
# 1. if failure.message is not a dict (it is a str) just show it as is
# 2. else try to get the traceback (py_callstack)
# 3. if the traceback is not there, use the message
# 4. if the message is not there show <N/A>
msg = failure.message
if isinstance(failure.message, dict):
msg = (
failure.message.get("extraInfo", {})
.get("py_callstack", failure.message.get("message", "<N/A>"))
.replace("\n", "\n ") # to properly indent the traceback
)
fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute(
idx=idx,
time=failure.timestamp_isoformat(),
hostname=socket.getfqdn(),
rank=rank,
local_rank=failure.local_rank,
exitcode=failure.exitcode,
pid=failure.pid,
error_file=failure.error_file,
message=msg,
)
width = 0
for line in fmt.split("\n"):
width = max(width, len(line))
return fmt, width
def record(
fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None
) -> Callable[..., T]:
"""
Syntactic sugar to record errors/exceptions that happened in the decorated
function using the provided ``error_handler``.
Using this decorator is equivalent to:
::
error_handler = get_error_handler()
error_handler.initialize()
try:
foobar()
except ChildFailedError as e:
_, failure = e.get_first_failure()
error_handler.dump_error_file(failure.error_file, failure.exitcode)
raise
except Exception as e:
error_handler.record(e)
raise
.. important:: use this decorator once per process at the top level method,
typically this is the main method.
Example
::
@record
def main():
pass
if __name__=="__main__":
main()
"""
if not error_handler:
error_handler = get_error_handler()
def wrap(f):
@wraps(f)
def wrapper(*args, **kwargs):
assert error_handler is not None # assertion for mypy type checker
error_handler.initialize()
try:
return f(*args, **kwargs)
except SystemExit as se:
# For run_path based entrypoints, SystemExit with code = 0 will never exit.
# Handling it here by returning a value:
if se.code == 0:
return None
else:
raise
except ChildFailedError as e:
rank, failure = e.get_first_failure()
if failure.error_file != _NOT_AVAILABLE:
error_handler.dump_error_file(failure.error_file, failure.exitcode)
else:
logger.info(
(
"local_rank %s FAILED with no error file."
" Decorate your entrypoint fn with @record for traceback info."
" See: https://pycore.org/docs/stable/elastic/errors.html",
rank,
)
)
raise
except Exception as e:
error_handler.record_exception(e)
raise
return wrapper
return wrap(fn)

+ 0
- 166
mindnlp/core/distributed/elastic/multiprocessing/errors/error_handler.py View File

@@ -1,166 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import faulthandler
import json
import logging
import os
import time
import traceback
import warnings
from typing import Any, Dict, Optional
__all__ = ["ErrorHandler"]
logger = logging.getLogger(__name__)
class ErrorHandler:
"""
Write the provided exception object along with some other metadata about
the error in a structured way in JSON format to an error file specified by the
environment variable: ``TORCHELASTIC_ERROR_FILE``. If this environment
variable is not set, then simply logs the contents of what would have been
written to the error file.
This handler may be subclassed to customize the handling of the error.
Subclasses should override ``initialize()`` and ``record_exception()``.
"""
def _get_error_file_path(self) -> Optional[str]:
"""
Return the error file path.
May return ``None`` to have the structured error be logged only.
"""
return os.environ.get("TORCHELASTIC_ERROR_FILE", None)
def initialize(self) -> None:
"""
Call prior to running code that we wish to capture errors/exceptions.
Typically registers signal/fault handlers. Users can override this
function to add custom initialization/registrations that aid in
propagation/information of errors/signals/exceptions/faults.
"""
try:
faulthandler.enable(all_threads=True)
except Exception as e:
warnings.warn(f"Unable to enable fault handler. {type(e).__name__}: {e}")
def _write_error_file(self, file_path: str, error_msg: str) -> None:
"""Write error message to the file."""
try:
with open(file_path, "w") as fp:
fp.write(error_msg)
except Exception as e:
warnings.warn(f"Unable to write error to file. {type(e).__name__}: {e}")
def record_exception(self, e: BaseException) -> None:
"""
Write a structured information about the exception into an error file in JSON format.
If the error file cannot be determined, then logs the content
that would have been written to the error file.
"""
file = self._get_error_file_path()
if file:
data = {
"message": {
"message": f"{type(e).__name__}: {e}",
"extraInfo": {
"py_callstack": traceback.format_exc(),
"timestamp": str(int(time.time())),
},
}
}
with open(file, "w") as fp:
json.dump(data, fp)
def override_error_code_in_rootcause_data(
self,
rootcause_error_file: str,
rootcause_error: Dict[str, Any],
error_code: int = 0,
):
"""Modify the rootcause_error read from the file, to correctly set the exit code."""
if "message" not in rootcause_error:
logger.warning(
"child error file (%s) does not have field `message`. \n"
"cannot override error code: %s",
rootcause_error_file,
error_code,
)
elif isinstance(rootcause_error["message"], str):
logger.warning(
"child error file (%s) has a new message format. \n"
"skipping error code override",
rootcause_error_file,
)
else:
rootcause_error["message"]["errorCode"] = error_code
def dump_error_file(self, rootcause_error_file: str, error_code: int = 0):
"""Dump parent error file from child process's root cause error and error code."""
with open(rootcause_error_file) as fp:
rootcause_error = json.load(fp)
# Override error code since the child process cannot capture the error code if it
# is terminated by signals like SIGSEGV.
if error_code:
self.override_error_code_in_rootcause_data(
rootcause_error_file, rootcause_error, error_code
)
logger.debug(
"child error file (%s) contents:\n" "%s",
rootcause_error_file,
json.dumps(rootcause_error, indent=2),
)
my_error_file = self._get_error_file_path()
if my_error_file:
# Guard against existing error files
# This can happen when the child is created using multiprocessing
# and the same env var (TORCHELASTIC_ERROR_FILE) is used on the
# parent and child to specify the error files (respectively)
# because the env vars on the child is set in the wrapper function
# and by default the child inherits the parent's env vars, if the child
# process receives a signal before the wrapper function kicks in
# and the signal handler writes to the error file, then the child
# will write to the parent's error file. In this case just log the
# original error file contents and overwrite the error file.
self._rm(my_error_file)
self._write_error_file(my_error_file, json.dumps(rootcause_error))
logger.info("dumped error file to parent's %s", my_error_file)
else:
logger.error(
"no error file defined for parent, to copy child error file (%s)",
rootcause_error_file,
)
def _rm(self, my_error_file):
if os.path.isfile(my_error_file):
# Log the contents of the original file.
with open(my_error_file) as fp:
try:
original = json.dumps(json.load(fp), indent=2)
logger.warning(
"%s already exists"
" and will be overwritten."
" Original contents:\n%s",
my_error_file,
original,
)
except json.decoder.JSONDecodeError:
logger.warning(
"%s already exists"
" and will be overwritten."
" Unable to load original contents:\n",
my_error_file,
)
os.remove(my_error_file)

+ 0
- 19
mindnlp/core/distributed/elastic/multiprocessing/errors/handlers.py View File

@@ -1,19 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Multiprocessing error-reporting module
from core.distributed.elastic.multiprocessing.errors.error_handler import ErrorHandler
__all__ = ["get_error_handler"]
def get_error_handler():
return ErrorHandler()

+ 0
- 104
mindnlp/core/distributed/elastic/multiprocessing/redirects.py View File

@@ -1,104 +0,0 @@
# mypy: allow-untyped-defs
# !/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Taken and modified from original source:
# https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/
import ctypes
import logging
import os
import sys
from contextlib import contextmanager
from functools import partial
IS_WINDOWS = sys.platform == "win32"
IS_MACOS = sys.platform == "darwin"
logger = logging.getLogger(__name__)
def get_libc():
if IS_WINDOWS or IS_MACOS:
logger.warning(
"NOTE: Redirects are currently not supported in Windows or MacOs."
)
return None
else:
return ctypes.CDLL("libc.so.6")
libc = get_libc()
def _c_std(stream: str):
return ctypes.c_void_p.in_dll(libc, stream)
def _python_std(stream: str):
return {"stdout": sys.stdout, "stderr": sys.stderr}[stream]
_VALID_STD = {"stdout", "stderr"}
@contextmanager
def redirect(std: str, to_file: str):
"""
Redirect ``std`` (one of ``"stdout"`` or ``"stderr"``) to a file in the path specified by ``to_file``.
This method redirects the underlying std file descriptor (not just python's ``sys.stdout|stderr``).
See usage for details.
Directory of ``dst_filename`` is assumed to exist and the destination file
is overwritten if it already exists.
.. note:: Due to buffering cross source writes are not guaranteed to
appear in wall-clock order. For instance in the example below
it is possible for the C-outputs to appear before the python
outputs in the log file.
Usage:
::
# syntactic-sugar for redirect("stdout", "tmp/stdout.log")
with redirect_stdout("/tmp/stdout.log"):
print("python stdouts are redirected")
libc = ctypes.CDLL("libc.so.6")
libc.printf(b"c stdouts are also redirected"
os.system("echo system stdouts are also redirected")
print("stdout restored")
"""
if std not in _VALID_STD:
raise ValueError(
f"unknown standard stream <{std}>, must be one of {_VALID_STD}"
)
c_std = _c_std(std)
python_std = _python_std(std)
std_fd = python_std.fileno()
def _redirect(dst):
libc.fflush(c_std)
python_std.flush()
os.dup2(dst.fileno(), std_fd)
with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst:
_redirect(dst)
try:
yield
finally:
_redirect(orig_std)
redirect_stdout = partial(redirect, "stdout")
redirect_stderr = partial(redirect, "stderr")

+ 0
- 16
mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/__init__.py View File

@@ -1,16 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from core.distributed.elastic.multiprocessing.subprocess_handler.handlers import (
get_subprocess_handler,
)
from core.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
SubprocessHandler,
)
__all__ = ["SubprocessHandler", "get_subprocess_handler"]

+ 0
- 34
mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/handlers.py View File

@@ -1,34 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Tuple
from core.distributed.elastic.multiprocessing.subprocess_handler.subprocess_handler import (
SubprocessHandler,
)
__all__ = ["get_subprocess_handler"]
def get_subprocess_handler(
entrypoint: str,
args: Tuple,
env: Dict[str, str],
stdout: str,
stderr: str,
local_rank_id: int,
):
return SubprocessHandler(
entrypoint=entrypoint,
args=args,
env=env,
stdout=stdout,
stderr=stderr,
local_rank_id=local_rank_id,
)

+ 0
- 78
mindnlp/core/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py View File

@@ -1,78 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import signal
import subprocess
import sys
from typing import Any, Dict, Optional, Tuple
__all__ = ["SubprocessHandler"]
IS_WINDOWS = sys.platform == "win32"
def _get_default_signal() -> signal.Signals:
"""Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows."""
if IS_WINDOWS:
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
else:
return signal.SIGTERM
class SubprocessHandler:
"""
Convenience wrapper around python's ``subprocess.Popen``. Keeps track of
meta-objects associated to the process (e.g. stdout and stderr redirect fds).
"""
def __init__(
self,
entrypoint: str,
args: Tuple,
env: Dict[str, str],
stdout: Optional[str],
stderr: Optional[str],
local_rank_id: int,
):
self._stdout = open(stdout, "w") if stdout else None
self._stderr = open(stderr, "w") if stderr else None
# inherit parent environment vars
env_vars = os.environ.copy()
env_vars.update(env)
args_str = (entrypoint, *[str(e) for e in args])
self.local_rank_id = local_rank_id
self.proc: subprocess.Popen = self._popen(args_str, env_vars)
def _popen(self, args: Tuple, env: Dict[str, str]) -> subprocess.Popen:
kwargs: Dict[str, Any] = {}
if not IS_WINDOWS:
kwargs["start_new_session"] = True
return subprocess.Popen(
# pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes],
# _PathLike[str], bytes, str]], bytes, str]` for 1st param but got
# `Tuple[str, *Tuple[Any, ...]]`.
args=args,
env=env,
stdout=self._stdout,
stderr=self._stderr,
**kwargs,
)
def close(self, death_sig: Optional[signal.Signals] = None) -> None:
if not death_sig:
death_sig = _get_default_signal()
if IS_WINDOWS:
self.proc.send_signal(death_sig)
else:
os.killpg(self.proc.pid, death_sig)
if self._stdout:
self._stdout.close()
if self._stderr:
self._stderr.close()

+ 0
- 158
mindnlp/core/distributed/elastic/multiprocessing/tail_log.py View File

@@ -1,158 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import time
from concurrent.futures.thread import ThreadPoolExecutor
from threading import Event
from typing import Dict, List, Optional, TextIO, TYPE_CHECKING
if TYPE_CHECKING:
from concurrent.futures._base import Future
__all__ = ["tail_logfile", "TailLog"]
logger = logging.getLogger(__name__)
def tail_logfile(
header: str, file: str, dst: TextIO, finished: Event, interval_sec: float
):
while not os.path.exists(file):
if finished.is_set():
return
time.sleep(interval_sec)
with open(file, errors="replace") as fp:
while True:
line = fp.readline()
if line:
dst.write(f"{header}{line}")
else: # reached EOF
if finished.is_set():
# log line producer is finished
break
else:
# log line producer is still going
# wait for a bit before looping again
time.sleep(interval_sec)
class TailLog:
"""
Tail the given log files.
The log files do not have to exist when the ``start()`` method is called. The tail-er will gracefully wait until
the log files are created by the producer and will tail the contents of the
log files until the ``stop()`` method is called.
.. warning:: ``TailLog`` will wait indefinitely for the log file to be created!
Each log file's line will be suffixed with a header of the form: ``[{name}{idx}]:``,
where the ``name`` is user-provided and ``idx`` is the index of the log file
in the ``log_files`` mapping. ``log_line_prefixes`` can be used to override the
header for each log file.
Usage:
::
log_files = {0: "/tmp/0_stdout.log", 1: "/tmp/1_stdout.log"}
tailer = TailLog("trainer", log_files, sys.stdout).start()
# actually run the trainers to produce 0_stdout.log and 1_stdout.log
run_trainers()
tailer.stop()
# once run_trainers() start writing the ##_stdout.log files
# the tailer will print to sys.stdout:
# >>> [trainer0]:log_line1
# >>> [trainer1]:log_line1
# >>> [trainer0]:log_line2
# >>> [trainer0]:log_line3
# >>> [trainer1]:log_line2
.. note:: Due to buffering log lines between files may not necessarily
be printed out in order. You should configure your application's
logger to suffix each log line with a proper timestamp.
"""
def __init__(
self,
name: str,
log_files: Dict[int, str],
dst: TextIO,
log_line_prefixes: Optional[Dict[int, str]] = None,
interval_sec: float = 0.1,
):
n = len(log_files)
self._threadpool = None
if n > 0:
self._threadpool = ThreadPoolExecutor(
max_workers=n,
thread_name_prefix=f"{self.__class__.__qualname__}_{name}",
)
self._name = name
self._dst = dst
self._log_files = log_files
self._log_line_prefixes = log_line_prefixes
self._finished_events: Dict[int, Event] = {
local_rank: Event() for local_rank in log_files.keys()
}
self._futs: List[Future] = []
self._interval_sec = interval_sec
self._stopped = False
def start(self) -> "TailLog":
if not self._threadpool:
return self
for local_rank, file in self._log_files.items():
header = f"[{self._name}{local_rank}]:"
if self._log_line_prefixes and local_rank in self._log_line_prefixes:
header = self._log_line_prefixes[local_rank]
self._futs.append(
self._threadpool.submit(
tail_logfile,
header=header,
file=file,
dst=self._dst,
finished=self._finished_events[local_rank],
interval_sec=self._interval_sec,
)
)
return self
def stop(self) -> None:
for finished in self._finished_events.values():
finished.set()
for local_rank, f in enumerate(self._futs):
try:
f.result()
except Exception as e:
logger.error(
"error in log tailor for %s%s. %s: %s",
self._name,
local_rank,
e.__class__.__qualname__,
e,
)
if self._threadpool:
self._threadpool.shutdown(wait=True)
self._stopped = True
def stopped(self) -> bool:
return self._stopped

+ 0
- 167
mindnlp/core/distributed/elastic/rendezvous/__init__.py View File

@@ -1,167 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
In the context of Torch Distributed Elastic we use the term *rendezvous* to
refer to a particular functionality that combines a **distributed
synchronization** primitive with **peer discovery**.
It is used by Torch Distributed Elastic to gather participants of a training
job (i.e. nodes) such that they all agree on the same list of participants and
everyone's roles, as well as make a consistent collective decision on when
training can begin/resume.
Torch Distributed Elastic rendezvous provides the following critical
functionalities:
**Barrier**:
Nodes performing rendezvous will all block until the rendezvous is considered
complete - this happens when at least ``min`` total number of nodes have joined
the rendezvous barrier (for the same job). This also implies the barrier is not
necessarily of fixed size.
There's an additional small waiting time after reaching ``min`` number of
nodes - this is used to ensure the rendezvous is not completed "too quickly"
(which could potentially exclude additional nodes attempting to join at
approximately the same time).
If ``max`` number of nodes is gathered at the barrier, the rendezvous is
completed immediately.
There's also an overall timeout which causes the rendezvous to fail if ``min``
number of nodes is never reached - this is meant to be a simple fail-safe to
help release partially allocated job resources, in case there's a problem with
the resource manager, and is meant to be interpreted as non-retryable.
**Exclusivity**:
A simple distributed barrier would not be sufficient, as we also need to ensure
that only one group of nodes exists at any given time (for a given job). In
other words, new nodes (i.e. joining late) should not be able to form a parallel
independent group of workers for the same job.
Torch Distributed Elastic rendezvous ensures that if a group of nodes has
already completed a rendezvous (and hence might already be training), then
additional "late" nodes attempting to rendezvous will only announce themselves
as waiting, and will have to wait until the (previously completed) existing
rendezvous is destroyed first.
**Consistency**:
When a rendezvous is completed, all its members will agree on the job membership
and everyone's role in it. This role is represented using an integer, called
rank, that is between between 0 and world size.
Note that ranks are *not stable*, in the sense that the same node can be
assigned a different rank in the next (re-)rendezvous.
**Fault-tolerance**:
Torch Distributed Elastic rendezvous is designed to tolerate node failures
during the rendezvous process. Should a process crash (or lose network
connectivity, etc), between joining the rendezvous and it being completed, then
a re-rendezvous with remaining healthy nodes will happen automatically.
A node can also fail *after* it has completed (or *has been observered* by other
nodes to have completed) the rendezvous - this scenario will be handled by the
Torch Distributed Elastic ``train_loop`` instead (where it will also trigger a
re-rendezvous).
**Shared key-value store**:
When the rendezvous is completed, a shared key-value store is created and
returned. This store implements a ``core.distributed.Store`` API (see
`distributed communication docs
<https://pycore.org/docs/stable/distributed.html>`__).
This store is only shared by the members of the completed rendezvous. It
is intended to be used by Torch Distributed Elastic to exchange information
necessary to initialize job control and data-planes.
**Waiting workers and rendezvous closing**:
Torch Distributed Elastic rendezvous handler object provides additional
functionalities, which are technically not part of the rendezvous process:
1. Querying how many workers arrived late at the barrier, who can participate in
*next* rendezvous.
2. Setting the rendezvous *closed* to signal all nodes not to participate in
next rendezvous.
**DynamicRendezvousHandler**:
Torch Distributed Elastic comes with the :py:class:`.DynamicRendezvousHandler`
class that implements the rendezvous mechanism described above. It is a backend-
agnostic type that expects a particular :py:class:`.RendezvousBackend` instance
to be specified during construction.
Torch distributed users can either implement their own backend type or use one
of the following implementations that come with PyTorch:
- :py:class:`.C10dRendezvousBackend`: Uses a C10d store (by default
``TCPStore``) as the rendezvous backend. The main advantage of using a C10d
store is that it requires no 3rd-party dependency (such as etcd) to establish
a rendezvous.
- :py:class:`.EtcdRendezvousBackend`: Supersedes the legacy
:py:class:`.EtcdRendezvousHandler` class. Passing an
:py:class:`.EtcdRendezvousBackend` instance to
:py:class:`.DynamicRendezvousHandler` is functionally equivalent to
instantiating an :py:class:`.EtcdRendezvousHandler`.
::
store = TCPStore("localhost")
backend = C10dRendezvousBackend(store, "my_run_id")
rdzv_handler = DynamicRendezvousHandler.from_backend(
run_id="my_run_id",
store=store,
backend=backend,
min_nodes=2,
max_nodes=4
)
"""
from .api import (
rendezvous_handler_registry,
RendezvousClosedError,
RendezvousConnectionError,
RendezvousError,
RendezvousGracefulExitError,
RendezvousHandler,
RendezvousHandlerCreator,
RendezvousHandlerRegistry,
RendezvousInfo,
RendezvousParameters,
RendezvousStateError,
RendezvousStoreInfo,
RendezvousTimeoutError,
)
from .registry import _register_default_handlers, _register_out_of_tree_handlers
_register_default_handlers()
_register_out_of_tree_handlers()
__all__ = [
"RendezvousClosedError",
"RendezvousConnectionError",
"RendezvousError",
"RendezvousGracefulExitError",
"RendezvousHandler",
"RendezvousHandlerCreator",
"RendezvousHandlerRegistry",
"RendezvousInfo",
"RendezvousParameters",
"RendezvousStateError",
"RendezvousStoreInfo",
"RendezvousTimeoutError",
"rendezvous_handler_registry",
]

+ 0
- 384
mindnlp/core/distributed/elastic/rendezvous/api.py View File

@@ -1,384 +0,0 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import socket
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Dict, Optional
from core.distributed import Store
from core.distributed.elastic.utils.distributed import get_free_port
__all__ = [
"RendezvousClosedError",
"RendezvousConnectionError",
"RendezvousError",
"RendezvousGracefulExitError",
"RendezvousHandler",
"RendezvousHandlerCreator",
"RendezvousHandlerRegistry",
"RendezvousInfo",
"RendezvousParameters",
"RendezvousStateError",
"RendezvousStoreInfo",
"RendezvousTimeoutError",
"rendezvous_handler_registry",
]
class RendezvousError(Exception):
"""Represents the base type for rendezvous errors."""
class RendezvousClosedError(RendezvousError):
"""Raised when a rendezvous is closed."""
class RendezvousTimeoutError(RendezvousError):
"""Raised when a rendezvous did not complete on time."""
class RendezvousConnectionError(RendezvousError):
"""Raised when the connection to a rendezvous backend has failed."""
class RendezvousStateError(RendezvousError):
"""Raised when the state of a rendezvous is corrupt."""
class RendezvousGracefulExitError(RendezvousError):
"""Raised when node wasn't not included in rendezvous and gracefully exits.
Exception is a mechanism to exit the stack, however does not mean a failure.
"""
@dataclass
class RendezvousStoreInfo:
"""Store address and port that can be used to bootstrap trainer distributed comms"""
MASTER_ADDR_KEY: ClassVar[str] = "MASTER_ADDR"
MASTER_PORT_KEY: ClassVar[str] = "MASTER_PORT"
master_addr: str
master_port: int
@staticmethod
def build(
rank: int,
store: Store,
local_addr: Optional[str],
server_port: Optional[int] = None,
) -> "RendezvousStoreInfo":
"""Factory method, finds unused new port on rank0 host and addr/port info with all ranks.
If master_addr/master_port is knowns (useful when sharing existing tcp store server) use the constructor.
Args:
rank: rank of the current node
store: store to use for rendezvous
local_addr: address of the current node, if not provided will be resolved from hostname
server_port: port of the TCPStore server, when the TCPStore is shared.
"""
# TODO swap to collectives comms API
if rank == 0:
addr = local_addr or socket.getfqdn()
# When TCPStore is not shared, we fallback to get_free_port.
port = server_port or get_free_port()
store.set(RendezvousStoreInfo.MASTER_ADDR_KEY, addr.encode(encoding="UTF-8")) # type: ignore[arg-type]
store.set(RendezvousStoreInfo.MASTER_PORT_KEY, str(port).encode(encoding="UTF-8")) # type: ignore[arg-type]
addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8")
port = int(
store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8")
)
return RendezvousStoreInfo(master_addr=addr, master_port=port)
class RendezvousInfo:
"""Holds the information about the rendezvous."""
def __init__(
self,
store: Store,
rank: int,
world_size: int,
bootstrap_store_info: RendezvousStoreInfo,
):
self._store = store
self._rank = rank
self._world_size = world_size
self._bootstrap_store_info = bootstrap_store_info
@property
def store(self) -> Store:
"""Store used by torchelastic control plane"""
return self._store
@property
def rank(self) -> int:
"""Rank within a group"""
return self._rank
@property
def world_size(self) -> int:
"""Global group size"""
return self._world_size
@property
def bootstrap_store_info(self) -> Optional[RendezvousStoreInfo]:
"""Store information that can used by trainer code to bootstrap distributed comms."""
return self._bootstrap_store_info
class RendezvousHandler(ABC):
"""Main rendezvous interface.
Note:
Distributed Torch users normally **do not** need to implement their own
``RendezvousHandler``. An implementation based on C10d Store is already
provided, and is recommended for most users.
"""
@abstractmethod
def get_backend(self) -> str:
"""Return the name of the rendezvous backend."""
@property
def use_agent_store(self) -> bool:
"""Indicates that store reference returned by :py:meth:`next_rendezvous` can be shared with user
applications and will be available during application lifecyle.
Rendezous handler impl will share store details as instance of :py:class:`RendezvousStoreInfo`.
Applications as a convention use `MASTER_ADDR`/`MASTER_PORT` env variables to lookup the store.
"""
return False
@abstractmethod
def next_rendezvous(self) -> RendezvousInfo:
"""Main entry-point into the rendezvous barrier.
Blocks until the rendezvous is complete and the current process is
included in the formed worker group, or a timeout occurs, or the
rendezvous was marked closed.
Returns:
Instance of :py:class:`RendezvousInfo`.
Raises:
RendezvousClosedError:
The rendezvous is closed.
RendezvousConnectionError:
The connection to the rendezvous backend has failed.
RendezvousStateError:
The rendezvous state is corrupt.
RendezvousTimeoutError:
The rendezvous did not complete on time.
"""
@abstractmethod
def is_closed(self) -> bool:
"""Check whether the rendezvous has been closed.
A closed rendezvous means all future attempts to re-rendezvous within
same job will fail.
``is_closed()`` and :py:meth:`set_closed` have semantics of eventual
propagation and should not be used for synchronization. The intention is
that if at least one node decides the job is finished, it will close the
rendezvous, and other nodes will soon observe this and stop running as
well.
"""
@abstractmethod
def set_closed(self):
"""Mark the rendezvous as closed."""
@abstractmethod
def num_nodes_waiting(self) -> int:
"""Return the number of nodes who arrived late at the rendezvous
barrier, hence were not included in the current worker group.
Callers should periodically call this method to check whether new
nodes are waiting to join the job and if so admit them by calling
:py:meth:`next_rendezvous()` (re-rendezvous).
"""
@abstractmethod
def get_run_id(self) -> str:
"""Return the run id of the rendezvous.
The run id is a user-defined id that uniquely identifies an instance of
a distributed application. It typically maps to a job id and is used to
allow nodes to join the correct distributed application.
"""
@abstractmethod
def shutdown(self) -> bool:
"""Close all resources that were open for the rendezvous.
Example::
rdzv_handler = ...
try:
store, rank, world_size = rdzv_handler.next_rendezvous()
finally:
rdzv_handler.shutdown()
"""
class RendezvousParameters:
"""Hold the parameters to construct a :py:class:`RendezvousHandler`.
Args:
backend:
The name of the backend to use to handle the rendezvous.
endpoint:
The endpoint of the rendezvous, usually in form <hostname>[:<port>].
run_id:
The id of the rendezvous.
min_nodes:
The minimum number of nodes to admit to the rendezvous.
max_nodes:
The maximum number of nodes to admit to the rendezvous.
local_addr:
The address of the local node.
**kwargs:
Additional parameters for the specified backend.
"""
def __init__(
self,
backend: str,
endpoint: str,
run_id: str,
min_nodes: int,
max_nodes: int,
local_addr: Optional[str] = None,
**kwargs,
):
if not backend:
raise ValueError("The rendezvous backend name must be a non-empty string.")
if min_nodes < 1:
raise ValueError(
f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero."
)
if max_nodes < min_nodes:
raise ValueError(
f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or "
f"equal to the minimum number of rendezvous nodes ({min_nodes})."
)
self.backend = backend
self.endpoint = endpoint
self.run_id = run_id
self.min_nodes = min_nodes
self.max_nodes = max_nodes
self.config = kwargs
self.local_addr = local_addr
def get(self, key: str, default: Any = None) -> Any:
"""Return the value for ``key`` if ``key`` exists, else ``default``."""
return self.config.get(key, default)
def get_as_bool(self, key: str, default: Optional[bool] = None) -> Optional[bool]:
"""Return the value for ``key`` as a ``bool``."""
value = self.get(key, default)
if value is None or isinstance(value, bool):
return value
if isinstance(value, int):
if value == 1:
return True
if value == 0:
return False
elif isinstance(value, str):
if value.lower() in ["1", "true", "t", "yes", "y"]:
return True
if value.lower() in ["0", "false", "f", "no", "n"]:
return False
raise ValueError(
f"The rendezvous configuration option '{key}' does not represent a valid boolean value."
)
def get_as_int(self, key: str, default: Optional[int] = None) -> Optional[int]:
"""Return the value for ``key`` as an ``int``."""
value = self.get(key, default)
if value is None:
return value
try:
return int(value)
except ValueError as e:
raise ValueError(
f"The rendezvous configuration option '{key}' does not represent a valid integer "
"value."
) from e
RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
class RendezvousHandlerRegistry:
"""Represent a registry of :py:class:`RendezvousHandler` backends."""
_registry: Dict[str, RendezvousHandlerCreator]
def __init__(self) -> None:
self._registry = {}
def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
"""Register a new rendezvous backend.
Args:
backend:
The name of the backend.
creator:
The callback to invoke to construct the
:py:class:`RendezvousHandler`.
"""
if not backend:
raise ValueError("The rendezvous backend name must be a non-empty string.")
current_creator: Optional[RendezvousHandlerCreator]
try:
current_creator = self._registry[backend]
except KeyError:
current_creator = None
if current_creator is not None and current_creator != creator:
raise ValueError(
f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it "
f"is already registered with '{current_creator}'."
)
self._registry[backend] = creator
def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
"""Create a new :py:class:`RendezvousHandler`."""
try:
creator = self._registry[params.backend]
except KeyError as e:
raise ValueError(
f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
f"to call `{self.register.__name__}`?"
) from e
handler = creator(params)
# Do some sanity check.
if handler.get_backend() != params.backend:
raise RuntimeError(
f"The rendezvous backend '{handler.get_backend()}' does not match the requested "
f"backend '{params.backend}'."
)
return handler
# The default global registry instance used by launcher scripts to instantiate
# rendezvous handlers.
rendezvous_handler_registry = RendezvousHandlerRegistry()

+ 0
- 273
mindnlp/core/distributed/elastic/rendezvous/c10d_rendezvous_backend.py View File

@@ -1,273 +0,0 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import binascii
import logging
import os
import tempfile
from base64 import b64decode, b64encode
from datetime import timedelta
from typing import Any, cast, Optional, Tuple
from core.distributed import FileStore, Store, TCPStore
from core.distributed.elastic.events import construct_and_record_rdzv_event, NodeState
from .api import (
RendezvousConnectionError,
RendezvousError,
RendezvousParameters,
RendezvousStateError,
)
from .dynamic_rendezvous import RendezvousBackend, Token
from .utils import _matches_machine_hostname, parse_rendezvous_endpoint
logger = logging.getLogger(__name__)
# default port for the TCP store
DEFAULT_PORT = 29400
class C10dRendezvousBackend(RendezvousBackend):
"""Represents a C10d-backed rendezvous backend.
Args:
store:
The :py:class:`core.distributed.Store` instance to use to
communicate with the C10d store.
run_id:
The run id of the rendezvous.
"""
# See the explanation in the __init__ method.
_NULL_SENTINEL = "Y2FuaW1hZGFt"
_store: Store
_key: str
def __init__(self, store: Store, run_id: str) -> None:
if not run_id:
raise ValueError("The run id must be a non-empty string.")
self._store = store
self._key = "core.rendezvous." + run_id
# The read operation of a store blocks the caller until the specified
# key becomes available. This behavior makes it tricky to use a store
# as a regular key-value dictionary.
#
# As a workaround we initially set a sentinel value as the rendezvous
# state. Whenever this value gets returned we treat it as a None.
self._call_store("compare_set", self._key, "", self._NULL_SENTINEL)
@property
def name(self) -> str:
"""See base class."""
return "c10d"
def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""See base class."""
base64_state: bytes = self._call_store("get", self._key)
return self._decode_state(base64_state)
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""See base class."""
base64_state_str: str = b64encode(state).decode()
if token:
# Shortcut if we know for sure that the token is not valid.
if not isinstance(token, bytes):
result = self.get_state()
if result is not None:
tmp = *result, False
# Python 3.6 does not support tuple unpacking in return
# statements.
return tmp
return None
token = token.decode()
else:
token = self._NULL_SENTINEL
base64_state: bytes = self._call_store(
"compare_set", self._key, token, base64_state_str
)
state_token_pair = self._decode_state(base64_state)
if state_token_pair is None:
return None
new_state, new_token = state_token_pair
# C10d Store's compare_set method does not offer an easy way to find out
# whether our write attempt was successful. As a brute-force solution we
# perform a bitwise comparison of our local state and the remote state.
return new_state, new_token, new_state == state
def _call_store(self, store_op: str, *args, **kwargs) -> Any:
try:
return getattr(self._store, store_op)(*args, **kwargs)
except (ValueError, RuntimeError, TimeoutError) as exc:
raise RendezvousConnectionError(
"The connection to the C10d store has failed. See inner exception for details."
) from exc
def _decode_state(self, base64_state: bytes) -> Optional[Tuple[bytes, Token]]:
if base64_state == self._NULL_SENTINEL.encode():
return None
try:
state = b64decode(base64_state)
except binascii.Error as exc:
raise RendezvousStateError(
"The state object is corrupt. See inner exception for details."
) from exc
return state, base64_state
def _create_tcp_store(params: RendezvousParameters) -> TCPStore:
host, port = parse_rendezvous_endpoint(params.endpoint, default_port=DEFAULT_PORT)
cfg_is_host = params.get_as_bool("is_host")
# If the user has explicitly specified whether our process should host the
# the store, respect it.
if cfg_is_host is not None:
is_host = cfg_is_host
# Otherwise try to determine whether we are the host based on our hostname
# and IP address.
else:
is_host = _matches_machine_hostname(host)
# The timeout
read_timeout = cast(int, params.get_as_int("read_timeout", 60))
if read_timeout <= 0:
raise ValueError("The read timeout must be a positive integer.")
# In specific cases we attempt to instantiate the store twice. For details
# see the explanation in the except clause below.
for is_server in [is_host, False]:
try:
store = TCPStore(
host,
port,
is_master=is_server,
multi_tenant=True,
timeout=timedelta(seconds=read_timeout),
)
if is_server:
msg = f"Process {os.getpid()} hosts the TCP store for the C10d rendezvous backend."
construct_and_record_rdzv_event(
run_id=params.run_id, message=msg, node_state=NodeState.INIT
)
logger.info(msg)
break
except (ValueError, RuntimeError, TimeoutError) as exc:
# If we heuristically inferred the value of is_host as True and our
# first attempt to instantiate the TCP store has failed, try it one
# more time with is_host set to False. As an edge case there can be
# more than one process that is part of the same rendezvous on this
# machine and only one of them will eventually host the store.
if not is_server or cfg_is_host is not None:
raise RendezvousConnectionError(
"The connection to the C10d store has failed. See inner exception for details."
) from exc
return store # type: ignore[possibly-undefined]
def _create_file_store(params: RendezvousParameters) -> FileStore:
# If a user specifies an endpoint, we treat it as a path to a file.
if params.endpoint:
path = params.endpoint
else:
try:
# The temporary file is readable and writable only by the user of
# this process.
_, path = tempfile.mkstemp()
except OSError as exc:
raise RendezvousError(
"The file creation for C10d store has failed. See inner exception for details."
) from exc
try:
store = FileStore(path)
except (ValueError, RuntimeError) as exc:
raise RendezvousConnectionError(
"The connection to the C10d store has failed. See inner exception for details."
) from exc
return store
def create_backend(params: RendezvousParameters) -> Tuple[C10dRendezvousBackend, Store]:
"""Create a new :py:class:`C10dRendezvousBackend` from the specified parameters.
+--------------+-----------------------------------------------------------+
| Parameter | Description |
+==============+===========================================================+
| store_type | The type of the C10d store. The currently supported types |
| | are "tcp" and "file" which correspond to |
| | :py:class:`core.distributed.TCPStore` and |
| | :py:class:`core.distributed.FileStore`, respectively. |
| | Defaults to "tcp". |
+--------------+-----------------------------------------------------------+
| read_timeout | The read timeout, in seconds, for store operations. |
| | Defaults to 60 seconds. |
| | |
| | Note this only applies to |
| | :py:class:`core.distributed.TCPStore`. It is not relevant|
| | to :py:class:`core.distributed.FileStore` which does not |
| | take in timeout as a parameter. |
+--------------+-----------------------------------------------------------+
| is_host | A boolean value indicating whether this backend instance |
| | will host the C10d store. If not specified it will be |
| | inferred heuristically by matching the hostname or the IP |
| | address of this machine against the specified rendezvous |
| | endpoint. Defaults to ``None``. |
| | |
| | Note that this configuration option only applies to |
| | :py:class:`core.distributed.TCPStore`. In normal |
| | circumstances you can safely skip it; the only time when |
| | it is needed is if its value cannot be correctly |
| | determined (e.g. the rendezvous endpoint has a CNAME as |
| | the hostname or does not match the FQDN of the machine). |
+--------------+-----------------------------------------------------------+
"""
# As of today we only support TCPStore and FileStore. Other store types do
# not have the required functionality (e.g. compare_set) yet.
store_type = params.get("store_type", "tcp").strip().lower()
store: Store
try:
if store_type == "file":
store = _create_file_store(params)
elif store_type == "tcp":
store = _create_tcp_store(params)
else:
raise ValueError(
"Invalid store type given. Currently only supports file and tcp."
)
backend = C10dRendezvousBackend(store, params.run_id)
except Exception as e:
construct_and_record_rdzv_event(
message=f"{type(e).__name__}: {str(e)}",
run_id=params.run_id,
node_state=NodeState.FAILED,
)
raise
return backend, store

+ 0
- 1431
mindnlp/core/distributed/elastic/rendezvous/dynamic_rendezvous.py View File

@@ -1,1431 +0,0 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import logging
import os
import pickle
import socket
import threading
import time
import weakref
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from mindnlp import core.distributed as dist
from core.distributed import Store
from core.distributed.elastic.events import construct_and_record_rdzv_event, NodeState
from .api import (
RendezvousClosedError,
RendezvousError,
RendezvousGracefulExitError,
RendezvousHandler,
RendezvousInfo,
RendezvousParameters,
RendezvousStateError,
RendezvousStoreInfo,
RendezvousTimeoutError,
)
from .utils import _delay, _PeriodicTimer
__all__ = [
"RendezvousBackend",
"RendezvousTimeout",
"RendezvousSettings",
"DynamicRendezvousHandler",
"create_handler",
]
logger = logging.getLogger(__name__)
def get_method_name(depth=2):
if len(inspect.stack()) > depth:
return inspect.stack()[depth].function
return "no_method_name"
Token = Any
"""Represent an opaque fencing token used by the rendezvous backend."""
class RendezvousBackend(ABC):
"""Represent a backend that holds the rendezvous state."""
@property
@abstractmethod
def name(self) -> str:
"""Get the name of the backend."""
@abstractmethod
def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""Get the rendezvous state.
Returns:
A tuple of the encoded rendezvous state and its fencing token or
``None`` if no state is found in the backend.
Raises:
RendezvousConnectionError:
The connection to the backend has failed.
RendezvousStateError:
The rendezvous state is corrupt.
"""
@abstractmethod
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""Set the rendezvous state.
The new rendezvous state is set conditionally:
- If the specified ``token`` matches the fencing token stored in the
backend, the state will be updated. The new state will be returned
to the caller along with its fencing token.
- If the specified ``token`` does not match the fencing token stored
in the backend, the state won't be updated; instead the existing
state along with its fencing token will be returned to the caller.
- If the specified ``token`` is ``None``, the new state will be set
only if there is no existing state in the backend. Either the new
state or the existing state along with its fencing token will be
returned to the caller.
Args:
state:
The encoded rendezvous state.
token:
An optional fencing token that was retrieved by a previous call
to :py:meth:`get_state` or ``set_state()``.
Returns:
A tuple of the serialized rendezvous state, its fencing token, and
a boolean value indicating whether our set attempt succeeded.
Raises:
RendezvousConnectionError:
The connection to the backend has failed.
RendezvousStateError:
The rendezvous state is corrupt.
"""
class RendezvousTimeout:
"""Hold the timeout configuration of a rendezvous.
Args:
join:
The time within which the rendezvous is expected to complete.
last_call:
An additional wait amount before completing the rendezvous once the
rendezvous has the minimum number of required participants.
close:
The time within which the rendezvous is expected to close after a
call to :py:meth:`RendezvousHandler.set_closed` or
:py:meth:`RendezvousHandler.shutdown`.
keep_alive:
The time within which a keep-alive heartbeat is expected to
complete.
"""
_ZERO = timedelta(0)
_DEFAULT_TIMEOUTS = {
"join": timedelta(seconds=600),
"last_call": timedelta(seconds=30),
"close": timedelta(seconds=30),
"heartbeat": timedelta(seconds=5),
}
_join: timedelta
_last_call: timedelta
_close: timedelta
_heartbeat: timedelta
def __init__(
self,
join: Optional[timedelta] = None,
last_call: Optional[timedelta] = None,
close: Optional[timedelta] = None,
heartbeat: Optional[timedelta] = None,
) -> None:
self._set_timeouts(
join=join, last_call=last_call, close=close, heartbeat=heartbeat
)
@property
def join(self) -> timedelta:
"""Get the join timeout."""
return self._join
@property
def last_call(self) -> timedelta:
"""Get the last call timeout."""
return self._last_call
@property
def close(self) -> timedelta:
"""Get the close timeout."""
return self._close
@property
def heartbeat(self) -> timedelta:
"""Get the keep-alive heartbeat timeout."""
return self._heartbeat
def _set_timeouts(self, **timeouts: Optional[timedelta]):
for name, timeout in timeouts.items():
if timeout is None:
timeout = self._DEFAULT_TIMEOUTS[name]
if timeout <= self._ZERO:
raise ValueError(f"The {name} timeout ({timeout}) must be positive.")
setattr(self, "_" + name, timeout)
@dataclass(repr=False, eq=False, frozen=True)
class RendezvousSettings:
"""Hold the settings of the rendezvous.
Attributes:
run_id:
The run id of the rendezvous.
min_nodes:
The minimum number of nodes to admit to the rendezvous.
max_nodes:
The maximum number of nodes to admit to the rendezvous.
timeout:
The timeout configuration of the rendezvous.
keep_alive_interval:
The amount of time a node waits before sending a heartbeat to keep
it alive in the rendezvous.
keep_alive_max_attempt:
The maximum number of failed heartbeat attempts after which a node
is considered dead.
"""
run_id: str
min_nodes: int
max_nodes: int
timeout: RendezvousTimeout
keep_alive_interval: timedelta
keep_alive_max_attempt: int
@dataclass(eq=True, order=True, frozen=True)
class _NodeDesc:
"""Describe a node in the rendezvous.
Attributes:
addr:
The FQDN of the node or user specified local node address.
pid:
The id of the process in which the rendezvous handler runs.
local_id:
A process-wide unique id.
"""
addr: str
pid: int
local_id: int
def __repr__(self) -> str:
return f"{self.addr}_{self.pid}_{self.local_id}"
class _NodeDescGenerator:
"""Generate node descriptors.
A node descriptor is a combination of an FQDN, a process id, and an auto-
incremented integer that uniquely identifies a node in the rendezvous.
"""
_lock: threading.Lock
_local_id: int
def __init__(self) -> None:
self._lock = threading.Lock()
# An integer that is incremented with each call to generate().
self._local_id = 0
def generate(self, local_addr: Optional[str] = None) -> _NodeDesc:
# This method can be called by multiple threads concurrently; therefore,
# we must increment the integer atomically.
with self._lock:
local_id = self._local_id
self._local_id += 1
return _NodeDesc(local_addr or socket.getfqdn(), os.getpid(), local_id)
class _RendezvousState:
"""Hold the state of a rendezvous.
Attributes:
round:
The current round of the rendezvous.
complete:
A boolean value indicating whether the current round of the
rendezvous is complete.
deadline:
The time at which the current round of the rendezvous will be
considered complete if it is still waiting for nodes to join.
closed:
A boolean value indicating whether the rendezvous is closed.
participants:
A dictionary of the participants and their corresponding ranks.
wait_list:
A set of nodes that are waiting to participate in the next round of
the rendezvous.
redundancy_list:
A set of nodes that are redundant in the current round and can join
the next rendezvous without triggering re-rendezvous.
last_heartbeats:
A dictionary containing each node's last heartbeat time.
"""
round: int
complete: bool
deadline: Optional[datetime]
closed: bool
participants: Dict[_NodeDesc, int]
wait_list: Set[_NodeDesc]
redundancy_list: Set[_NodeDesc]
last_heartbeats: Dict[_NodeDesc, datetime]
def __init__(self) -> None:
self.round = 0
self.complete = False
self.deadline = None
self.closed = False
self.participants = {}
self.wait_list = set()
self.redundancy_list = set()
self.last_heartbeats = {}
def _remove_participant_epilogue(
state: _RendezvousState, settings: RendezvousSettings
) -> None:
if state.complete:
# If we do not have any participants left, move to the next round.
if not state.participants:
msg = "No participants left in the rendezvous, marking rendezvous as incomplete"
logger.debug(msg)
state.complete = False
state.round += 1
else:
if len(state.participants) < settings.min_nodes:
msg = (
f"Number of participants {len(state.participants)}) less than"
f"min_nodes {settings.min_nodes}, clearning deadline in state"
)
logger.debug(msg)
state.deadline = None
class _RendezvousStateHolder(ABC):
"""Hold the shared rendezvous state synced with other nodes."""
@property
@abstractmethod
def state(self) -> _RendezvousState:
"""Get the local state."""
@abstractmethod
def sync(self) -> Optional[bool]:
"""Read or writes the latest state.
Returns:
A boolean value indicating whether the local state, in case marked
as dirty, was successfully synced with other nodes.
"""
@abstractmethod
def mark_dirty(self) -> None:
"""Mark the local state as dirty."""
class _BackendRendezvousStateHolder(_RendezvousStateHolder):
"""Hold the rendezvous state synced with other nodes via a backend.
Args:
backend:
The rendezvous backend to use.
settings:
The rendezvous settings.
cache_duration:
The amount of time, in seconds, to cache the last rendezvous state
before requesting it from the backend again.
"""
_backend: RendezvousBackend
_state: _RendezvousState
_settings: RendezvousSettings
_cache_duration: int
_token: Token
_dirty: bool
_last_sync_time: float
_dead_nodes: List[_NodeDesc]
def __init__(
self,
backend: RendezvousBackend,
settings: RendezvousSettings,
cache_duration: int = 1,
) -> None:
self._backend = backend
self._state = _RendezvousState()
self._settings = settings
self._cache_duration = cache_duration
self._token = None
self._dirty = False
self._last_sync_time = -1
self._dead_nodes = []
def _record(self, message: str, node_state: NodeState = NodeState.RUNNING):
construct_and_record_rdzv_event(
name=f"{self.__class__.__name__}.{get_method_name()}",
run_id=self._settings.run_id,
message=message,
node_state=node_state,
)
@property
def state(self) -> _RendezvousState:
"""See base class."""
return self._state
def sync(self) -> Optional[bool]:
"""See base class."""
state_bits: Optional[bytes] = None
token = None
has_set: Optional[bool]
if self._dirty:
has_set = False
state_bits = pickle.dumps(self._state)
set_response = self._backend.set_state(state_bits, self._token)
if set_response is not None:
state_bits, token, has_set = set_response
else:
has_set = None
if self._cache_duration > 0:
# Avoid overloading the backend if we are asked to retrieve the
# state repeatedly. Try to serve the cached state.
if self._last_sync_time >= max(
time.monotonic() - self._cache_duration, 0
):
return None
get_response = self._backend.get_state()
if get_response is not None:
state_bits, token = get_response
if state_bits is not None:
try:
self._state = pickle.loads(state_bits)
except pickle.PickleError as exc:
raise RendezvousStateError(
"The rendezvous state is corrupt. See inner exception for details."
) from exc
else:
self._state = _RendezvousState()
if has_set and self._dead_nodes and logger.isEnabledFor(logging.DEBUG):
node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes)
msg = (
f"As part of the sync operation the node(s) {node_list} have been removed from the "
f"rendezvous '{self._settings.run_id}' since they had no heartbeat."
)
self._record(message=msg)
logger.debug(msg)
self._token = token
self._dirty = False
self._last_sync_time = time.monotonic()
self._sanitize()
return has_set
def _sanitize(self) -> None:
state = self._state
expire_time = datetime.now(timezone.utc) - (
self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt
)
# Filter out the dead nodes.
self._dead_nodes = [
node
for node, last_heartbeat in state.last_heartbeats.items()
if last_heartbeat < expire_time
]
participant_removed = False
for dead_node in self._dead_nodes:
msg = f"Detected dead node '{dead_node}', removing it from the rendezvous"
logger.debug(msg)
del state.last_heartbeats[dead_node]
try:
del state.participants[dead_node]
participant_removed = True
except KeyError:
pass
try:
state.wait_list.remove(dead_node)
except KeyError:
pass
try:
state.redundancy_list.remove(dead_node)
except KeyError:
pass
if participant_removed:
# Common epilogue shared with the _remove_from_participants()
# function of _DistributedRendezvousOpExecutor.
_remove_participant_epilogue(state, self._settings)
def mark_dirty(self) -> None:
"""See base class.
If the local rendezvous state is dirty, the next sync call will try to
write the changes back to the backend. However this attempt might fail
if another node, which had the same state, also made changes and wrote
them before us.
"""
self._dirty = True
class _Action(Enum):
"""Specifies the possible actions based on the state of the rendezvous."""
KEEP_ALIVE = 1
ADD_TO_PARTICIPANTS = 2
ADD_TO_WAIT_LIST = 3
ADD_TO_REDUNDANCY_LIST = 4
REMOVE_FROM_PARTICIPANTS = 5
REMOVE_FROM_WAIT_LIST = 6
REMOVE_FROM_REDUNDANCY_LIST = 7
MARK_RENDEZVOUS_COMPLETE = 8
MARK_RENDEZVOUS_CLOSED = 9
SYNC = 10
ERROR_CLOSED = 11
ERROR_TIMEOUT = 12
FINISH = 13
class _RendezvousContext:
"""Holds the context of the rendezvous.
Attributes:
node:
The node descriptor associated with the current rendezvous handler
instance.
state:
The current state of the rendezvous.
settings:
The rendezvous settings.
"""
node: _NodeDesc
state: _RendezvousState
settings: RendezvousSettings
def __init__(
self, node: _NodeDesc, state: _RendezvousState, settings: RendezvousSettings
) -> None:
self.node = node
self.state = state
self.settings = settings
class _RendezvousOpExecutor(ABC):
"""Execute rendezvous operations."""
@abstractmethod
def run(
self,
state_handler: Callable[[_RendezvousContext, float], _Action],
deadline: float,
update_deadline: Optional[Callable[[timedelta], float]] = None,
) -> None:
"""Execute a rendezvous operation.
An operation is run inside a state machine and is expected to transition
the rendezvous from one state to another.
Args:
state_handler:
A callable that is expected to return the next state transition
action based on the current state of the rendezvous.
deadline:
The time, in seconds, at which the operation will be considered
timed-out.
update_deadline:
Function to generate a new operation deadline if the current
node may participate in the next rendezvous.
"""
class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor):
"""Execute rendezvous operations using a shared state.
Args:
node:
The node descriptor associated with the current rendezvous handler
instance.
state_holder:
The ``RendezvousStateHolder`` to use to sync the rendezvous state
with other nodes.
settings:
The rendezvous settings.
"""
_node: _NodeDesc
_state: _RendezvousState
_state_holder: _RendezvousStateHolder
_settings: RendezvousSettings
def __init__(
self,
node: _NodeDesc,
state_holder: _RendezvousStateHolder,
settings: RendezvousSettings,
) -> None:
self._node = node
self._state_holder = state_holder
self._settings = settings
def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None:
construct_and_record_rdzv_event(
name=f"{self.__class__.__name__}.{get_method_name()}",
run_id=self._settings.run_id,
message=message,
node_state=node_state,
hostname=self._node.addr,
pid=self._node.pid,
local_id=self._node.local_id,
)
def run(
self,
state_handler: Callable[[_RendezvousContext, float], _Action],
deadline: float,
update_deadline: Optional[Callable[[timedelta], float]] = None,
) -> None:
"""See base class."""
action = None
while action != _Action.FINISH:
# Reads or writes the latest rendezvous state shared by all nodes in
# the rendezvous. Note that our local changes might get overridden
# by another node if that node synced its changes before us.
has_set = self._state_holder.sync()
if has_set is not None:
if has_set:
msg = (
f"The node '{self._node}' has successfully synced its local changes with "
f"other nodes in the rendezvous '{self._settings.run_id}'."
)
else:
msg = (
f"The node '{self._node}' has a stale state and failed to sync its local "
f"changes with other nodes in the rendezvous '{self._settings.run_id}'."
)
self._record(message=msg)
logger.debug(msg)
self._state = self._state_holder.state
ctx = _RendezvousContext(self._node, self._state, self._settings)
# Determine the next action to take based on the current state of
# the rendezvous.
action = state_handler(ctx, deadline)
if action == _Action.FINISH:
continue
if action == _Action.ERROR_CLOSED:
raise RendezvousClosedError
if action == _Action.ERROR_TIMEOUT:
raise RendezvousTimeoutError
if action == _Action.SYNC:
# Delay the execution by one second to avoid overloading the
# backend if we are asked to poll for state changes.
_delay(seconds=1)
else:
if action == _Action.KEEP_ALIVE:
self._keep_alive()
elif action == _Action.ADD_TO_PARTICIPANTS:
self._add_to_participants()
elif action == _Action.ADD_TO_WAIT_LIST:
self._add_to_wait_list()
elif action == _Action.ADD_TO_REDUNDANCY_LIST:
self._add_to_redundancy_list()
elif action == _Action.REMOVE_FROM_PARTICIPANTS:
self._remove_from_participants()
elif action == _Action.REMOVE_FROM_WAIT_LIST:
self._remove_from_wait_list()
elif action == _Action.REMOVE_FROM_REDUNDANCY_LIST:
self._remove_from_redundancy_list()
# update deadline since the node may participate in rendezvous process
if update_deadline:
deadline = update_deadline(self._settings.timeout.join)
elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
self._mark_rendezvous_complete()
elif action == _Action.MARK_RENDEZVOUS_CLOSED:
self._mark_rendezvous_closed()
# Attempt to sync our changes back to other nodes.
self._state_holder.mark_dirty()
def _keep_alive(self) -> None:
msg = (
f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous "
f"'{self._settings.run_id}'. Pending sync."
)
self._record(message=msg)
logger.debug(msg)
self._state.last_heartbeats[self._node] = datetime.now(timezone.utc)
def _add_to_participants(self) -> None:
msg = (
f"The node '{self._node}' added itself to the participants of round "
f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync."
)
self._record(message=msg)
logger.debug(msg)
state = self._state
try:
state.wait_list.remove(self._node)
except KeyError:
pass
# The ranks of the participants will be set once the rendezvous is
# complete.
state.participants[self._node] = 0
self._keep_alive()
if len(state.participants) == self._settings.min_nodes:
state.deadline = (
datetime.now(timezone.utc) + self._settings.timeout.last_call
)
if len(state.participants) == self._settings.max_nodes:
self._mark_rendezvous_complete()
def _add_to_wait_list(self) -> None:
msg = (
f"The node '{self._node}' added itself to the wait list of round "
f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync."
)
self._record(message=msg)
logger.debug(msg)
if self._node in self._state.redundancy_list:
self._state.redundancy_list.remove(self._node)
self._state.wait_list.add(self._node)
self._keep_alive()
def _add_to_redundancy_list(self) -> None:
msg = (
f"The node '{self._node}' added itself to the redundancy list of round "
f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync."
)
self._record(message=msg)
logger.debug(msg)
self._state.redundancy_list.add(self._node)
self._keep_alive()
def _remove_from_participants(self) -> None:
msg = (
f"The node '{self._node}' removed itself from the participants of round "
f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync."
)
self._record(message=msg)
logger.debug(msg)
state = self._state
del state.participants[self._node]
del state.last_heartbeats[self._node]
# Common epilogue shared with the sanitizer() function of
# _BackendRendezvousStateHolder.
_remove_participant_epilogue(state, self._settings)
def _remove_from_wait_list(self) -> None:
msg = (
f"The node '{self._node}' removed itself from the wait list of round "
f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync."
)
self._record(message=msg)
logger.debug(msg)
self._state.wait_list.remove(self._node)
del self._state.last_heartbeats[self._node]
def _remove_from_redundancy_list(self) -> None:
msg = (
f"The node '{self._node}' removed itself from the redunant list of round "
f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync."
)
self._record(message=msg)
logger.debug(msg)
self._state.redundancy_list.remove(self._node)
del self._state.last_heartbeats[self._node]
def _mark_rendezvous_complete(self) -> None:
msg = (
f"The node '{self._node}' marked round {self._state.round} of the rendezvous "
f"'{self._settings.run_id}' as complete. Pending sync."
)
self._record(message=msg, node_state=NodeState.SUCCEEDED)
logger.debug(msg)
state = self._state
state.complete = True
state.deadline = None
# Assign the ranks.
for rank, node in enumerate(sorted(state.participants)):
state.participants[node] = rank
def _mark_rendezvous_closed(self) -> None:
msg = (
f"The node '{self._node}' marked the rendezvous '{self._settings.run_id}' as closed. "
"Pending sync."
)
self._record(message=msg, node_state=NodeState.SUCCEEDED)
logger.debug(msg)
self._state.closed = True
def _should_keep_alive(ctx: _RendezvousContext) -> bool:
"""Determine whether a keep-alive heartbeat should be sent."""
try:
last_heartbeat = ctx.state.last_heartbeats[ctx.node]
except KeyError:
return False
return (
last_heartbeat <= datetime.now(timezone.utc) - ctx.settings.keep_alive_interval
)
class _RendezvousExitOp:
"""Represent a rendezvous exit operation."""
def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
if ctx.node in ctx.state.participants:
if time.monotonic() > deadline:
return _Action.ERROR_TIMEOUT
return _Action.REMOVE_FROM_PARTICIPANTS
return _Action.FINISH
class _RendezvousJoinOp:
"""Represent a rendezvous join operation."""
def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
state = ctx.state
# A closed rendezvous means that it no longer accepts new nodes.
if state.closed:
if ctx.node in state.redundancy_list:
msg = f"The rendezvous '{ctx.settings.run_id}' is closed, terminating pending rendezvous."
raise RendezvousGracefulExitError(msg)
return _Action.ERROR_CLOSED
if ctx.node in state.redundancy_list:
msg = f"The node {ctx.node} is in redunancy list"
logger.debug(msg)
# don't apply the timeout logic here, since we want to allow the node to rejoin
if len(state.participants) == ctx.settings.max_nodes:
if _should_keep_alive(ctx):
return _Action.KEEP_ALIVE
else:
return _Action.SYNC
else:
# transition to waiting state that will respect timeouts.
msg = f"The node {ctx.node} is removed from redunancy list"
logger.debug(msg)
return _Action.REMOVE_FROM_REDUNDANCY_LIST
is_participant = ctx.node in state.participants
# If we are part of the rendezvous and it is already complete there is
# no further action to take.
if state.complete and is_participant:
return _Action.FINISH
now = time.monotonic()
if now > deadline:
rollback_period = 5 # 5 seconds
# If we still have time to rollback (a short period on top of the
# operation deadline), try to remove ourself from the rendezvous.
# It is okay if we can't though as our keep-alive will eventually
# expire.
if now <= deadline + rollback_period:
# If we are part of the rendezvous, it means we couldn't find
# enough participants to complete it on time.
if is_participant:
return _Action.REMOVE_FROM_PARTICIPANTS
# If we are in the wait list, it means we couldn't wait till the
# next round of the rendezvous.
if ctx.node in state.wait_list:
return _Action.REMOVE_FROM_WAIT_LIST
return _Action.ERROR_TIMEOUT
if state.complete:
# If we are here, it means we are not part of the rendezvous. In
# case the rendezvous has capacity for additional participants add
# ourself to the wait list for the next round.
if len(state.participants) < ctx.settings.max_nodes:
if ctx.node not in state.wait_list:
return _Action.ADD_TO_WAIT_LIST
elif len(state.participants) >= ctx.settings.max_nodes:
if (
ctx.node not in state.redundancy_list
and ctx.node not in state.wait_list
):
return _Action.ADD_TO_REDUNDANCY_LIST
elif is_participant:
# If the rendezvous has enough number of participants including us,
# check whether we have passed the rendezvous deadline. If yes,
# complete it.
if (
len(state.participants) >= ctx.settings.min_nodes
and len(state.participants) <= ctx.settings.max_nodes
and state.deadline is not None
):
if state.deadline < datetime.now(timezone.utc):
msg = (
f"The node '{ctx.node}' marking the rendezvous complete, "
f"quorum established within deadline"
)
logger.debug(msg)
return _Action.MARK_RENDEZVOUS_COMPLETE
else:
msg = f"The node '{ctx.node}' can't complete rendezvous: deadline reached"
logger.debug(msg)
else:
msg = f"The node '{ctx.node}' can't complete rendezvous: not enough participants"
logger.debug(msg)
else:
# The rendezvous is not complete yet and we are not part of it. Try
# to join.
return _Action.ADD_TO_PARTICIPANTS
if _should_keep_alive(ctx):
return _Action.KEEP_ALIVE
# At this point either the rendezvous is not complete, but we are part
# of it, which means we have to wait for other participants to join; or
# the rendezvous is complete, but we are not part of it, which means we
# have to wait for the next round.
return _Action.SYNC
class _RendezvousCloseOp:
"""Represent a rendezvous close operation."""
def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
if ctx.state.closed:
return _Action.FINISH
if time.monotonic() > deadline:
return _Action.ERROR_TIMEOUT
return _Action.MARK_RENDEZVOUS_CLOSED
class _RendezvousKeepAliveOp:
"""Represent a rendezvous keep-alive update operation."""
def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
if _should_keep_alive(ctx):
if time.monotonic() > deadline:
return _Action.ERROR_TIMEOUT
return _Action.KEEP_ALIVE
return _Action.FINISH
class DynamicRendezvousHandler(RendezvousHandler):
"""Represent a handler that sets up a rendezvous among a set of nodes."""
# Static
_node_desc_generator = _NodeDescGenerator()
_this_node: _NodeDesc
_settings: RendezvousSettings
_backend_name: str
_store: Store
_state_holder: _RendezvousStateHolder
_op_executor: _RendezvousOpExecutor
_heartbeat_lock: threading.Lock
_keep_alive_timer: Optional[_PeriodicTimer]
@classmethod
def from_backend(
cls,
run_id: str,
store: Store,
backend: RendezvousBackend,
min_nodes: int,
max_nodes: int,
local_addr: Optional[str] = None,
timeout: Optional[RendezvousTimeout] = None,
):
"""Create a new :py:class:`DynamicRendezvousHandler`.
Args:
run_id:
The run id of the rendezvous.
store:
The C10d store to return as part of the rendezvous.
backend:
The backend to use to hold the rendezvous state.
min_nodes:
The minimum number of nodes to admit to the rendezvous.
max_nodes:
The maximum number of nodes to admit to the rendezvous.
local_addr:
The local node address.
timeout:
The timeout configuration of the rendezvous.
"""
# We associate each handler instance with a unique node descriptor.
node = cls._node_desc_generator.generate(local_addr)
settings = RendezvousSettings(
run_id,
min_nodes,
max_nodes,
timeout or RendezvousTimeout(),
keep_alive_interval=timedelta(seconds=5),
keep_alive_max_attempt=3,
)
state_holder = _BackendRendezvousStateHolder(backend, settings)
return cls(node, settings, backend.name, store, state_holder)
def __init__(
self,
node: _NodeDesc,
settings: RendezvousSettings,
backend_name: str,
store: Store,
state_holder: _RendezvousStateHolder,
) -> None:
if not settings.run_id:
raise ValueError("The run id must be a non-empty string.")
if settings.min_nodes < 1:
raise ValueError(
f"The minimum number of nodes ({settings.min_nodes}) must be greater than zero."
)
if settings.max_nodes < settings.min_nodes:
raise ValueError(
f"The maximum number of nodes ({settings.max_nodes}) must be greater than or equal "
f"to the minimum number of nodes ({settings.min_nodes})."
)
self._this_node = node
self._settings = settings
self._backend_name = backend_name
self._store = store
self._state_holder = state_holder
self._op_executor = _DistributedRendezvousOpExecutor(
self._this_node, self._state_holder, self._settings
)
self._heartbeat_lock = threading.Lock()
self._keep_alive_timer = None
# Cached shared store server reference
self._shared_tcp_store_server: Optional[dist.Store] = None
self._bootstrap_store_info: Optional[RendezvousStoreInfo] = None
def _record(
self,
message: str,
node_state: NodeState = NodeState.RUNNING,
rank: Optional[int] = None,
) -> None:
construct_and_record_rdzv_event(
name=f"{self.__class__.__name__}.{get_method_name()}",
run_id=self._settings.run_id,
message=message,
node_state=node_state,
hostname=self._this_node.addr,
pid=self._this_node.pid,
local_id=self._this_node.local_id,
rank=rank,
)
def _create_tcp_store_server(self, master_addr, master_port) -> dist.TCPStore:
return dist.TCPStore(
host_name=master_addr,
port=master_port,
is_master=True,
multi_tenant=True,
)
@property
def settings(self) -> RendezvousSettings:
"""Get the settings of the rendezvous."""
return self._settings
def get_backend(self) -> str:
"""See base class."""
return self._backend_name
@property
def use_agent_store(self) -> bool:
"""See base class."""
return os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") != "1"
def next_rendezvous(self) -> RendezvousInfo:
"""See base class."""
msg = (
f"The node '{self._this_node}' attempts to join the next round of the rendezvous "
f"'{self._settings.run_id}'."
)
self._record(message=msg)
logger.info(msg)
try:
self._stop_heartbeats()
# Delay the execution for a small random amount of time if this is our
# first run. This will slightly skew the rendezvous attempts across the
# nodes and reduce the load on the backend.
if self._state_holder.state.round == 0:
_delay(seconds=(0, 0.3))
exit_op = _RendezvousExitOp()
join_op = _RendezvousJoinOp()
deadline = self._get_deadline(self._settings.timeout.join)
self._op_executor.run(exit_op, deadline)
self._op_executor.run(join_op, deadline, self._get_deadline)
self._start_heartbeats()
rank, world_size = self._get_world()
store = self._get_store()
except Exception as e:
self._record(
message=f"{type(e).__name__}: {str(e)}",
node_state=NodeState.FAILED,
)
raise
msg = (
f"The node '{self._this_node}' has joined round {self._state_holder.state.round} of "
f"the rendezvous '{self._settings.run_id}' as rank {rank} in a world of size "
f"{world_size}."
)
self._record(message=msg, rank=rank)
logger.info(msg)
# opt-out option of TCPStore sharing
if os.getenv("TORCH_DISABLE_SHARE_RDZV_TCP_STORE", "0") == "1":
bootstrap_store_info = RendezvousStoreInfo.build(
rank, store, local_addr=self._this_node.addr
)
return RendezvousInfo(
store,
rank,
world_size,
bootstrap_store_info,
)
# This will only be hit when TCPStore sharing is enabled.
if self._bootstrap_store_info is None:
# To avoid race in get_free_port because we release the port after the call,
# we want to create a TCPStore server soon afterwards.
server_port = 0
if rank == 0:
self._shared_tcp_store_server = self._create_tcp_store_server(
self._this_node.addr, server_port
)
server_port = self._shared_tcp_store_server.port
self._bootstrap_store_info = RendezvousStoreInfo.build(
rank,
store,
local_addr=self._this_node.addr,
server_port=server_port, # For non-0 rank, this is a no-op
)
assert self._bootstrap_store_info is not None
if rank == 0:
assert self._shared_tcp_store_server is not None
return RendezvousInfo(
store,
rank,
world_size,
self._bootstrap_store_info, # type: ignore[assignment]
)
def is_closed(self) -> bool:
"""See base class."""
try:
with self._heartbeat_lock:
self._state_holder.sync()
return self._state_holder.state.closed
except Exception as e:
self._record(
message=f"{type(e).__name__}: {str(e)}",
node_state=NodeState.FAILED,
)
raise
def set_closed(self) -> None:
"""See base class."""
try:
with self._heartbeat_lock:
self._close()
except Exception as e:
self._record(
message=f"{type(e).__name__}: {str(e)}",
node_state=NodeState.FAILED,
)
raise
def num_nodes_waiting(self) -> int:
"""See base class."""
try:
with self._heartbeat_lock:
self._state_holder.sync()
return len(self._state_holder.state.wait_list)
except Exception as e:
self._record(
message=f"{type(e).__name__}: {str(e)}",
node_state=NodeState.FAILED,
)
raise
def get_run_id(self) -> str:
"""See base class."""
return self._settings.run_id
def shutdown(self) -> bool:
"""See base class."""
self._stop_heartbeats()
try:
self._close()
return True
except RendezvousError as ex:
msg = (
f"The node '{self._this_node}' has failed to shutdown the rendezvous "
f"'{self._settings.run_id}' due to an error of type {type(ex).__name__}."
)
self._record(message=msg, node_state=NodeState.FAILED)
logger.warning(msg)
return False
except Exception as e:
self._record(
message=f"{type(e).__name__}: {str(e)}",
node_state=NodeState.FAILED,
)
raise
def _close(self) -> None:
op = _RendezvousCloseOp()
deadline = self._get_deadline(self._settings.timeout.close)
self._op_executor.run(op, deadline)
msg = f"The node '{self._this_node}' has closed the rendezvous '{self._settings.run_id}'."
self._record(message=msg, node_state=NodeState.SUCCEEDED)
logger.info(msg)
@staticmethod
def _keep_alive_weak(weak_self) -> None:
self = weak_self()
if self is not None:
self._keep_alive()
def _keep_alive(self) -> None:
self._heartbeat_lock.acquire()
op = _RendezvousKeepAliveOp()
deadline = self._get_deadline(self._settings.timeout.heartbeat)
try:
self._op_executor.run(op, deadline)
msg = (
f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous "
f"'{self._settings.run_id}'."
)
self._record(message=msg)
logger.debug(msg)
except RendezvousError as ex:
msg = (
f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the "
f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}."
)
self._record(message=msg, node_state=NodeState.FAILED)
logger.warning(msg)
finally:
self._heartbeat_lock.release()
def _start_heartbeats(self) -> None:
self._keep_alive_timer = _PeriodicTimer(
self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self)
)
self._keep_alive_timer.set_name(
f"RendezvousKeepAliveTimer_{self._this_node.local_id}"
)
self._keep_alive_timer.start()
def _stop_heartbeats(self) -> None:
if self._keep_alive_timer is None:
return
self._keep_alive_timer.cancel()
def _get_world(self) -> Tuple[int, int]:
state = self._state_holder.state
return state.participants[self._this_node], len(state.participants)
def _wrap_store(self, store: Store) -> Store:
key_prefix = (
f"core.rendezvous.{self._settings.run_id}.{self._state_holder.state.round}"
)
return dist.PrefixStore(key_prefix, store)
def _get_store(self) -> Store:
return self._wrap_store(self._store)
def _get_deadline(self, timeout: timedelta) -> float:
return time.monotonic() + timeout.total_seconds()
def _get_timeout(params: RendezvousParameters, key: str) -> Optional[timedelta]:
timeout = params.get_as_int(key + "_timeout")
if timeout is None:
return None
return timedelta(seconds=timeout)
def create_handler(
store: Store, backend: RendezvousBackend, params: RendezvousParameters
) -> DynamicRendezvousHandler:
"""Create a new :py:class:`DynamicRendezvousHandler` from the specified parameters.
Args:
store:
The C10d store to return as part of the rendezvous.
backend:
The backend to use to hold the rendezvous state.
+-------------------+------------------------------------------------------+
| Parameter | Description |
+===================+======================================================+
| join_timeout | The total time, in seconds, within which the |
| | rendezvous is expected to complete. Defaults to 600 |
| | seconds. |
+-------------------+------------------------------------------------------+
| last_call_timeout | An additional wait amount, in seconds, before |
| | completing the rendezvous once the minimum number of |
| | nodes has been reached. Defaults to 30 seconds. |
+-------------------+------------------------------------------------------+
| close_timeout | The time, in seconds, within which the rendezvous is |
| | expected to close after a call to |
| | :py:meth:`RendezvousHandler.set_closed` or |
| | :py:meth:`RendezvousHandler.shutdown`. Defaults to |
| | 30 seconds. |
+-------------------+------------------------------------------------------+
"""
try:
timeout = RendezvousTimeout(
_get_timeout(params, "join"),
_get_timeout(params, "last_call"),
_get_timeout(params, "close"),
)
return DynamicRendezvousHandler.from_backend(
params.run_id,
store,
backend,
params.min_nodes,
params.max_nodes,
params.local_addr,
timeout,
)
except Exception as e:
construct_and_record_rdzv_event(
message=f"{type(e).__name__}: {str(e)}",
run_id=params.run_id,
node_state=NodeState.FAILED,
)
raise

+ 0
- 1077
mindnlp/core/distributed/elastic/rendezvous/etcd_rendezvous.py View File

@@ -1,1077 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
import logging
import sys
import threading
import time
from typing import Optional
import etcd # type: ignore[import]
from core.distributed.elastic.rendezvous import (
RendezvousClosedError,
RendezvousError,
RendezvousHandler,
RendezvousInfo,
RendezvousParameters,
RendezvousStoreInfo,
RendezvousTimeoutError,
)
from .etcd_store import cas_delay, EtcdStore
from .utils import parse_rendezvous_endpoint
__all__ = [
"EtcdRendezvousRetryableFailure",
"EtcdRendezvousRetryImmediately",
"EtcdRendezvousHandler",
"EtcdRendezvous",
"create_rdzv_handler",
]
_log_fmt = logging.Formatter("%(levelname)s %(asctime)s %(message)s")
_log_handler = logging.StreamHandler(sys.stderr)
_log_handler.setFormatter(_log_fmt)
logger = logging.getLogger(__name__)
logger.propagate = False
logger.setLevel(logging.INFO)
logger.addHandler(_log_handler)
# Retryable failure exception means the we were too late to make
# a desired state transition (e.g. because of a race condition),
# and should now restart from the beginning.
# A small delay is recommended to avoid spamming Etcd.
class EtcdRendezvousRetryableFailure(Exception):
pass
# Similar to retryable failure, but the new state we observed suggests we
# can re-try immediately, i.e. without a need for "safety delay".
class EtcdRendezvousRetryImmediately(Exception):
pass
# Default timeout for the rendezvous.
_DEFAULT_TIMEOUT: int = 600 # 10 minutes
# Additional waiting time after reaching the minimum number of nodes
# in case the rendezvous is elastic (min != max).
_DEFAULT_LAST_CALL_TIMEOUT: int = 30 # 30 seconds
# Various constants used internally in EtcdRendezvous
CONST_ETCD_SETUP_TTL = 5
CONST_ETCD_FROZEN_TTL = 10
CONST_ETCD_JOINABLE_EPHEMERAL_TTL = 10
# Ephemeral node TTL for worker's keep-alive key:
CONST_WORKER_KEEPALIVE_TTL = 10
# TTL for the ephemeral run_id-specific directory. All rendezvous state data
# for a specific run_id (job instance) is contained within directory.
# Its only role is to clean-up rendezvous data from old runs (for the case when
# etcd server is persistent), and has no affect on correctness, but should be
# larger than any timeouts that a worker process is expected to survive:
CONST_RUNID_SUBROOT_TTL = 7200 # 2 hours
class EtcdRendezvousHandler(RendezvousHandler):
"""
Implements a
:py:class:`core.distributed.elastic.rendezvous.RendezvousHandler` interface
backed by
:py:class:`core.distributed.elastic.rendezvous.etcd_rendezvous.EtcdRendezvous`.
``EtcdRendezvousHandler`` uses a URL to configure the type of rendezvous to
use and to pass implementation specific configurations to the rendezvous
module. The basic etcd rendezvous configuration URL looks like the following
::
etcd://<etcd_address>:<port>/<job_id>?min_workers=<min_workers>&max_workers=<max_workers> # noqa: W605
-- example --
etcd://localhost:2379/1234?min_workers=1&max_workers=3
The URL above is interpreted as follows:
1. Use the rendezvous handler that is registered with the ``etcd``
scheme
2. The ``etcd`` endpoint to use is ``localhost:2379``
3. ``job_id == 1234`` is used as the prefix in etcd (this allows one to
share a common etcd server for multiple jobs so long as the
``job_ids`` are guaranteed to be unique). Note that the job id can be
any string (e.g. does not need to be a number) as long as it is
unique.
4. ``min_workers=1`` and ``max_workers=3`` specifies a range for
membership size - Torch Distributed Elastic starts running the job as
long as the cluster size is greater than or equal to ``min_workers``
and admits up to ``max_workers`` into the cluster.
Below are a full list of the parameters that can be passed to etcd
rendezvous:
+--------------------------------------------+--------------------------+
| Parameter | Description |
+============================================+==========================+
| min_workers | minimum number of |
| | workers for the |
| | rendezvous to be valid |
+--------------------------------------------+--------------------------+
| max_workers | maximum number of |
| | workers to admit |
+--------------------------------------------+--------------------------+
| timeout | total timeout within |
| | which next_rendezvous is |
| | expected to succeed |
| | (default 600s) |
+--------------------------------------------+--------------------------+
| last_call_timeout | additional wait amount |
| | ("last call") after min |
| | number of workers has |
| | been reached (defaults |
| | to 30s) |
+--------------------------------------------+--------------------------+
| etcd_prefix | path prefix (from etcd |
| | root), inside which all |
| | etcd nodes will be |
| | created (defaults to |
| | ``/torchelastic/p2p``) |
+--------------------------------------------+--------------------------+
"""
def __init__(self, rdzv_impl: "EtcdRendezvous", local_addr: Optional[str]):
"""
Args:
rdzv_impl: the implementation of the rendezvous
local_addr: the local address of the current node
"""
self._rdzv_impl = rdzv_impl
self._local_addr = local_addr
def __del__(self):
# TODO: look into using weakref here instead.
del self._rdzv_impl
def get_backend(self) -> str:
return "etcd"
def next_rendezvous(self):
rdzv_version, rank, world_size = self._rdzv_impl.rendezvous_barrier()
logger.info("Creating EtcdStore as the c10d::Store implementation")
store = self._rdzv_impl.setup_kv_store(rdzv_version)
bootstrap_store_info = RendezvousStoreInfo.build(
rank, store, local_addr=self._local_addr
)
return RendezvousInfo(store, rank, world_size, bootstrap_store_info)
def is_closed(self):
try:
_, state = self._rdzv_impl.get_rdzv_state()
return state["status"] == "closed"
except etcd.EtcdKeyNotFound:
# No rendezvous state, so it cannot be closed.
return False
def set_closed(self):
self._rdzv_impl.set_closed()
def num_nodes_waiting(self):
try:
_, state = self._rdzv_impl.get_rdzv_state()
if state["status"] == "final":
return state["num_workers_waiting"]
except etcd.EtcdKeyNotFound:
pass
return 0
def get_run_id(self) -> str:
return self._rdzv_impl._run_id
def shutdown(self) -> bool:
try:
self.set_closed()
return True
except BaseException as e:
logger.warning("Shutdown failed. Error occurred: %s", str(e))
return False
# TODO: we should probably handle a few additional errors,
# like EtcdLeaderElectionInProgress and EtcdWatcherCleared. These are
# only relevant for multi-node Etcd ensemble. A simple retry would work,
# but is verbose to add everywhere. Consider wrapping the client calls
# into auto-retry for these errors?
#
class EtcdRendezvous:
"""A rendezvous implementation that uses `etcd <https://etcd.io/>`__ as the backend store."""
def __init__(
self,
client,
prefix,
run_id,
num_min_workers,
num_max_workers,
timeout,
last_call_timeout,
):
self.client = client
logger.info("Etcd machines: %s", self.client.machines)
self._prefix = prefix
self._run_id = run_id
self._num_min_workers = num_min_workers
self._num_max_workers = num_max_workers
self._timeout = timeout
self._last_call_timeout = last_call_timeout
# For cleaning up TTL refresher threads (for ephemeral keys)
self._lease_run_id_stop = None
self._lease_this_rank_stop = None
if not self._prefix.endswith("/"):
self._prefix += "/"
# Setup a permanent prefix dir, if didn't exist
if self._prefix != "/":
self.create_path_if_not_exists(self._prefix)
# Lease a "sub-root" node specific to this job instance (run_id)
self.create_path_if_not_exists(self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL)
self._lease_run_id_stop = self.setup_lease_renewal(
self.get_path(""), ttl=CONST_RUNID_SUBROOT_TTL
)
# Subdir for all rendezvous work
self.create_path_if_not_exists(self.get_path("/rdzv"))
# Create a rendezvous version counter, if doesn't exist
try:
self.client.write(
key=self.get_path("/rdzv/version_counter"), value="0", prevExist=False
)
except etcd.EtcdAlreadyExist:
pass
def __del__(self):
# TODO: look into using weakref here instead.
if self._lease_run_id_stop is not None:
self._lease_run_id_stop.set()
if self._lease_this_rank_stop is not None:
self._lease_this_rank_stop.set()
def rendezvous_barrier(self):
"""
Main entry point for next rendezvous.
This method is blocking until rendezvous succeeds or a timeout occurs.
Returns:
``(rdzv_version, rank, world_size)``
Raises:
RendezvousTimeoutError - timeout waiting for rendezvous
RendezvousClosedError - rendezvous is or was closed while waiting
RendezvousError - other persistent errors that
render the rendezvous non-retryable
"""
self._rendezvous_deadline = time.time() + self._timeout
while True:
if time.time() > self._rendezvous_deadline:
raise RendezvousTimeoutError
logger.info("Attempting to join next rendezvous")
try:
# Dis-own our lease in the previous rendezvous, if exists
if self._lease_this_rank_stop is not None:
self._lease_this_rank_stop.set()
return self.init_phase()
except EtcdRendezvousRetryImmediately:
# The type of failure suggests we can retry without delay
pass
except EtcdRendezvousRetryableFailure:
# In case of retryable failure, wait a small delay
# to avoid spamming etcd
time.sleep(1)
except RendezvousTimeoutError:
logger.info("Rendezvous timeout occurred in EtcdRendezvousHandler")
raise
except RendezvousClosedError:
logger.info(
"Rendezvous for run_id=%s was observed to be closed", self._run_id
)
raise
except RendezvousError:
raise
except Exception as e:
# In case of a general exception, wait a small delay
# to avoid spamming etcd
# FIXME: there are a few things that fall under this like
# etcd.EtcdKeyNotFound, etc, which could be handled more explicitly.
logger.info("Rendezvous attempt failed, will retry. Reason: %s", e)
time.sleep(1)
def init_phase(self):
"""
Initially, the rendezvous state is expected to be one of:
1. empty (non-existent) - in this case we try to create a new one.
2. joinable - we try to join it.
3. final - we announce ourselves as waiting, and go into monitoring mode
Any other state is considered transitional, and will be retried after
a short delay.
Returns:
``(rdzv_version, rank, world_size)``
Raises:
RendezvousClosedError - current rendezvous was/is closed
EtcdRendezvousRetryableFailure - observed some intermediate
state, which is best handled by retrying later
"""
try:
active_version = self.try_create_rendezvous()
state = json.loads(active_version.value)
logger.info("New rendezvous state created: %s", state)
except etcd.EtcdAlreadyExist:
active_version, state = self.get_rdzv_state()
# Note: it is possible for above query to fail (etcd.EtcdKeyNotFound),
# but this is ok for us - just means we'll restart from beginning.
logger.info("Observed existing rendezvous state: %s", state)
if state["status"] == "closed":
raise RendezvousClosedError
if state["status"] == "joinable":
return self.join_phase(state["version"])
if state["status"] == "final":
self.handle_existing_rendezvous(state["version"])
raise EtcdRendezvousRetryImmediately
self.try_wait_for_state_change(etcd_index=active_version.etcd_index + 1)
raise EtcdRendezvousRetryableFailure
def join_phase(self, expected_version):
"""
We observed a rendezvous state in 'joinable' state, and attempt to join this
particular version, and then wait for all other peers to join.
"""
# Failure to join will propagate an exception, causing a re-entry.
active_version, this_rank = self.join_rendezvous(expected_version)
state = json.loads(active_version.value)
logger.info(
"Joined rendezvous version %s as rank %s. Full state: %s",
state["version"],
this_rank,
state,
)
# If this worker was first to reach num_min_workers requirement,
# and rendezvous is still joinable (therefore it is elastic),
# then this worker will be responsible for waiting out the "last call"
# timeout and closing (i.e. transitioning to 'frozen') the rendezvous
# afterwards.
# As a safety against a potential failure of this worker (during the
# last call timeout), the rendezvous state is made ephemeral
# when min_num_workers is reached.
if this_rank == self._num_min_workers - 1 and state["status"] == "joinable":
logger.info("Rank %s is responsible for join last call.", this_rank)
last_call_deadline = time.time() + self._last_call_timeout
self.handle_join_last_call(expected_version, last_call_deadline)
logger.info("Rank %s finished join last call.", this_rank)
# Wait for rendezvous state to be frozen, which means a fixed set of peers
logger.info("Waiting for remaining peers.")
active_version = self.wait_for_peers(expected_version)
state = json.loads(active_version.value)
assert (
state["version"] == expected_version
), "Logic error: failed to observe version mismatch"
return self.confirm_phase(expected_version, this_rank)
def confirm_phase(self, expected_version, this_rank):
"""
Once the rendezvous state transitions from 'joinable' to 'frozen',
we have every participant confirm their membership and setup per-member
keep-alive TTL keys, and then wait for all other participants to confirm,
which would then successfully conclude this rendezvous.
"""
logger.info("All peers arrived. Confirming membership.")
self.confirm_membership(expected_version, this_rank)
logger.info("Waiting for confirmations from all peers.")
active_version = self.wait_for_final(expected_version)
state = json.loads(active_version.value)
logger.info(
"Rendezvous version %s is complete. Final state: %s",
state["version"],
state,
)
# Rendezvous version number; our rank in it; world size
return state["version"], this_rank, len(state["participants"])
def handle_existing_rendezvous(self, expected_version):
"""
Handle the case when there's an existing (state 'final) rendezvous already
in place, and we have to announce ourselves waiting, and wait until
the next rendezvous opportunity.
"""
# If state is 'final' -> increment num_workers_waiting
# Then, observe state changes:
# 1. if it's no longer final -> bail out and re-try
# 2. if keep alives are missing, destroy it and bail out.
active_state = self.announce_self_waiting(expected_version)
logger.info(
"Added self to waiting list. Rendezvous full state: %s", active_state.value
)
self.wait_for_rendezvous_to_free(expected_version)
logger.info(
"Previously existing rendezvous state changed. Will re-try joining."
)
def try_create_rendezvous(self):
"""
Create new rendezvous state or raise an exception that indicates an unexpected state (e.g. already exists).
Raises:
RendezvousError - on unexpected state
"""
# Initially active_version is ephemeral - this is to handle the
# possibility that might fail to complete the setup transaction,
# i.e. the transition "setup" -> "joinable".
active_version = self.client.write(
key=self.get_path("/rdzv/active_version"),
value=json.dumps({"status": "setup"}),
prevExist=False,
ttl=CONST_ETCD_SETUP_TTL,
)
try:
version_counter = self.client.get(self.get_path("/rdzv/version_counter"))
version_counter.value = str(int(version_counter.value) + 1)
self.client.update(version_counter)
except (etcd.EtcdKeyNotFound, etcd.EtcdCompareFailed) as e:
raise RendezvousError(
"Unexpected state of EtcdRendezvousHandler, worker needs to die."
) from e
# Any failure below results in declaring a retryable rendezvous failure.
# The ephemeral /rdzv/active_version will expire and someone can then
# re-try the setup process.
# Create directory node for participant data
self.client.write(
key=self.get_path(f"/rdzv/v_{version_counter.value}"),
value=None,
dir=True,
prevExist=False,
)
# Publish rendezvous version and signal it is ready-to-be-joined.
# If rendezvous was set closed just before this, a retry will happen,
# where the closed condition will be handled.
return self.client.test_and_set(
key=self.get_path("/rdzv/active_version"),
value=json.dumps(
{
"status": "joinable",
"version": version_counter.value,
"participants": [],
}
),
prev_value=active_version.value,
)
def join_rendezvous(self, expected_version):
"""Helper method for the join phase."""
# Use compare-and-swap to add self to rendezvous state:
while True:
cas_delay()
active_version, state = self.get_rdzv_state()
if state["status"] != "joinable":
raise EtcdRendezvousRetryableFailure(
"Rendezvous state became non-joinable before we could join. "
"Must join next one."
)
if state["version"] != expected_version:
raise EtcdRendezvousRetryImmediately(
"Rendezvous version changed. Must try join the new one."
)
assert (
len(state["participants"]) < self._num_max_workers
), "Logic error: joinable rendezvous should always have space left"
this_rank = len(state["participants"])
state["participants"].append(this_rank)
# When reaching min workers, or changing state to frozen, we'll set
# the active_version node to be ephemeral.
set_ttl: Optional[int] = None
if len(state["participants"]) == self._num_max_workers:
state["status"] = "frozen"
state["keep_alives"] = []
set_ttl = CONST_ETCD_FROZEN_TTL
elif len(state["participants"]) >= self._num_min_workers:
set_ttl = CONST_ETCD_JOINABLE_EPHEMERAL_TTL
try:
# Compare-and-swap.
active_version = self.client.test_and_set(
key=self.get_path("/rdzv/active_version"),
value=json.dumps(state),
prev_value=active_version.value,
ttl=set_ttl,
)
# We succeeded joining.
return active_version, this_rank
except etcd.EtcdCompareFailed:
logger.info("Join rendezvous CAS unsuccessful, retrying")
def wait_for_peers(self, expected_version):
"""Helper method for the join phase."""
active_version, state = self.get_rdzv_state()
while True:
if state["status"] == "frozen" and state["version"] == expected_version:
# Success, all peers arrived.
return active_version
elif state["status"] == "joinable" and state["version"] == expected_version:
# Continue waiting for any interesting events.
active_version, state = self.try_wait_for_state_change(
etcd_index=active_version.etcd_index + 1
)
else:
# No valid transition possible at this point
raise EtcdRendezvousRetryableFailure(
"Rendezvous state transition no longer possible. Must re-enter."
)
def confirm_membership(self, expected_version, this_rank):
"""Helper method for the confirm phase."""
# Compare-and-swap loop
while True:
cas_delay()
active_version, state = self.get_rdzv_state()
if state["status"] != "frozen":
raise EtcdRendezvousRetryImmediately(
"Rendezvous no longer frozen, before we confirmed. "
"Must join next one"
)
if state["version"] != expected_version:
raise EtcdRendezvousRetryImmediately(
"Rendezvous version changed. Must try join the new one."
)
this_lease_key = self.get_path(
f"/rdzv/v_{expected_version}/rank_{this_rank}"
)
self.client.set(this_lease_key, value=None, ttl=CONST_WORKER_KEEPALIVE_TTL)
state["keep_alives"].append(this_lease_key)
if len(state["keep_alives"]) == len(state["participants"]):
# Everyone confirmed (this rank is last to do so)
state["status"] = "final"
state["num_workers_waiting"] = 0
finalize = True
else:
finalize = False
try:
# Compare-and-swap. If new state is still frozen, keep it ephemeral.
active_version = self.client.test_and_set(
key=self.get_path("/rdzv/active_version"),
value=json.dumps(state),
prev_value=active_version.value,
ttl=None if finalize else CONST_ETCD_FROZEN_TTL,
)
self._lease_this_rank_stop = self.setup_lease_renewal(
this_lease_key, ttl=CONST_WORKER_KEEPALIVE_TTL
)
return active_version
except etcd.EtcdCompareFailed:
logger.info("Confirm membership CAS unsuccessful, retrying")
def wait_for_final(self, expected_version):
"""Helper method for the confirm phase."""
active_version, state = self.get_rdzv_state()
while True:
if state["status"] == "final" and state["version"] == expected_version:
# Success. This rendezvous is final, and we accept it.
return active_version
elif state["status"] == "frozen" and state["version"] == expected_version:
# Continue waiting for any interesting events.
active_version, state = self.try_wait_for_state_change(
etcd_index=active_version.etcd_index + 1
)
else:
# No valid transition possible at this point
raise EtcdRendezvousRetryableFailure(
"Rendezvous state transition no longer possible. Must re-enter."
)
def announce_self_waiting(self, expected_version):
"""
Announce this worker is waiting (via num_workers_waiting counter) to join next
rendezvous, but only if state and version match.
"""
while True:
cas_delay()
active_version, state = self.get_rdzv_state()
if state["status"] != "final" or state["version"] != expected_version:
raise EtcdRendezvousRetryImmediately
# Increment counter to signal an additional waiting worker.
state["num_workers_waiting"] += 1
try:
active_version = self.client.test_and_set(
key=self.get_path("/rdzv/active_version"),
value=json.dumps(state),
prev_value=active_version.value,
)
return active_version
except etcd.EtcdCompareFailed:
logger.info("Announce self as waiting CAS unsuccessful, retrying")
def wait_for_rendezvous_to_free(self, expected_version):
"""
When there's an existing valid rendezvous in state 'final', we have to wait until the next opportunity to join.
Such opportunity may come from:
1. rendezvous state changed by someone else, in which case we unblock and retry.
2. rendezvous becomes invalid because at least one member failed to renew their
leased keep_alive node. We detect this, and destroy the rendezvous.
"""
active_version, state = self.get_rdzv_state()
while True:
if state["status"] != "final" or state["version"] != expected_version:
return
# Check if current rendezvous state is valid, in the sense that all
# its members are alive (renewing their lease).
# If not, try destroy this rendezvous, so a new one can be created.
alive_members = self.client.get(
self.get_path(f"/rdzv/v_{expected_version}")
)
keep_alive_keys = [ch.key for ch in alive_members.children]
for key in state["keep_alives"]:
if key not in keep_alive_keys:
# This participant didn't renew their lease. We'll declare this
# rendezvous version as dead (but only if it hadn't changed)
logger.info("Keep-alive key %s is not renewed.", key)
logger.info(
"Rendezvous version %s is incomplete. ", expected_version
)
logger.info("Attempting to destroy it.")
# Compare-and-delete operation. Throws if compare failed,
# which means rendezvous was already destroyed/re-created/closed,
# and we can try to re-enter the barrier.
self.client.delete(
key=self.get_path("/rdzv/active_version"),
prevValue=active_version.value,
)
logger.info(
"Destroyed rendezvous version %s successfully.",
expected_version,
)
# We can return (and retry) immediately
return
# Existing rendezvous seems valid, no reason to destroy it.
# We just have to wait until something changes and re-check.
try:
overall_timeout = (
max(self._rendezvous_deadline - time.time(), 0.0) + 1.0
)
self.client.watch(
key=self.get_path("/rdzv"),
index=active_version.etcd_index + 1,
recursive=True,
timeout=overall_timeout,
)
except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut):
pass
if time.time() > self._rendezvous_deadline:
raise RendezvousTimeoutError
active_version, state = self.get_rdzv_state()
def handle_join_last_call(self, expected_version, deadline):
"""
After we reach min number of workers, one particular worker takes on the
responsibility of waiting an additional timeout before closing the join window.
If the worker responsible for this fails, the rendezvous will be destroyed due
to expiring TTL, and the other participants will re-rendezvous.
Here we expect to see state <joinable, expected_version>
Exit gracefully if either:
1. state becomes <frozen, expected_version>
2. timeout happens (reaching deadline), in which case
we try the transition to <frozen, expected_version>
Exit with exception otherwise.
"""
active_version, state = self.get_rdzv_state()
while True:
if state["status"] == "frozen" and state["version"] == expected_version:
# Worker set became frozen before last-call timeout. This is possible
# when num_max_workers is reached before the timeout.
return
if state["status"] != "joinable" or state["version"] != expected_version:
raise EtcdRendezvousRetryableFailure(
"Rendezvous state transition no longer possible. Must re-enter."
)
# If timeout occurred, attempt a state transition (joinable -> frozen)
if time.time() >= deadline:
state["status"] = "frozen"
state["keep_alives"] = []
try:
active_version = self.client.test_and_set(
key=self.get_path("/rdzv/active_version"),
value=json.dumps(state),
prev_value=active_version.value,
ttl=CONST_ETCD_FROZEN_TTL,
)
# We successfully made this rendezvous frozen.
return
except etcd.EtcdCompareFailed:
logger.info(
"Join last-call transition CAS unsuccessful. Will retry"
)
cas_delay()
active_version, state = self.get_rdzv_state()
continue
# Timeout did not occur, so we must refresh TTL, and wait for
# further changes. Note: we only want TTL to be refreshed if
# state is still joinable, hence we use CAS for that here,
# even though we don't change any of the data.
try:
active_version = self.client.test_and_set(
key=self.get_path("/rdzv/active_version"),
value=active_version.value,
prev_value=active_version.value,
ttl=CONST_ETCD_JOINABLE_EPHEMERAL_TTL,
)
# Minimize "oversleeping":
timeout = min(
CONST_ETCD_JOINABLE_EPHEMERAL_TTL / 2,
deadline - time.time() + 1.0, # Oversleeping by 1s is ok.
)
active_version, state = self.try_wait_for_state_change(
etcd_index=active_version.etcd_index + 1, timeout=timeout
)
except etcd.EtcdCompareFailed:
logger.info("Join last-call TTL refresh CAS unsuccessful, will retry")
cas_delay()
active_version, state = self.get_rdzv_state()
def set_closed(self):
"""
Mark rendezvous 'closed' for current run_id, which is used to signal other
participants to not attempt to perform (re-)rendezvous. This is useful
when one of the workers decides the job is complete.
"""
while True:
active_version, state = self.get_rdzv_state()
if state["status"] == "closed":
# Already closed by someone else.
return
state["status"] = "closed"
try:
self.client.test_and_set(
key=self.get_path("/rdzv/active_version"),
value=json.dumps(state),
prev_value=active_version.value,
)
return
except etcd.EtcdCompareFailed:
logger.info("Set closed CAS unsuccessful, retrying")
cas_delay()
def get_rdzv_state(self):
active_version = self.client.get(key=self.get_path("/rdzv/active_version"))
return active_version, json.loads(active_version.value)
def try_wait_for_state_change(self, etcd_index, timeout=None):
# Don't sleep past the overall deadline (at least more than by 1s)
overall_timeout = max(self._rendezvous_deadline - time.time(), 0.0) + 1.0
timeout = overall_timeout if timeout is None else min(timeout, overall_timeout)
try:
self.client.watch(
self.get_path("/rdzv/active_version"), index=etcd_index, timeout=timeout
)
except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut):
pass
if time.time() > self._rendezvous_deadline:
raise RendezvousTimeoutError
# Unfortunately, we have to do another fetch in order to get last etcd_index.
return self.get_rdzv_state()
def get_path(self, path):
if not path.startswith("/"):
path = "/" + path
return f"{self._prefix}run_{self._run_id}{path}"
def create_path_if_not_exists(self, full_path, ttl=None):
try:
self.client.write(
key=full_path, value=None, dir=True, prevExist=False, ttl=ttl
)
except etcd.EtcdAlreadyExist:
pass
def setup_lease_renewal(self, full_path, ttl):
# NOTE: For ephemeral key TTL renewal (~lease) to work correctly,
# make sure you don't call any long-blocking methods that do not
# release the Python's GIL! An example of this is calling a pybind11
# extension function that is blocking / long-running, but is not
# doing a scoped release of the GIL.
def lease_worker(client, path, ttl, stop_event):
while True:
try:
client.refresh(path, ttl=ttl)
except etcd.EtcdKeyNotFound:
break
except ConnectionRefusedError:
# This error usually occurs during test when the server already got terminated but the
# python garbage collector have not yet invoked the __del__ method.
break
if stop_event.wait(timeout=ttl / 2):
break
lease_stop_event = threading.Event()
lease_thread = threading.Thread(
target=lease_worker, args=(self.client, full_path, ttl, lease_stop_event)
)
lease_thread.daemon = True
lease_thread.start()
return lease_stop_event
def store_extra_data(self, rdzv_version, key, value):
node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data")
try:
# If first time we are storing anything:
extra_data = self.client.write(
key=node, value=json.dumps({key: value}), prevExist=False
)
return
except etcd.EtcdAlreadyExist:
pass
# CAS loop, to make sure we don't lose concurrent stores.
while True:
# We never delete extra_data. Failure here should be fatal, no special handling.
extra_data = self.client.get(node)
new_extra_data_value = json.loads(extra_data.value)
new_extra_data_value[key] = value
try:
extra_data = self.client.test_and_set(
key=node,
value=json.dumps(new_extra_data_value),
prev_value=extra_data.value,
)
return
except etcd.EtcdCompareFailed:
logger.info("Store extra_data CAS unsuccessful, retrying")
time.sleep(0.1)
def load_extra_data(self, rdzv_version, key, timeout=None):
# 'extra_data' node itself, and the directory it is located in:
node = self.get_path(f"/rdzv/v_{rdzv_version}/extra_data")
node_dir = self.get_path(f"/rdzv/v_{rdzv_version}")
# TODO: implement timeout
# https://github.com/pytorch/elastic/issues/12
while True:
# Combined wait for the node itself, and the key inside it.
root = self.client.get(node_dir)
# Find the extra_data node, if it exists
extra_data = [n for n in root.children if n.key == node]
assert len(extra_data) <= 1
# Node for extra_data exists, check the desired key inside it.
if len(extra_data) == 1:
extra_data_dict = json.loads(extra_data[0].value)
if key in extra_data_dict:
return extra_data_dict[key]
# The 'extra_data' node doesn't exist, or they key isn't published yet.
# Wait for interesting events on the extra_data node and retry.
try:
self.client.watch(node, index=root.etcd_index + 1)
except (etcd.EtcdEventIndexCleared, etcd.EtcdWatchTimedOut):
pass
def setup_kv_store(self, rdzv_version):
store_path = self.get_path(f"/rdzv/v_{rdzv_version}/kv")
self.create_path_if_not_exists(store_path)
return EtcdStore(etcd_client=self.client, etcd_store_prefix=store_path)
def _create_etcd_client(params: RendezvousParameters) -> etcd.Client:
"""Create a new ``etcd.Client`` from the specified ``RendezvousParameters``."""
hostname, port = parse_rendezvous_endpoint(params.endpoint, 2379)
# The communication protocol
protocol = params.config.get("protocol")
if protocol is None:
protocol = "http"
else:
if protocol != "http" and protocol != "https":
raise ValueError("The etcd protocol must be HTTP or HTTPS.")
# The SSL client certificate
ssl_cert = params.config.get("cert")
if ssl_cert is not None:
cert_key = params.config.get("key")
if cert_key is not None:
# The etcd client expects the certificate key as the second element
# of the `cert` tuple.
ssl_cert = (ssl_cert, cert_key)
# The root certificate
ca_cert = params.config.get("cacert")
return etcd.Client(
hostname,
port,
protocol=protocol,
cert=ssl_cert,
ca_cert=ca_cert,
allow_reconnect=True,
)
# Handler for core.distributed "static" registration
def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
"""
Usage:
::
rdzv_params = RendezvousParameters(
backend="etcd",
endpoint="192.168.0.42:2379",
run_id="123",
min_nodes=4,
max_nodes=8,
timeout=300,
last_call_timeout=30,
etcd_prefix="custom_prefix",
protocol="https",
cacert="/etc/kubernetes/certs/ca.crt",
cert="/etc/kubernetes/certs/client.crt",
key="/etc/kubernetes/certs/client.key")
# -- or --
rdzv_params = RendezvousParameters(
backend="etcd",
endpoint="192.168.0.42:2379",
run_id="123",
min_nodes=4,
max_nodes=8)
etcd_rdzv_handler = create_etcd_rendezvous_handler(rdzv_params)
Where:
run_id - unique id for this training job instance,
min_nodes - min number of workers expected to join the rendezvous,
max_nodes - max number of workers allowed to join the rendezvous,
defaults to min_workers is not specified.
timeout - total timeout within which next_rendezvous is expected to
succeed; a RendezvousTimeoutError is raised otherwise;
Defaults is 600 (10 minutes).
last_call_timeout - additional wait amount ("last call") after
min number of workers has been reached.
Defaults to 30 seconds.
etcd_prefix - path prefix (from etcd root), inside which all
etcd nodes will be created.
Default is "/torchelastic/p2p".
protocol - http (default) or https to access etcd.
cacert - CA cert to access etcd, only makes sense with https.
cert - client cert to access etcd, only makes sense with https.
key - client key to access etcd, only makes sense with https.
"""
client = _create_etcd_client(params)
etcd_prefix = params.get("etcd_prefix", "/torchelastic/p2p")
rdzv = EtcdRendezvous(
client=client,
prefix=etcd_prefix,
run_id=params.run_id,
num_min_workers=params.min_nodes,
num_max_workers=params.max_nodes,
timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT),
last_call_timeout=params.get_as_int(
"last_call_timeout", _DEFAULT_LAST_CALL_TIMEOUT
),
)
return EtcdRendezvousHandler(
rdzv_impl=rdzv,
local_addr=params.local_addr,
)

+ 0
- 217
mindnlp/core/distributed/elastic/rendezvous/etcd_rendezvous_backend.py View File

@@ -1,217 +0,0 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import binascii
from base64 import b64decode, b64encode
from typing import cast, Optional, Tuple
import urllib3.exceptions # type: ignore[import]
from etcd import ( # type: ignore[import]
Client as EtcdClient,
EtcdAlreadyExist,
EtcdCompareFailed,
EtcdException,
EtcdKeyNotFound,
EtcdResult,
)
from core.distributed import Store
from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError
from .dynamic_rendezvous import RendezvousBackend, Token
from .etcd_store import EtcdStore
from .utils import parse_rendezvous_endpoint
class EtcdRendezvousBackend(RendezvousBackend):
"""Represents an etcd-based rendezvous backend.
Args:
client:
The ``etcd.Client`` instance to use to communicate with etcd.
run_id:
The run id of the rendezvous.
key_prefix:
The path under which to store the rendezvous state in etcd.
ttl:
The TTL of the rendezvous state. If not specified, defaults to two hours.
"""
_DEFAULT_TTL = 7200 # 2 hours
_client: EtcdClient
_key: str
_ttl: int
def __init__(
self,
client: EtcdClient,
run_id: str,
key_prefix: Optional[str] = None,
ttl: Optional[int] = None,
) -> None:
if not run_id:
raise ValueError("The run id must be a non-empty string.")
self._client = client
if key_prefix:
self._key = key_prefix + "/" + run_id
else:
self._key = run_id
if ttl and ttl > 0:
self._ttl = ttl
else:
self._ttl = self._DEFAULT_TTL
@property
def name(self) -> str:
"""See base class."""
return "etcd-v2"
def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""See base class."""
try:
result = self._client.read(self._key)
except EtcdKeyNotFound:
return None
except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
raise RendezvousConnectionError(
"The connection to etcd has failed. See inner exception for details."
) from exc
return self._decode_state(result)
def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""See base class."""
base64_state = b64encode(state).decode()
kwargs = {}
def get_state():
result = self.get_state()
if result is not None:
tmp = *result, False
# Python 3.6 does not support tuple unpacking in return
# statements.
return tmp
return None
if token:
try:
token = int(token)
except ValueError:
return get_state()
if token:
kwargs["prevIndex"] = token
else:
kwargs["prevExist"] = False
try:
result = self._client.write(self._key, base64_state, self._ttl, **kwargs)
except (EtcdAlreadyExist, EtcdCompareFailed):
result = None
except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
raise RendezvousConnectionError(
"The connection to etcd has failed. See inner exception for details."
) from exc
if result is None:
return get_state()
tmp = *self._decode_state(result), True
return tmp
def _decode_state(self, result: EtcdResult) -> Tuple[bytes, Token]:
base64_state = result.value.encode()
try:
state = b64decode(base64_state)
except binascii.Error as exc:
raise RendezvousStateError(
"The state object is corrupt. See inner exception for details."
) from exc
return state, result.modifiedIndex
def _create_etcd_client(params: RendezvousParameters) -> EtcdClient:
host, port = parse_rendezvous_endpoint(params.endpoint, default_port=2379)
# The timeout
read_timeout = cast(int, params.get_as_int("read_timeout", 60))
if read_timeout <= 0:
raise ValueError("The read timeout must be a positive integer.")
# The communication protocol
protocol = params.get("protocol", "http").strip().lower()
if protocol != "http" and protocol != "https":
raise ValueError("The protocol must be HTTP or HTTPS.")
# The SSL client certificate
ssl_cert = params.get("ssl_cert")
if ssl_cert:
ssl_cert_key = params.get("ssl_cert_key")
if ssl_cert_key:
# The etcd client expects the certificate key as the second element
# of the `cert` tuple.
ssl_cert = (ssl_cert, ssl_cert_key)
# The root certificate
ca_cert = params.get("ca_cert")
try:
return EtcdClient(
host,
port,
read_timeout=read_timeout,
protocol=protocol,
cert=ssl_cert,
ca_cert=ca_cert,
allow_reconnect=True,
)
except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
raise RendezvousConnectionError(
"The connection to etcd has failed. See inner exception for details."
) from exc
def create_backend(params: RendezvousParameters) -> Tuple[EtcdRendezvousBackend, Store]:
"""Create a new :py:class:`EtcdRendezvousBackend` from the specified parameters.
+--------------+-----------------------------------------------------------+
| Parameter | Description |
+==============+===========================================================+
| read_timeout | The read timeout, in seconds, for etcd operations. |
| | Defaults to 60 seconds. |
+--------------+-----------------------------------------------------------+
| protocol | The protocol to use to communicate with etcd. Valid |
| | values are "http" and "https". Defaults to "http". |
+--------------+-----------------------------------------------------------+
| ssl_cert | The path to the SSL client certificate to use along with |
| | HTTPS. Defaults to ``None``. |
+--------------+-----------------------------------------------------------+
| ssl_cert_key | The path to the private key of the SSL client certificate |
| | to use along with HTTPS. Defaults to ``None``. |
+--------------+-----------------------------------------------------------+
| ca_cert | The path to the rool SSL authority certificate. Defaults |
| | to ``None``. |
+--------------+-----------------------------------------------------------+
"""
client = _create_etcd_client(params)
backend = EtcdRendezvousBackend(
client, params.run_id, key_prefix="/torch/elastic/rendezvous"
)
store = EtcdStore(client, "/torch/elastic/store")
return backend, store

+ 0
- 248
mindnlp/core/distributed/elastic/rendezvous/etcd_server.py View File

@@ -1,248 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import atexit
import logging
import os
import shlex
import shutil
import socket
import subprocess
import tempfile
import time
from typing import Optional, TextIO, Union
try:
import etcd # type: ignore[import]
except ModuleNotFoundError:
pass
logger = logging.getLogger(__name__)
def find_free_port():
"""
Find a free port and binds a temporary socket to it so that the port can be "reserved" until used.
.. note:: the returned socket must be closed before using the port,
otherwise a ``address already in use`` error will happen.
The socket should be held and closed as close to the
consumer of the port as possible since otherwise, there
is a greater chance of race-condition where a different
process may see the port as being free and take it.
Returns: a socket binded to the reserved free port
Usage::
sock = find_free_port()
port = sock.getsockname()[1]
sock.close()
use_port(port)
"""
addrs = socket.getaddrinfo(
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
try:
s = socket.socket(family, type, proto)
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
s.close() # type: ignore[possibly-undefined]
print(f"Socket creation attempt failed: {e}")
raise RuntimeError("Failed to create a socket")
def stop_etcd(subprocess, data_dir: Optional[str] = None):
if subprocess and subprocess.poll() is None:
logger.info("stopping etcd server")
subprocess.terminate()
subprocess.wait()
if data_dir:
logger.info("deleting etcd data dir: %s", data_dir)
shutil.rmtree(data_dir, ignore_errors=True)
class EtcdServer:
"""
.. note:: tested on etcd server v3.4.3.
Starts and stops a local standalone etcd server on a random free
port. Useful for single node, multi-worker launches or testing,
where a sidecar etcd server is more convenient than having to
separately setup an etcd server.
This class registers a termination handler to shutdown the etcd
subprocess on exit. This termination handler is NOT a substitute for
calling the ``stop()`` method.
The following fallback mechanism is used to find the etcd binary:
1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH
2. Uses ``<this file root>/bin/etcd`` if one exists
3. Uses ``etcd`` from ``PATH``
Usage
::
server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd")
server.start()
client = server.get_client()
# use client
server.stop()
Args:
etcd_binary_path: path of etcd server binary (see above for fallback path)
"""
def __init__(self, data_dir: Optional[str] = None):
self._port = -1
self._host = "localhost"
root = os.path.dirname(__file__)
default_etcd_bin = os.path.join(root, "bin/etcd")
self._etcd_binary_path = os.environ.get(
"TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin
)
if not os.path.isfile(self._etcd_binary_path):
self._etcd_binary_path = "etcd"
self._base_data_dir = (
data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data")
)
self._etcd_cmd = None
self._etcd_proc: Optional[subprocess.Popen] = None
def _get_etcd_server_process(self) -> subprocess.Popen:
if not self._etcd_proc:
raise RuntimeError(
"No etcd server process started. Call etcd_server.start() first"
)
else:
return self._etcd_proc
def get_port(self) -> int:
"""Return the port the server is running on."""
return self._port
def get_host(self) -> str:
"""Return the host the server is running on."""
return self._host
def get_endpoint(self) -> str:
"""Return the etcd server endpoint (host:port)."""
return f"{self._host}:{self._port}"
def start(
self,
timeout: int = 60,
num_retries: int = 3,
stderr: Union[int, TextIO, None] = None,
) -> None:
"""
Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests.
Args:
timeout: time (in seconds) to wait for the server to be ready
before giving up.
num_retries: number of retries to start the server. Each retry
will wait for max ``timeout`` before considering it as failed.
stderr: the standard error file handle. Valid values are
`subprocess.PIPE`, `subprocess.DEVNULL`, an existing file
descriptor (a positive integer), an existing file object, and
`None`.
Raises:
TimeoutError: if the server is not ready within the specified timeout
"""
curr_retries = 0
while True:
try:
data_dir = os.path.join(self._base_data_dir, str(curr_retries))
os.makedirs(data_dir, exist_ok=True)
return self._start(data_dir, timeout, stderr)
except Exception as e:
curr_retries += 1
stop_etcd(self._etcd_proc)
logger.warning(
"Failed to start etcd server, got error: %s, retrying", str(e)
)
if curr_retries >= num_retries:
shutil.rmtree(self._base_data_dir, ignore_errors=True)
raise
atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir)
def _start(
self, data_dir: str, timeout: int = 60, stderr: Union[int, TextIO, None] = None
) -> None:
sock = find_free_port()
sock_peer = find_free_port()
self._port = sock.getsockname()[1]
peer_port = sock_peer.getsockname()[1]
etcd_cmd = shlex.split(
" ".join(
[
self._etcd_binary_path,
"--enable-v2",
"--data-dir",
data_dir,
"--listen-client-urls",
f"http://{self._host}:{self._port}",
"--advertise-client-urls",
f"http://{self._host}:{self._port}",
"--listen-peer-urls",
f"http://{self._host}:{peer_port}",
]
)
)
logger.info("Starting etcd server: [%s]", etcd_cmd)
sock.close()
sock_peer.close()
self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True, stderr=stderr)
self._wait_for_ready(timeout)
def get_client(self):
"""Return an etcd client object that can be used to make requests to this server."""
return etcd.Client(
host=self._host, port=self._port, version_prefix="/v2", read_timeout=10
)
def _wait_for_ready(self, timeout: int = 60) -> None:
client = etcd.Client(
host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5
)
max_time = time.time() + timeout
while time.time() < max_time:
if self._get_etcd_server_process().poll() is not None:
# etcd server process finished
exitcode = self._get_etcd_server_process().returncode
raise RuntimeError(
f"Etcd server process exited with the code: {exitcode}"
)
try:
logger.info("etcd server ready. version: %s", client.version)
return
except Exception:
time.sleep(1)
raise TimeoutError("Timed out waiting for etcd server to be ready!")
def stop(self) -> None:
"""Stop the server and cleans up auto generated resources (e.g. data dir)."""
logger.info("EtcdServer stop method called")
stop_etcd(self._etcd_proc, self._base_data_dir)

+ 0
- 212
mindnlp/core/distributed/elastic/rendezvous/etcd_store.py View File

@@ -1,212 +0,0 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import random
import time
from base64 import b64decode, b64encode
from typing import Optional
import etcd # type: ignore[import]
# pyre-ignore[21]: Could not find name `Store` in `core.distributed`.
from core.distributed import Store
# Delay (sleep) for a small random amount to reduce CAS failures.
# This does not affect correctness, but will reduce requests to etcd server.
def cas_delay():
time.sleep(random.uniform(0, 0.1))
# pyre-fixme[11]: Annotation `Store` is not defined as a type.
class EtcdStore(Store):
"""
Implement a c10 Store interface by piggybacking on the rendezvous etcd instance.
This is the store object returned by ``EtcdRendezvous``.
"""
def __init__(
self,
etcd_client,
etcd_store_prefix,
# Default timeout same as in c10d/Store.hpp
timeout: Optional[datetime.timedelta] = None,
):
super().__init__() # required for pybind trampoline.
self.client = etcd_client
self.prefix = etcd_store_prefix
if timeout is not None:
self.set_timeout(timeout)
if not self.prefix.endswith("/"):
self.prefix += "/"
def set(self, key, value):
"""
Write a key/value pair into ``EtcdStore``.
Both key and value may be either Python ``str`` or ``bytes``.
"""
self.client.set(key=self.prefix + self._encode(key), value=self._encode(value))
def get(self, key) -> bytes:
"""
Get a value by key, possibly doing a blocking wait.
If key is not immediately present, will do a blocking wait
for at most ``timeout`` duration or until the key is published.
Returns:
value ``(bytes)``
Raises:
LookupError - If key still not published after timeout
"""
b64_key = self.prefix + self._encode(key)
kvs = self._try_wait_get([b64_key])
if kvs is None:
raise LookupError(f"Key {key} not found in EtcdStore")
return self._decode(kvs[b64_key])
def add(self, key, num: int) -> int:
"""
Atomically increment a value by an integer amount.
The integer is represented as a string using base 10. If key is not present,
a default value of ``0`` will be assumed.
Returns:
the new (incremented) value
"""
b64_key = self._encode(key)
# c10d Store assumes value is an integer represented as a decimal string
try:
# Assume default value "0", if this key didn't yet:
node = self.client.write(
key=self.prefix + b64_key,
value=self._encode(str(num)), # i.e. 0 + num
prevExist=False,
)
return int(self._decode(node.value))
except etcd.EtcdAlreadyExist:
pass
while True:
# Note: c10d Store does not have a method to delete keys, so we
# can be sure it's still there.
node = self.client.get(key=self.prefix + b64_key)
new_value = self._encode(str(int(self._decode(node.value)) + num))
try:
node = self.client.test_and_set(
key=node.key, value=new_value, prev_value=node.value
)
return int(self._decode(node.value))
except etcd.EtcdCompareFailed:
cas_delay()
def wait(self, keys, override_timeout: Optional[datetime.timedelta] = None):
"""
Wait until all of the keys are published, or until timeout.
Raises:
LookupError - if timeout occurs
"""
b64_keys = [self.prefix + self._encode(key) for key in keys]
kvs = self._try_wait_get(b64_keys, override_timeout)
if kvs is None:
raise LookupError("Timeout while waiting for keys in EtcdStore")
# No return value on success
def check(self, keys) -> bool:
"""Check if all of the keys are immediately present (without waiting)."""
b64_keys = [self.prefix + self._encode(key) for key in keys]
kvs = self._try_wait_get(
b64_keys,
override_timeout=datetime.timedelta(microseconds=1), # as if no wait
)
return kvs is not None
#
# Encode key/value data in base64, so we can store arbitrary binary data
# in EtcdStore. Input can be `str` or `bytes`.
# In case of `str`, utf-8 encoding is assumed.
#
def _encode(self, value) -> str:
if type(value) == bytes:
return b64encode(value).decode()
elif type(value) == str:
return b64encode(value.encode()).decode()
raise ValueError("Value must be of type str or bytes")
#
# Decode a base64 string (of type `str` or `bytes`).
# Return type is `bytes`, which is more convenient with the Store interface.
#
def _decode(self, value) -> bytes:
if type(value) == bytes:
return b64decode(value)
elif type(value) == str:
return b64decode(value.encode())
raise ValueError("Value must be of type str or bytes")
#
# Get all of the (base64-encoded) etcd keys at once, or wait until all the keys
# are published or timeout occurs.
# This is a helper method for the public interface methods.
#
# On success, a dictionary of {etcd key -> etcd value} is returned.
# On timeout, None is returned.
#
def _try_wait_get(self, b64_keys, override_timeout=None):
timeout = self.timeout if override_timeout is None else override_timeout # type: ignore[attr-defined]
deadline = time.time() + timeout.total_seconds()
while True:
# Read whole directory (of keys), filter only the ones waited for
all_nodes = None
try:
all_nodes = self.client.get(key=self.prefix)
req_nodes = {
node.key: node.value
for node in all_nodes.children
if node.key in b64_keys
}
if len(req_nodes) == len(b64_keys):
# All keys are available
return req_nodes
except etcd.EtcdKeyNotFound:
pass
watch_timeout = deadline - time.time()
if watch_timeout <= 0:
return None
try:
index = all_nodes.etcd_index + 1 if all_nodes else 0
self.client.watch(
key=self.prefix,
recursive=True,
timeout=watch_timeout,
index=index,
)
except etcd.EtcdWatchTimedOut:
if time.time() >= deadline:
return None
else:
continue
except etcd.EtcdEventIndexCleared:
continue

+ 0
- 96
mindnlp/core/distributed/elastic/rendezvous/registry.py View File

@@ -1,96 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import sys
from .api import (
rendezvous_handler_registry as handler_registry,
RendezvousHandler,
RendezvousParameters,
)
from .dynamic_rendezvous import create_handler
if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points
log = logging.getLogger(__name__)
__all__ = ["get_rendezvous_handler"]
def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler:
from . import static_tcp_rendezvous
return static_tcp_rendezvous.create_rdzv_handler(params)
def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler:
from . import etcd_rendezvous
return etcd_rendezvous.create_rdzv_handler(params)
def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler:
from .etcd_rendezvous_backend import create_backend
backend, store = create_backend(params)
return create_handler(store, backend, params)
def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler:
from .c10d_rendezvous_backend import create_backend
backend, store = create_backend(params)
return create_handler(store, backend, params)
def _register_default_handlers() -> None:
handler_registry.register("etcd", _create_etcd_handler)
handler_registry.register("etcd-v2", _create_etcd_v2_handler)
handler_registry.register("c10d", _create_c10d_handler)
handler_registry.register("static", _create_static_handler)
def _register_out_of_tree_handlers() -> None:
discovered_handler_generators = entry_points(group="torchrun.handlers")
for handler_generator in discovered_handler_generators:
try:
get_handler = discovered_handler_generators[handler_generator.name].load()
handler_registry.register(handler_generator.name, get_handler())
except Exception:
log.warning(
"Exception while registering out of tree plugin %s: ",
handler_generator.name,
exc_info=True,
)
def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
"""
Obtain a reference to a :py:class`RendezvousHandler`.
Custom rendezvous handlers can be registered by
::
from core.distributed.elastic.rendezvous import rendezvous_handler_registry
from core.distributed.elastic.rendezvous.registry import get_rendezvous_handler
def create_my_rdzv(params: RendezvousParameters):
return MyCustomRdzv(params)
rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv)
my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters)
"""
return handler_registry.create_handler(params)

+ 0
- 128
mindnlp/core/distributed/elastic/rendezvous/static_tcp_rendezvous.py View File

@@ -1,128 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import logging
from typing import cast, Optional
from core.distributed import PrefixStore, Store, TCPStore
from core.distributed.elastic.rendezvous import (
RendezvousHandler,
RendezvousInfo,
RendezvousParameters,
RendezvousStoreInfo,
)
from core.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
__all__ = ["StaticTCPRendezvous", "create_rdzv_handler"]
logger = logging.getLogger(__name__)
_default_timeout_seconds = 600
class StaticTCPRendezvous(RendezvousHandler):
"""
Static rendezvous that is a wrapper around the TCPStore.
Creates TCPStore based on the input parameters with the
listener on the agent with group_rank=0
"""
def __init__(
self,
master_addr: str,
master_port: int,
rank: int,
world_size: int,
run_id: str,
timeout: int,
):
self.master_addr = master_addr
self.master_port = master_port
self.rank = rank
self.world_size = world_size
self.run_id = run_id
self.timeout = datetime.timedelta(seconds=timeout)
self._store: Optional[Store] = None
def get_backend(self) -> str:
return "static"
@property
def use_agent_store(self) -> bool:
return True
def next_rendezvous(self) -> RendezvousInfo:
logger.info("Creating TCPStore as the c10d::Store implementation")
is_master = self.rank == 0
if not self._store:
self._store = TCPStore( # type: ignore[call-arg]
self.master_addr,
self.master_port,
self.world_size,
is_master,
self.timeout,
multi_tenant=True,
)
store = PrefixStore(self.run_id, self._store)
# TCPStore server instance is used by trainer code
bootstrap_store_info = RendezvousStoreInfo(self.master_addr, self.master_port)
return RendezvousInfo(
store,
self.rank,
self.world_size,
bootstrap_store_info,
)
def is_closed(self):
return False
def set_closed(self):
pass
def num_nodes_waiting(self):
return 0
def get_run_id(self) -> str:
return self.run_id
def shutdown(self) -> bool:
return True
def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
if "rank" not in params.config:
raise ValueError(
"rank is absent in RendezvousParameters."
"Try add --node-rank to the cmd request"
)
endpoint = params.endpoint.strip()
if not endpoint:
raise ValueError(
"endpoint is absent in RendezvousParameters"
"Try add --master-port and --master-addr to the cmd request"
)
master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1)
if master_port == -1:
raise ValueError(
f"Port is absent in endpoint: {endpoint}. Try launching with --master-port"
)
world_size = params.max_nodes
rank = cast(int, params.config.get("rank"))
run_id = params.run_id
if "timeout" in params.config:
timeout = int(params.config["timeout"])
else:
timeout = _default_timeout_seconds
return StaticTCPRendezvous(
master_addr, master_port, rank, world_size, run_id, timeout
)

+ 0
- 284
mindnlp/core/distributed/elastic/rendezvous/utils.py View File

@@ -1,284 +0,0 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import ipaddress
import random
import re
import socket
import time
import weakref
from datetime import timedelta
from threading import Event, Thread
from typing import Any, Callable, Dict, Optional, Tuple, Union
__all__ = ["parse_rendezvous_endpoint"]
def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
"""Extract key-value pairs from a rendezvous configuration string.
Args:
config_str:
A string in format <key1>=<value1>,...,<keyN>=<valueN>.
"""
config: Dict[str, str] = {}
config_str = config_str.strip()
if not config_str:
return config
key_values = config_str.split(",")
for kv in key_values:
key, *values = kv.split("=", 1)
key = key.strip()
if not key:
raise ValueError(
"The rendezvous configuration string must be in format "
"<key1>=<value1>,...,<keyN>=<valueN>."
)
value: Optional[str]
if values:
value = values[0].strip()
else:
value = None
if not value:
raise ValueError(
f"The rendezvous configuration option '{key}' must have a value specified."
)
config[key] = value
return config
def _try_parse_port(port_str: str) -> Optional[int]:
"""Try to extract the port number from ``port_str``."""
if port_str and re.match(r"^[0-9]{1,5}$", port_str):
return int(port_str)
return None
def parse_rendezvous_endpoint(
endpoint: Optional[str], default_port: int
) -> Tuple[str, int]:
"""Extract the hostname and the port number from a rendezvous endpoint.
Args:
endpoint:
A string in format <hostname>[:<port>].
default_port:
The port number to use if the endpoint does not include one.
Returns:
A tuple of hostname and port number.
"""
if endpoint is not None:
endpoint = endpoint.strip()
if not endpoint:
return ("localhost", default_port)
# An endpoint that starts and ends with brackets represents an IPv6 address.
if endpoint[0] == "[" and endpoint[-1] == "]":
host, *rest = endpoint, *[]
else:
host, *rest = endpoint.rsplit(":", 1)
# Sanitize the IPv6 address.
if len(host) > 1 and host[0] == "[" and host[-1] == "]":
host = host[1:-1]
if len(rest) == 1:
port = _try_parse_port(rest[0])
if port is None or port >= 2**16:
raise ValueError(
f"The port number of the rendezvous endpoint '{endpoint}' must be an integer "
"between 0 and 65536."
)
else:
port = default_port
if not re.match(r"^[\w\.:-]+$", host):
raise ValueError(
f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of "
"labels, an IPv4 address, or an IPv6 address."
)
return host, port
def _matches_machine_hostname(host: str) -> bool:
"""Indicate whether ``host`` matches the hostname of this machine.
This function compares ``host`` to the hostname as well as to the IP
addresses of this machine. Note that it may return a false negative if this
machine has CNAME records beyond its FQDN or IP addresses assigned to
secondary NICs.
"""
if host == "localhost":
return True
try:
addr = ipaddress.ip_address(host)
except ValueError:
addr = None
if addr and addr.is_loopback:
return True
try:
host_addr_list = socket.getaddrinfo(
host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
)
except (ValueError, socket.gaierror) as _:
host_addr_list = []
host_ip_list = [host_addr_info[4][0] for host_addr_info in host_addr_list]
this_host = socket.gethostname()
if host == this_host:
return True
addr_list = socket.getaddrinfo(
this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
)
for addr_info in addr_list:
# If we have an FQDN in the addr_info, compare it to `host`.
if addr_info[3] and addr_info[3] == host:
return True
# Otherwise if `host` represents an IP address, compare it to our IP
# address.
if addr and addr_info[4][0] == str(addr):
return True
# If the IP address matches one of the provided host's IP addresses
if addr_info[4][0] in host_ip_list:
return True
return False
def _delay(seconds: Union[float, Tuple[float, float]]) -> None:
"""Suspend the current thread for ``seconds``.
Args:
seconds:
Either the delay, in seconds, or a tuple of a lower and an upper
bound within which a random delay will be picked.
"""
if isinstance(seconds, tuple):
seconds = random.uniform(*seconds)
# Ignore delay requests that are less than 10 milliseconds.
if seconds >= 0.01:
time.sleep(seconds)
class _PeriodicTimer:
"""Represent a timer that periodically runs a specified function.
Args:
interval:
The interval, in seconds, between each run.
function:
The function to run.
"""
# The state of the timer is hold in a separate context object to avoid a
# reference cycle between the timer and the background thread.
class _Context:
interval: float
function: Callable[..., None]
args: Tuple[Any, ...]
kwargs: Dict[str, Any]
stop_event: Event
_name: Optional[str]
_thread: Optional[Thread]
_finalizer: Optional[weakref.finalize]
# The context that is shared between the timer and the background thread.
_ctx: _Context
def __init__(
self,
interval: timedelta,
function: Callable[..., None],
*args: Any,
**kwargs: Any,
) -> None:
self._name = None
self._ctx = self._Context()
self._ctx.interval = interval.total_seconds()
self._ctx.function = function # type: ignore[assignment]
self._ctx.args = args or ()
self._ctx.kwargs = kwargs or {}
self._ctx.stop_event = Event()
self._thread = None
self._finalizer = None
@property
def name(self) -> Optional[str]:
"""Get the name of the timer."""
return self._name
def set_name(self, name: str) -> None:
"""Set the name of the timer.
The specified name will be assigned to the background thread and serves
for debugging and troubleshooting purposes.
"""
if self._thread:
raise RuntimeError("The timer has already started.")
self._name = name
def start(self) -> None:
"""Start the timer."""
if self._thread:
raise RuntimeError("The timer has already started.")
self._thread = Thread(
target=self._run,
name=self._name or "PeriodicTimer",
args=(self._ctx,),
daemon=True,
)
# We avoid using a regular finalizer (a.k.a. __del__) for stopping the
# timer as joining a daemon thread during the interpreter shutdown can
# cause deadlocks. The weakref.finalize is a superior alternative that
# provides a consistent behavior regardless of the GC implementation.
self._finalizer = weakref.finalize(
self, self._stop_thread, self._thread, self._ctx.stop_event
)
# We do not attempt to stop our background thread during the interpreter
# shutdown. At that point we do not even know whether it still exists.
self._finalizer.atexit = False
self._thread.start()
def cancel(self) -> None:
"""Stop the timer at the next opportunity."""
if self._finalizer:
self._finalizer()
@staticmethod
def _run(ctx) -> None:
while not ctx.stop_event.wait(ctx.interval):
ctx.function(*ctx.args, **ctx.kwargs)
@staticmethod
def _stop_thread(thread, stop_event):
stop_event.set()
thread.join()

+ 0
- 54
mindnlp/core/distributed/elastic/timer/__init__.py View File

@@ -1,54 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Expiration timers are set up on the same process as the agent and
used from your script to deal with stuck workers. When you go into
a code-block that has the potential to get stuck you can acquire
an expiration timer, which instructs the timer server to kill the
process if it does not release the timer by the self-imposed expiration
deadline.
Usage::
from mindnlp import coreelastic.timer as timer
from mindnlp import coreelastic.agent.server as agent
def main():
start_method = "spawn"
message_queue = mp.get_context(start_method).Queue()
server = timer.LocalTimerServer(message, max_interval=0.01)
server.start() # non-blocking
spec = WorkerSpec(
fn=trainer_func,
args=(message_queue,),
...<OTHER_PARAMS...>)
agent = agent.LocalElasticAgent(spec, start_method)
agent.run()
def trainer_func(message_queue):
timer.configure(timer.LocalTimerClient(message_queue))
with timer.expires(after=60): # 60 second expiry
# do some work
In the example above if ``trainer_func`` takes more than 60 seconds to
complete, then the worker process is killed and the agent retries the worker group.
"""
from .api import ( # noqa: F401
configure,
expires,
TimerClient,
TimerRequest,
TimerServer,
)
from .file_based_local_timer import ( # noqa: F401
FileTimerClient,
FileTimerRequest,
FileTimerServer,
)
from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401

+ 0
- 283
mindnlp/core/distributed/elastic/timer/api.py View File

@@ -1,283 +0,0 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import abc
import logging
import threading
import time
from contextlib import contextmanager
from inspect import getframeinfo, stack
from typing import Any, Dict, List, Optional, Set
__all__ = [
"TimerRequest",
"TimerClient",
"RequestQueue",
"TimerServer",
"configure",
"expires",
]
logger = logging.getLogger(__name__)
class TimerRequest:
"""
Data object representing a countdown timer acquisition and release
that is used between the ``TimerClient`` and ``TimerServer``.
A negative ``expiration_time`` should be interpreted as a "release"
request.
.. note:: the type of ``worker_id`` is implementation specific.
It is whatever the TimerServer and TimerClient implementations
have on to uniquely identify a worker.
"""
__slots__ = ["worker_id", "scope_id", "expiration_time"]
def __init__(self, worker_id: Any, scope_id: str, expiration_time: float):
self.worker_id = worker_id
self.scope_id = scope_id
self.expiration_time = expiration_time
def __eq__(self, other):
if isinstance(other, TimerRequest):
return (
self.worker_id == other.worker_id
and self.scope_id == other.scope_id
and self.expiration_time == other.expiration_time
)
return False
class TimerClient(abc.ABC):
"""
Client library to acquire and release countdown timers by communicating
with the TimerServer.
"""
@abc.abstractmethod
def acquire(self, scope_id: str, expiration_time: float) -> None:
"""
Acquires a timer for the worker that holds this client object
given the scope_id and expiration_time. Typically registers
the timer with the TimerServer.
"""
@abc.abstractmethod
def release(self, scope_id: str):
"""
Releases the timer for the ``scope_id`` on the worker this
client represents. After this method is
called, the countdown timer on the scope is no longer in effect.
"""
class RequestQueue(abc.ABC):
"""
Consumer queue holding timer acquisition/release requests
"""
@abc.abstractmethod
def size(self) -> int:
"""
Returns the size of the queue at the time this method is called.
Note that by the time ``get`` is called the size of the queue
may have increased. The size of the queue should not decrease
until the ``get`` method is called. That is, the following assertion
should hold:
size = q.size()
res = q.get(size, timeout=0)
assert size == len(res)
-- or --
size = q.size()
res = q.get(size * 2, timeout=1)
assert size <= len(res) <= size * 2
"""
@abc.abstractmethod
def get(self, size: int, timeout: float) -> List[TimerRequest]:
"""
Gets up to ``size`` number of timer requests in a blocking fashion
(no more than ``timeout`` seconds).
"""
class TimerServer(abc.ABC):
"""
Entity that monitors active timers and expires them
in a timely fashion. This server is responsible for
reaping workers that have expired timers.
"""
def __init__(
self, request_queue: RequestQueue, max_interval: float, daemon: bool = True
):
"""
:param request_queue: Consumer ``RequestQueue``
:param max_interval: max time (in seconds) to wait
for an item in the request_queue
:param daemon: whether to run the watchdog thread as a daemon
"""
super().__init__()
self._request_queue = request_queue
self._max_interval = max_interval
self._daemon = daemon
self._watchdog_thread: Optional[threading.Thread] = None
self._stop_signaled = False
@abc.abstractmethod
def register_timers(self, timer_requests: List[TimerRequest]) -> None:
"""
Processes the incoming timer requests and registers them with the server.
The timer request can either be a acquire-timer or release-timer request.
Timer requests with a negative expiration_time should be interpreted
as a release-timer request.
"""
@abc.abstractmethod
def clear_timers(self, worker_ids: Set[Any]) -> None:
"""
Clears all timers for the given ``worker_ids``.
"""
@abc.abstractmethod
def get_expired_timers(self, deadline: float) -> Dict[str, List[TimerRequest]]:
"""
Returns all expired timers for each worker_id. An expired timer
is a timer for which the expiration_time is less than or equal to
the provided deadline.
"""
@abc.abstractmethod
def _reap_worker(self, worker_id: Any) -> bool:
"""
Reaps the given worker. Returns True if the worker has been
successfully reaped, False otherwise. If any uncaught exception
is thrown from this method, the worker is considered reaped
and all associated timers will be removed.
"""
def _reap_worker_no_throw(self, worker_id: Any) -> bool:
"""
Wraps ``_reap_worker(worker_id)``, if an uncaught exception is
thrown, then it considers the worker as reaped.
"""
try:
return self._reap_worker(worker_id)
except Exception:
logger.exception(
"Uncaught exception thrown from _reap_worker(), "
"check that the implementation correctly catches exceptions",
)
return True
def _watchdog_loop(self):
while not self._stop_signaled:
try:
self._run_watchdog()
except Exception:
logger.exception("Error running watchdog")
def _run_watchdog(self):
batch_size = max(1, self._request_queue.size())
timer_requests = self._request_queue.get(batch_size, self._max_interval)
self.register_timers(timer_requests)
now = time.time()
reaped_worker_ids = set()
for worker_id, expired_timers in self.get_expired_timers(now).items():
logger.info(
"Reaping worker_id=[%s]." " Expired timers: %s",
worker_id,
self._get_scopes(expired_timers),
)
if self._reap_worker_no_throw(worker_id):
logger.info("Successfully reaped worker=[%s]", worker_id)
reaped_worker_ids.add(worker_id)
else:
logger.error(
"Error reaping worker=[%s]. Will retry on next watchdog.", worker_id
)
self.clear_timers(reaped_worker_ids)
def _get_scopes(self, timer_requests):
return [r.scope_id for r in timer_requests]
def start(self) -> None:
logger.info(
"Starting %s..." " max_interval=%s," " daemon=%s",
type(self).__name__,
self._max_interval,
self._daemon,
)
self._watchdog_thread = threading.Thread(
target=self._watchdog_loop, daemon=self._daemon
)
logger.info("Starting watchdog thread...")
self._watchdog_thread.start()
def stop(self) -> None:
logger.info("Stopping %s", type(self).__name__)
self._stop_signaled = True
if self._watchdog_thread:
logger.info("Stopping watchdog thread...")
self._watchdog_thread.join(self._max_interval)
self._watchdog_thread = None
else:
logger.info("No watchdog thread running, doing nothing")
_timer_client: Optional[TimerClient] = None
def configure(timer_client: TimerClient):
"""
Configures a timer client. Must be called before using ``expires``.
"""
global _timer_client
_timer_client = timer_client
logger.info("Timer client configured to: %s", type(_timer_client).__name__)
@contextmanager
def expires(
after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None
):
"""
Acquires a countdown timer that expires in ``after`` seconds from now,
unless the code-block that it wraps is finished within the timeframe.
When the timer expires, this worker is eligible to be reaped. The
exact meaning of "reaped" depends on the client implementation. In
most cases, reaping means to terminate the worker process.
Note that the worker is NOT guaranteed to be reaped at exactly
``time.now() + after``, but rather the worker is "eligible" for being
reaped and the ``TimerServer`` that the client talks to will ultimately
make the decision when and how to reap the workers with expired timers.
Usage::
core.distributed.elastic.timer.configure(LocalTimerClient())
with expires(after=10):
core.distributed.all_reduce(...)
"""
if client is None:
if _timer_client is None:
raise RuntimeError("Configure timer client before using countdown timers.")
client = _timer_client
if scope is None:
# grab the caller file + lineno
caller = getframeinfo(stack()[1][0])
scope = f"{caller.filename}#{caller.lineno}"
expiration = time.time() + after
client.acquire(scope, expiration)
try:
yield
finally:
client.release(scope)

+ 0
- 25
mindnlp/core/distributed/elastic/timer/debug_info_logging.py View File

@@ -1,25 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List
from core.distributed.elastic.utils.logging import get_logger
logger = get_logger(__name__)
__all__ = ["log_debug_info_for_expired_timers"]
def log_debug_info_for_expired_timers(
run_id: str,
expired_timers: Dict[int, List[str]],
):
if expired_timers:
logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers)

+ 0
- 396
mindnlp/core/distributed/elastic/timer/file_based_local_timer.py View File

@@ -1,396 +0,0 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import io
import json
import os
import select
import signal
import sys
import threading
import time
from typing import Callable, Dict, List, Optional, Set, Tuple
from core.distributed.elastic.timer.api import TimerClient, TimerRequest
from core.distributed.elastic.timer.debug_info_logging import (
log_debug_info_for_expired_timers,
)
from core.distributed.elastic.utils.logging import get_logger
__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"]
logger = get_logger(__name__)
class FileTimerRequest(TimerRequest):
"""
Data object representing a countdown timer acquisition and release
that is used between the ``FileTimerClient`` and ``FileTimerServer``.
A negative ``expiration_time`` should be interpreted as a "release"
request.
``signal`` is the signal to reap the worker process from the server
process.
"""
__slots__ = ["version", "worker_pid", "scope_id", "expiration_time", "signal"]
def __init__(
self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0
) -> None:
self.version = 1
self.worker_pid = worker_pid
self.scope_id = scope_id
self.expiration_time = expiration_time
self.signal = signal
def __eq__(self, other) -> bool:
if isinstance(other, FileTimerRequest):
return (
self.version == other.version
and self.worker_pid == other.worker_pid
and self.scope_id == other.scope_id
and self.expiration_time == other.expiration_time
and self.signal == other.signal
)
return False
def to_json(self) -> str:
return json.dumps(
{
"version": self.version,
"pid": self.worker_pid,
"scope_id": self.scope_id,
"expiration_time": self.expiration_time,
"signal": self.signal,
},
)
class FileTimerClient(TimerClient):
"""
Client side of ``FileTimerServer``. This client is meant to be used
on the same host that the ``FileTimerServer`` is running on and uses
pid to uniquely identify a worker.
This client uses a named_pipe to send timer requests to the
``FileTimerServer``. This client is a producer while the
``FileTimerServer`` is a consumer. Multiple clients can work with
the same ``FileTimerServer``.
Args:
file_path: str, the path of a FIFO special file. ``FileTimerServer``
must have created it by calling os.mkfifo().
signal: signal, the signal to use to kill the process. Using a
negative or zero signal will not kill the process.
"""
def __init__(
self,
file_path: str,
signal=(signal.SIGKILL if sys.platform != "win32" else signal.CTRL_C_EVENT), # type: ignore[attr-defined]
) -> None:
super().__init__()
self._file_path = file_path
self.signal = signal
def _open_non_blocking(self) -> Optional[io.TextIOWrapper]:
try:
fd = os.open(self._file_path, os.O_WRONLY | os.O_NONBLOCK)
return os.fdopen(fd, "wt")
except Exception:
return None
def _send_request(self, request: FileTimerRequest) -> None:
# The server may have crashed or may haven't started yet.
# In such case, calling open() in blocking model blocks the client.
# To avoid such issue, open it in non-blocking mode, and an OSError will
# be raised if the server is not there.
file = self._open_non_blocking()
if file is None:
raise BrokenPipeError(
"Could not send the FileTimerRequest because FileTimerServer is not available."
)
with file:
json_request = request.to_json()
# Write request with no greater than select.PIPE_BUF is guarantee to be atomic.
if len(json_request) > select.PIPE_BUF:
raise RuntimeError(
f"FileTimerRequest larger than {select.PIPE_BUF} bytes "
f"is not supported: {json_request}"
)
file.write(json_request + "\n")
def acquire(self, scope_id: str, expiration_time: float) -> None:
self._send_request(
request=FileTimerRequest(
worker_pid=os.getpid(),
scope_id=scope_id,
expiration_time=expiration_time,
signal=self.signal,
),
)
def release(self, scope_id: str) -> None:
self._send_request(
request=FileTimerRequest(
worker_pid=os.getpid(), scope_id=scope_id, expiration_time=-1, signal=0
),
)
class FileTimerServer:
"""
Server that works with ``FileTimerClient``. Clients are expected to be
running on the same host as the process that is running this server.
Each host in the job is expected to start its own timer server locally
and each server instance manages timers for local workers (running on
processes on the same host).
Args:
file_path: str, the path of a FIFO special file to be created.
max_interval: float, max interval in seconds for each watchdog loop.
daemon: bool, running the watchdog thread in daemon mode or not.
A daemon thread will not block a process to stop.
log_event: Callable[[Dict[str, str]], None], an optional callback for
logging the events in JSON format.
"""
def __init__(
self,
file_path: str,
run_id: str,
max_interval: float = 10,
daemon: bool = True,
log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None,
) -> None:
self._file_path = file_path
self._run_id = run_id
self._max_interval = max_interval
self._daemon = daemon
self._timers: Dict[Tuple[int, str], FileTimerRequest] = {}
self._stop_signaled = False
self._watchdog_thread: Optional[threading.Thread] = None
self._is_client_started = False
if os.path.exists(self._file_path):
os.remove(self._file_path)
os.mkfifo(self._file_path)
# For test only. Count the number of requests received.
self._request_count = 0
# For test only. Process all requests and stop the server.
self._run_once = False
self._log_event = (
log_event if log_event is not None else lambda name, request: None
)
self._last_progress_time = int(time.time())
def start(self) -> None:
logger.info(
"Starting %s... max_interval=%s, daemon=%s, file_path=%s",
type(self).__name__,
self._max_interval,
self._daemon,
self._file_path,
)
self._watchdog_thread = threading.Thread(
target=self._watchdog_loop, daemon=self._daemon
)
logger.info("Starting watchdog thread...")
self._watchdog_thread.start()
self._log_event("watchdog started", None)
def stop(self) -> None:
logger.info("Stopping %s", type(self).__name__)
self._stop_signaled = True
if self._watchdog_thread:
logger.info("Stopping watchdog thread...")
self._watchdog_thread.join(self._max_interval)
self._watchdog_thread = None
else:
logger.info("No watchdog thread running, doing nothing")
if os.path.exists(self._file_path):
os.remove(self._file_path)
self._log_event("watchdog stopped", None)
def run_once(self) -> None:
self._run_once = True
if self._watchdog_thread:
logger.info("Stopping watchdog thread...")
self._watchdog_thread.join()
self._watchdog_thread = None
else:
logger.info("No watchdog thread running, doing nothing")
if os.path.exists(self._file_path):
os.remove(self._file_path)
@staticmethod
def is_process_running(pid: int):
"""
function to check process is running or not
"""
try:
# Check if the process exists and we can send signals to it
os.kill(pid, 0)
return True
except OSError:
return False
def _watchdog_loop(self) -> None:
# Open the pipe in blocking mode blocks the server thread.
# This is fine for the following reasons:
# 1. No client case usually does not happen.
# 2. We are running the watchdog loop in a separate daemon
# thread, which will not block the process to stop.
with open(self._file_path) as fd:
self._is_client_started = True
while not self._stop_signaled:
try:
run_once = self._run_once
self._run_watchdog(fd)
if run_once:
break
self._last_progress_time = int(time.time())
except Exception:
logger.exception("Error running watchdog")
def _run_watchdog(self, fd: io.TextIOWrapper) -> None:
timer_requests = self._get_requests(fd, self._max_interval)
self.register_timers(timer_requests)
now = time.time()
reaped_worker_pids = set()
all_expired_timers = self.get_expired_timers(now)
log_debug_info_for_expired_timers(
self._run_id,
{
pid: [expired_timer.to_json() for expired_timer in expired_timers]
for pid, expired_timers in all_expired_timers.items()
},
)
for worker_pid, expired_timers in all_expired_timers.items():
logger.info(
"Reaping worker_pid=[%s]. Expired timers: %s",
worker_pid,
self._get_scopes(expired_timers),
)
reaped_worker_pids.add(worker_pid)
# In case we have multiple expired timers, we find the first timer
# with a valid signal (>0) in the expiration time order.
expired_timers.sort(key=lambda timer: timer.expiration_time)
signal = 0
expired_timer = None
for timer in expired_timers:
self._log_event("timer expired", timer)
if timer.signal > 0:
signal = timer.signal
expired_timer = timer
break
if signal <= 0:
logger.info(
"No signal specified with worker=[%s]. Do not reap it.", worker_pid
)
continue
if self._reap_worker(worker_pid, signal):
logger.info(
"Successfully reaped worker=[%s] with signal=%s", worker_pid, signal
)
self._log_event("kill worker process", expired_timer)
else:
logger.error(
"Error reaping worker=[%s]. Will retry on next watchdog.",
worker_pid,
)
self.clear_timers(reaped_worker_pids)
def _get_scopes(self, timer_requests: List[FileTimerRequest]) -> List[str]:
return [r.scope_id for r in timer_requests]
def _get_requests(
self, fd: io.TextIOWrapper, max_interval: float
) -> List[FileTimerRequest]:
start = time.time()
requests = []
while not self._stop_signaled or self._run_once:
# For named pipe, readline() is blocking when at least one writer opens.
# It returns only when flush() is called at the writer side.
# Note that flush() is automatically called inside close().
# After the last writer closes, readline() is not blocking.
# It will return an empty string when it's at end-of-file.
# Since the client side always opens the pipe, writes a message and closes
# the pipe immediately, the readline() call below is not blocking for long.
json_request = fd.readline()
if len(json_request) == 0:
if self._run_once:
break
time.sleep(min(max_interval, 1))
else:
request = json.loads(json_request)
pid = request["pid"]
scope_id = request["scope_id"]
expiration_time = request["expiration_time"]
signal = request["signal"]
requests.append(
FileTimerRequest(
worker_pid=pid,
scope_id=scope_id,
expiration_time=expiration_time,
signal=signal,
)
)
now = time.time()
if now - start > max_interval:
break
return requests
def register_timers(self, timer_requests: List[FileTimerRequest]) -> None:
for request in timer_requests:
pid = request.worker_pid
scope_id = request.scope_id
expiration_time = request.expiration_time
self._request_count += 1
key = (pid, scope_id)
# negative expiration is a proxy for a release call
if expiration_time < 0:
if key in self._timers:
del self._timers[key]
else:
self._timers[key] = request
def clear_timers(self, worker_pids: Set[int]) -> None:
for pid, scope_id in list(self._timers.keys()):
if pid in worker_pids or not FileTimerServer.is_process_running(pid):
del self._timers[(pid, scope_id)]
def get_expired_timers(self, deadline: float) -> Dict[int, List[FileTimerRequest]]:
# pid -> [timer_requests...]
expired_timers: Dict[int, List[FileTimerRequest]] = {}
for request in self._timers.values():
if request.expiration_time <= deadline:
expired_scopes = expired_timers.setdefault(request.worker_pid, [])
expired_scopes.append(request)
return expired_timers
def _reap_worker(self, worker_pid: int, signal: int) -> bool:
try:
os.kill(worker_pid, signal)
return True
except ProcessLookupError:
logger.info("Process with pid=%s does not exist. Skipping", worker_pid)
return True
except Exception:
logger.exception("Error terminating pid=%s", worker_pid)
return False
def get_last_progress_time(self) -> int:
return self._last_progress_time if self._is_client_started else int(time.time())

+ 0
- 128
mindnlp/core/distributed/elastic/timer/local_timer.py View File

@@ -1,128 +0,0 @@
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import multiprocessing as mp
import os
import signal
import time
from queue import Empty
from typing import Any, Dict, List, Set, Tuple
from .api import RequestQueue, TimerClient, TimerRequest, TimerServer
__all__ = ["LocalTimerClient", "MultiprocessingRequestQueue", "LocalTimerServer"]
logger = logging.getLogger(__name__)
class LocalTimerClient(TimerClient):
"""
Client side of ``LocalTimerServer``. This client is meant to be used
on the same host that the ``LocalTimerServer`` is running on and uses
pid to uniquely identify a worker. This is particularly useful in situations
where one spawns a subprocess (trainer) per GPU on a host with multiple
GPU devices.
"""
def __init__(self, mp_queue):
super().__init__()
self._mp_queue = mp_queue
def acquire(self, scope_id, expiration_time):
pid = os.getpid()
acquire_request = TimerRequest(pid, scope_id, expiration_time)
self._mp_queue.put(acquire_request)
def release(self, scope_id):
pid = os.getpid()
release_request = TimerRequest(pid, scope_id, -1)
self._mp_queue.put(release_request)
class MultiprocessingRequestQueue(RequestQueue):
"""
A ``RequestQueue`` backed by python ``multiprocessing.Queue``
"""
def __init__(self, mp_queue: mp.Queue):
super().__init__()
self._mp_queue = mp_queue
def size(self) -> int:
return self._mp_queue.qsize()
def get(self, size, timeout: float) -> List[TimerRequest]:
requests = []
wait = timeout
for _ in range(0, size):
start = time.time()
try:
r = self._mp_queue.get(block=True, timeout=wait)
except Empty:
break
requests.append(r)
wait = wait - (time.time() - start)
if wait <= 0:
break
return requests
class LocalTimerServer(TimerServer):
"""
Server that works with ``LocalTimerClient``. Clients are expected to be
subprocesses to the parent process that is running this server. Each host
in the job is expected to start its own timer server locally and each
server instance manages timers for local workers (running on processes
on the same host).
"""
def __init__(
self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True
):
super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon)
self._timers: Dict[Tuple[Any, str], TimerRequest] = {}
def register_timers(self, timer_requests: List[TimerRequest]) -> None:
for request in timer_requests:
pid = request.worker_id
scope_id = request.scope_id
expiration_time = request.expiration_time
# negative expiration is a proxy for a release call
if expiration_time < 0:
self._timers.pop((pid, scope_id), None)
else:
self._timers[(pid, scope_id)] = request
def clear_timers(self, worker_ids: Set[int]) -> None:
for pid, scope_id in list(self._timers.keys()):
if pid in worker_ids:
self._timers.pop((pid, scope_id))
def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]:
# pid -> [timer_requests...]
expired_timers: Dict[Any, List[TimerRequest]] = {}
for request in self._timers.values():
if request.expiration_time <= deadline:
expired_scopes = expired_timers.setdefault(request.worker_id, [])
expired_scopes.append(request)
return expired_timers
def _reap_worker(self, worker_id: int) -> bool:
try:
os.kill(worker_id, signal.SIGKILL)
return True
except ProcessLookupError:
logger.info("Process with pid=%s does not exist. Skipping", worker_id)
return True
except Exception:
logger.exception("Error terminating pid=%s", worker_id)
return False

+ 0
- 9
mindnlp/core/distributed/elastic/utils/__init__.py View File

@@ -1,9 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .api import get_env_variable_or_raise, get_socket_with_port, macros # noqa: F401

+ 0
- 62
mindnlp/core/distributed/elastic/utils/api.py View File

@@ -1,62 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import socket
from string import Template
from typing import Any, List
def get_env_variable_or_raise(env_name: str) -> str:
r"""
Tries to retrieve environment variable. Raises ``ValueError``
if no environment variable found.
Args:
env_name (str): Name of the env variable
"""
value = os.environ.get(env_name, None)
if value is None:
msg = f"Environment variable {env_name} expected, but not set"
raise ValueError(msg)
return value
def get_socket_with_port() -> socket.socket:
addrs = socket.getaddrinfo(
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
s = socket.socket(family, type, proto)
try:
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError:
s.close()
raise RuntimeError("Failed to create a socket")
class macros:
"""
Defines simple macros for caffe2.distributed.launch cmd args substitution
"""
local_rank = "${local_rank}"
@staticmethod
def substitute(args: List[Any], local_rank: str) -> List[str]:
args_sub = []
for arg in args:
if isinstance(arg, str):
sub = Template(arg).safe_substitute(local_rank=local_rank)
args_sub.append(sub)
else:
args_sub.append(arg)
return args_sub

+ 0
- 184
mindnlp/core/distributed/elastic/utils/distributed.py View File

@@ -1,184 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import datetime
import os
import socket
from contextlib import closing
from typing import Optional
from mindnlp import core.distributed as dist
from core.distributed.elastic.utils.logging import get_logger
from core.distributed.elastic.utils.store import barrier
__all__ = ["create_c10d_store", "get_free_port", "get_socket_with_port"]
logger = get_logger(__name__)
_ADDRESS_IN_USE = "Address already in use"
_SOCKET_TIMEOUT = "Socket Timeout"
_TCP_STORE_INIT = "_tcp_store/num_members"
def create_c10d_store(
is_server: bool,
server_addr: str,
server_port: int = -1,
world_size: int = 1,
timeout: float = (60 * 10), # 10 min
wait_for_workers: bool = True,
retries=3,
use_libuv: Optional[bool] = None,
):
if use_libuv is not None:
logger.warning(
"argument use_libuv is deprecated and ignored. Set USE_LIBUV environment "
'variable to "0" to disable libuv, or "1" to enable it. If the env var '
"is not set, libuv will be used by default."
)
# check os.environ for use_libuv
use_libuv = os.environ.get("USE_LIBUV", "1") == "1" # libuv is the default option
if server_port == -1 and world_size > 1:
raise ValueError(
f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}"
)
if server_port != -1:
logger.info("sever_port: %s, specified, ignoring retries", server_port)
# only retry when server_port is NOT static
attempt = retries if server_port == -1 else 1
while True:
if server_port != -1:
port = server_port
else:
port = get_free_port()
logger.info(
"Creating c10d store on %s:%s\n"
" world_size : %s\n"
" is_server : %s\n"
" timeout(sec): %s\n"
" use_libuv : %s\n",
server_addr,
port,
world_size,
is_server,
timeout,
use_libuv,
)
try:
store = dist.TCPStore(
host_name=server_addr,
port=port,
world_size=world_size,
is_master=is_server,
timeout=datetime.timedelta(seconds=timeout),
wait_for_workers=wait_for_workers,
use_libuv=use_libuv,
)
# skips full rank check when we don't have to wait for all workers
if wait_for_workers:
_check_full_rank(store, world_size, timeout=timeout)
logger.info("Successfully created c10d store")
return store
except RuntimeError as e:
# this is brittle, but the underlying exception type is not properly pybinded
# so we parse the error msg for now, interestingly this is how torch itself
# detects timeouts and port conflicts in their own unittests
# see - caffe2/torch/testing/_internal/common_utils.py
# TODO properly map the exceptions in pybind (c10d/init.cpp)
if str(e) == _ADDRESS_IN_USE: # this will only happen on the server
if attempt < retries:
logger.warning(
"port: %s already in use, attempt: [%s/%s]",
port,
attempt,
retries,
)
attempt += 1
else:
raise RuntimeError(
f"on {server_addr}, port: {port} already in use"
) from e
else:
raise
def _check_full_rank(store, world_size, timeout):
try:
barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout)
except RuntimeError as e:
if str(e) == _SOCKET_TIMEOUT:
raise TimeoutError(
f"timed out waiting for all {world_size} members to join"
) from e
else:
raise
def get_free_port():
"""
Returns an unused port on localhost.
This function finds an unused port on localhost by opening to socket to bind
to a port and then closing it.
Returns:
int: an unused port on localhost
Example:
>>> # xdoctest: +SKIP("Nondeterministic")
>>> get_free_port()
63976
..note:
The port returned by :func:`get_free_port` is not reserved and may be
taken by another process after this function returns.
"""
sock = get_socket_with_port()
with closing(sock):
return sock.getsockname()[1]
def get_socket_with_port() -> socket.socket:
"""
Returns a free port on localhost that is "reserved" by binding a temporary
socket on it. Close the socket before passing the port to the entity
that requires it. Usage example
::
sock = _get_socket_with_port()
with closing(sock):
port = sock.getsockname()[1]
sock.close()
# there is still a race-condition that some other process
# may grab this port before func() runs
func(port)
"""
addrs = socket.getaddrinfo(
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
)
for addr in addrs:
family, type, proto, _, _ = addr
s = socket.socket(family, type, proto)
try:
s.bind(("localhost", 0))
s.listen(0)
return s
except OSError as e:
s.close()
logger.warning("Socket creation attempt failed.", exc_info=e)
raise RuntimeError("Failed to create a socket")

+ 0
- 14
mindnlp/core/distributed/elastic/utils/log_level.py View File

@@ -1,14 +0,0 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
def get_log_level() -> str:
"""
Return default log level for pycore.
"""
return "WARNING"

+ 0
- 70
mindnlp/core/distributed/elastic/utils/logging.py View File

@@ -1,70 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import logging
import os
import warnings
from typing import Optional
from core.distributed.elastic.utils.log_level import get_log_level
def get_logger(name: Optional[str] = None):
"""
Util function to set up a simple logger that writes
into stderr. The loglevel is fetched from the LOGLEVEL
env. variable or WARNING as default. The function will use the
module name of the caller if no name is provided.
Args:
name: Name of the logger. If no name provided, the name will
be derived from the call stack.
"""
# Derive the name of the caller, if none provided
# Use depth=2 since this function takes up one level in the call stack
return _setup_logger(name or _derive_module_name(depth=2))
def _setup_logger(name: Optional[str] = None):
logger = logging.getLogger(name)
logger.setLevel(os.environ.get("LOGLEVEL", get_log_level()))
return logger
def _derive_module_name(depth: int = 1) -> Optional[str]:
"""
Derives the name of the caller module from the stack frames.
Args:
depth: The position of the frame in the stack.
"""
try:
stack = inspect.stack()
assert depth < len(stack)
# FrameInfo is just a named tuple: (frame, filename, lineno, function, code_context, index)
frame_info = stack[depth]
module = inspect.getmodule(frame_info[0])
if module:
module_name = module.__name__
else:
# inspect.getmodule(frame_info[0]) does NOT work (returns None) in
# binaries built with @mode/opt
# return the filename (minus the .py extension) as modulename
filename = frame_info[1]
module_name = os.path.splitext(os.path.basename(filename))[0]
return module_name
except Exception as e:
warnings.warn(
f"Error deriving logger module name, using <None>. Exception: {e}",
RuntimeWarning,
)
return None

+ 0
- 225
mindnlp/core/distributed/elastic/utils/store.py View File

@@ -1,225 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
from datetime import timedelta
from typing import Callable, Iterable, List, Optional
from mindnlp import core
DistStoreError = core._C._DistStoreError
_NUM_MEMBERS = "/num_members"
_LAST_MEMBER_CHECKIN = "/last_member"
_TRACE = "/TRACE"
_TRACING_GATE = "/TRACING_GATE"
_MAX_TRACE_MISSING_RANKS = 16
__all__ = ["store_timeout", "get_all", "synchronize", "barrier"]
@contextmanager
def store_timeout(store, timeout: float):
"""
This sets the timeout and then restores the old timeout when the context
manager exits.
Args:
store: the store to set the timeout on
timeout: the timeout to set
"""
old_timeout = store.timeout
store.set_timeout(timedelta(seconds=timeout))
yield
store.set_timeout(old_timeout)
def get_all(store, rank: int, prefix: str, world_size: int):
r"""
Given a store and a prefix, the method goes through the array of keys
of the following format: ``{prefix}{idx}``, where idx is in a range
from 0 to size, and tries to retrieve the data.
The Rank0 process waits at the end to make sure all other processes
finished the procedure before exiting.
Usage
::
values = get_all(store, 'torchelastic/data', 3)
value1 = values[0] # retrieves the data for key torchelastic/data0
value2 = values[1] # retrieves the data for key torchelastic/data1
value3 = values[2] # retrieves the data for key torchelastic/data2
"""
data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)])
barrier_key = _barrier_nonblocking(
store=store,
world_size=world_size,
key_prefix=f"{prefix}/finished",
)
if rank == 0:
# Rank0 runs the TCPStore daemon, as a result it needs to exit last.
# Otherwise, the barrier may timeout if rank0 process finished the work
# before other processes finished `get_all` method
store.wait([barrier_key])
return data_arr
def synchronize(
store,
data: bytes,
rank: int,
world_size: int,
key_prefix: str,
timeout: float = 300,
) -> List[bytes]:
"""
Synchronizes ``world_size`` agents between each other using the underlying c10d store.
The ``data`` will be available on each of the agents.
Note: The data on the path is not deleted, as a result there can be stale data if
you use the same key_prefix twice.
Time complexity: O(N) per worker, O(N^2) globally.
"""
with store_timeout(store, timeout):
store.set(f"{key_prefix}{rank}", data)
agent_data = get_all(store, rank, key_prefix, world_size)
return agent_data
def _try_detecting_missing_ranks(
store,
world_size: int,
key_prefix: str,
rank: int,
rank_decoder: Callable[[int], str],
trace_timeout: float,
) -> Optional[Iterable[str]]:
store.set(f"{key_prefix}{rank}{_TRACE}", "<val_ignored>")
def _find_missing_ranks():
missing_rank_info = set()
ranks_missing = 0
for i in range(1, world_size):
# reduce noise, assuming in general 8 ranks per node
# It is valuable to know that 1 or >1 nodes have timed-out.
if ranks_missing >= _MAX_TRACE_MISSING_RANKS:
break
try:
if ranks_missing == 0:
store.wait(
[f"{key_prefix}{i}{_TRACE}"], timedelta(seconds=trace_timeout)
)
else:
# use a shortest timeout, some ranks have failed to check-in
store.wait([f"{key_prefix}{i}{_TRACE}"], timedelta(milliseconds=1))
except DistStoreError:
ranks_missing += 1
missing_rank_info.add(rank_decoder(i))
return missing_rank_info
def _checkin():
try:
store.wait([f"{key_prefix}{_TRACING_GATE}"])
return [f"[<check rank 0 ({rank_decoder(0)}) for missing rank info>]"]
except DistStoreError:
# in case rank0 is the source of the timeout, original exception will be raised
return None
if rank == 0:
missing_rank_info = _find_missing_ranks()
store.set(f"{key_prefix}{_TRACING_GATE}", "<val_ignored>")
return missing_rank_info
else:
return _checkin()
def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str:
"""
Does all the non-blocking operations for a barrier and returns the final key
that can be waited on.
"""
num_members_key = key_prefix + _NUM_MEMBERS
last_member_key = key_prefix + _LAST_MEMBER_CHECKIN
idx = store.add(num_members_key, 1)
if idx == world_size:
store.set(last_member_key, "<val_ignored>")
return last_member_key
def barrier(
store,
world_size: int,
key_prefix: str,
barrier_timeout: float = 300,
rank: Optional[int] = None,
rank_tracing_decoder: Optional[Callable[[int], str]] = None,
trace_timeout: float = 10,
) -> None:
"""
A global lock between agents. This will pause all workers until at least
``world_size`` workers respond.
This uses a fast incrementing index to assign waiting ranks and a success
flag set by the last worker.
Time complexity: O(1) per worker, O(N) globally.
Optionally, passing rank will enable tracing of missing ranks on timeouts.
`rank_tracing_decoder` lambda arg can be used to convert rank data
into a more meaninful information at an app level (e.g. hostname).
Note: Since the data is not removed from the store, the barrier can be used
once per unique ``key_prefix``.
"""
if rank is None:
assert rank_tracing_decoder is None, "Tracing requires rank information"
with store_timeout(store, barrier_timeout):
last_member_key = _barrier_nonblocking(
store=store, world_size=world_size, key_prefix=key_prefix
)
try:
store.wait([last_member_key])
except DistStoreError as e:
if rank is None:
raise e
else:
missing_ranks = _try_detecting_missing_ranks(
store,
world_size,
key_prefix,
rank,
rank_tracing_decoder or (lambda x: str(x)),
trace_timeout,
)
if missing_ranks is not None:
raise DistStoreError(
"Timed out waiting on barrier on "
"rank {}, for key prefix: {} (world_size={}, missing_ranks={}, timeout={})".format(
rank,
key_prefix,
world_size,
f"[{', '.join(missing_ranks)}]",
barrier_timeout,
)
) from None
else:
raise e

+ 0
- 14
mindnlp/core/distributed/launcher/__init__.py View File

@@ -1,14 +0,0 @@
#!/usr/bin/env/python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from core.distributed.launcher.api import ( # noqa: F401
elastic_launch,
launch_agent,
LaunchConfig,
)

+ 0
- 289
mindnlp/core/distributed/launcher/api.py View File

@@ -1,289 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import sys
import uuid
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from mindnlp import core.distributed.elastic.rendezvous.registry as rdzv_registry
from core.distributed.elastic import events, metrics
from core.distributed.elastic.agent.server.api import WorkerSpec
from core.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
from core.distributed.elastic.multiprocessing import (
DefaultLogsSpecs,
LogsSpecs,
SignalException,
)
from core.distributed.elastic.multiprocessing.errors import ChildFailedError
from core.distributed.elastic.rendezvous import RendezvousParameters
from core.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
from core.distributed.elastic.utils.logging import get_logger
__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"]
logger = get_logger(__name__)
@dataclass
class LaunchConfig:
"""
Creates a rendezvous config.
Args:
min_nodes: Minimum amount of nodes that the user function will
be launched on. Elastic agent ensures that the user
function start only when the min_nodes amount enters
the rendezvous.
max_nodes: Maximum amount of nodes that the user function
will be launched on.
nproc_per_node: On each node the elastic agent will launch
this amount of workers that will execute user
defined function.
rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd).
rdzv_endpoint: The endpoint of the rdzv sync. storage.
rdzv_configs: Key, value pair that specifies rendezvous specific configuration.
rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going
to be removed in future versions, see the note below. The default timeout is 900 seconds.
run_id: The unique run id of the job (if not passed a unique one will be
deduced from run environment - flow workflow id in flow - or auto generated).
role: User defined role of the worker (defaults to "trainer").
max_restarts: The maximum amount of restarts that elastic agent will conduct
on workers before failure.
monitor_interval: The interval in seconds that is used by the elastic_agent
as a period of monitoring workers.
start_method: The method is used by the elastic agent to start the
workers (spawn, fork, forkserver).
metrics_cfg: configuration to initialize metrics.
local_addr: address of the local node if any. If not set, a lookup on the local
machine's FQDN will be performed.
local_ranks_filter: ranks for which to show logs in console. If not set, show from all.
..note:
`rdzv_timeout` is a legacy argument that will be removed in future.
Set the timeout via `rdzv_configs['timeout']`
"""
min_nodes: int
max_nodes: int
nproc_per_node: int
logs_specs: Optional[LogsSpecs] = None
run_id: str = ""
role: str = "default_role"
rdzv_endpoint: str = ""
rdzv_backend: str = "etcd"
rdzv_configs: Dict[str, Any] = field(default_factory=dict)
rdzv_timeout: int = -1
max_restarts: int = 3
monitor_interval: float = 0.1
start_method: str = "spawn"
log_line_prefix_template: Optional[str] = None
metrics_cfg: Dict[str, str] = field(default_factory=dict)
local_addr: Optional[str] = None
def __post_init__(self):
default_timeout = 900
if self.rdzv_timeout != -1:
self.rdzv_configs["timeout"] = self.rdzv_timeout
elif "timeout" not in self.rdzv_configs:
self.rdzv_configs["timeout"] = default_timeout
# Post-processing to enable refactoring to introduce logs_specs due to non-torchrun API usage
if self.logs_specs is None:
self.logs_specs = DefaultLogsSpecs()
class elastic_launch:
"""
Launches an torchelastic agent on the container that invoked the entrypoint.
1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
``entrypoint`` can be a function or a command.
2. The return value is a map of each worker's output mapped
by their respective global rank.
Usage
::
def worker_fn(foo):
# ...
def main():
# entrypoint is a function.
outputs = elastic_launch(LaunchConfig, worker_fn)(foo)
# return rank 0's output
return outputs[0]
# entrypoint is a command and ``script.py`` is the python module.
outputs = elastic_launch(LaunchConfig, "script.py")(args)
outputs = elastic_launch(LaunchConfig, "python")("script.py")
"""
def __init__(
self,
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
):
self._config = config
self._entrypoint = entrypoint
def __call__(self, *args):
return launch_agent(self._config, self._entrypoint, list(args))
def _get_entrypoint_name(
entrypoint: Union[Callable, str, None], args: List[Any]
) -> str:
"""Retrieve entrypoint name with the rule:
1. If entrypoint is a function, use ``entrypoint.__qualname__``.
2. If entrypoint is a string, check its value:
2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args``
which does not start with hifen letter (for example, "-u" will be skipped).
2.2 otherwise, use ``entrypoint`` value.
3. Otherwise, return empty string.
"""
if isinstance(entrypoint, Callable): # type: ignore[arg-type]
return entrypoint.__name__ # type: ignore[union-attr]
elif isinstance(entrypoint, str):
if entrypoint == sys.executable:
return next((arg for arg in args if arg[0] != "-"), "")
else:
return entrypoint
else:
return ""
def _get_addr_and_port(
rdzv_parameters: RendezvousParameters,
) -> Tuple[Optional[str], Optional[int]]:
if rdzv_parameters.backend != "static":
return (None, None)
endpoint = rdzv_parameters.endpoint
endpoint = endpoint.strip()
if not endpoint:
raise ValueError(
"Endpoint is missing in endpoint. Try to add --master-addr and --master-port"
)
master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1)
if master_port == -1:
raise ValueError(
f"port is missing in endpoint: {endpoint}. Try to specify --master-port"
)
return (master_addr, master_port)
def launch_agent(
config: LaunchConfig,
entrypoint: Union[Callable, str, None],
args: List[Any],
) -> Dict[int, Any]:
if not config.run_id:
run_id = str(uuid.uuid4().int)
logger.warning("config has no run_id, generated a random run_id: %s", run_id)
config.run_id = run_id
entrypoint_name = _get_entrypoint_name(entrypoint, args)
logger.info(
"Starting elastic_operator with launch configs:\n"
" entrypoint : %(entrypoint)s\n"
" min_nodes : %(min_nodes)s\n"
" max_nodes : %(max_nodes)s\n"
" nproc_per_node : %(nproc_per_node)s\n"
" run_id : %(run_id)s\n"
" rdzv_backend : %(rdzv_backend)s\n"
" rdzv_endpoint : %(rdzv_endpoint)s\n"
" rdzv_configs : %(rdzv_configs)s\n"
" max_restarts : %(max_restarts)s\n"
" monitor_interval : %(monitor_interval)s\n"
" log_dir : %(log_dir)s\n"
" metrics_cfg : %(metrics_cfg)s\n",
{
"entrypoint": entrypoint_name,
"min_nodes": config.min_nodes,
"max_nodes": config.max_nodes,
"nproc_per_node": config.nproc_per_node,
"run_id": config.run_id,
"rdzv_backend": config.rdzv_backend,
"rdzv_endpoint": config.rdzv_endpoint,
"rdzv_configs": config.rdzv_configs,
"max_restarts": config.max_restarts,
"monitor_interval": config.monitor_interval,
"log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr]
"metrics_cfg": config.metrics_cfg,
},
)
rdzv_parameters = RendezvousParameters(
backend=config.rdzv_backend,
endpoint=config.rdzv_endpoint,
run_id=config.run_id,
min_nodes=config.min_nodes,
max_nodes=config.max_nodes,
local_addr=config.local_addr,
**config.rdzv_configs,
)
master_addr, master_port = _get_addr_and_port(rdzv_parameters)
spec = WorkerSpec(
role=config.role,
local_world_size=config.nproc_per_node,
entrypoint=entrypoint,
args=tuple(args),
rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
max_restarts=config.max_restarts,
monitor_interval=config.monitor_interval,
master_addr=master_addr,
master_port=master_port,
local_addr=config.local_addr,
)
agent = LocalElasticAgent(
spec=spec,
logs_specs=config.logs_specs, # type: ignore[arg-type]
start_method=config.start_method,
log_line_prefix_template=config.log_line_prefix_template,
)
shutdown_rdzv = True
try:
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
result = agent.run()
# records that agent.run() has succeeded NOT that workers have succeeded
events.record(agent.get_event_succeeded())
if result.is_failed():
# ChildFailedError is treated specially by @record
# if the error files for the failed children exist
# @record will copy the first error (root cause)
# to the error file of the launcher process.
raise ChildFailedError(
name=entrypoint_name,
failures=result.failures,
)
return result.return_values
except ChildFailedError:
raise
except SignalException:
# when the agent dies with a signal do NOT shutdown the rdzv_handler
# since this closes the rendezvous on this rdzv_id permanently and
# prevents any additional scaling events
shutdown_rdzv = False
events.record(agent.get_event_failed())
raise
except Exception:
events.record(agent.get_event_failed())
raise
finally:
if shutdown_rdzv:
spec.rdzv_handler.shutdown()

+ 0
- 7
mindnlp/core/distributed/pipelining/README.md View File

@@ -1,7 +0,0 @@
# Pipeline Parallelism for PyTorch
`core.distributed.pipelining` is a package for implementing pipeline parallelism on your model.
Our documentation is available [here](https://pycore.org/docs/main/distributed.pipelining.html).
![pipeline_diagram_web](https://github.com/pytorch/PiPPy/assets/6676466/c93e2fe7-1cd4-49a2-9fd8-231ec9905e0c)

+ 0
- 20
mindnlp/core/distributed/rpc/_testing/__init__.py View File

@@ -1,20 +0,0 @@
# mypy: allow-untyped-defs
from mindnlp import core
def is_available():
return hasattr(core._C, "_faulty_agent_init")
if is_available() and not core._C._faulty_agent_init():
raise RuntimeError("Failed to initialize core.distributed.rpc._testing")
if is_available():
# Registers FAULTY_TENSORPIPE RPC backend.
from core._C._distributed_rpc_testing import (
FaultyTensorPipeAgent,
FaultyTensorPipeRpcBackendOptions,
)
from . import faulty_agent_backend_registry

+ 0
- 922
mindnlp/core/distributed/run.py View File

@@ -1,922 +0,0 @@
#!/usr/bin/env python3
# mypy: allow-untyped-defs
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Superset of ``core.distributed.launch``.
``torchrun`` provides a superset of the functionality as ``core.distributed.launch``
with the following additional functionalities:
1. Worker failures are handled gracefully by restarting all workers.
2. Worker ``RANK`` and ``WORLD_SIZE`` are assigned automatically.
3. Number of nodes is allowed to change between minimum and maximum sizes (elasticity).
.. note:: ``torchrun`` is a python
`console script <https://packaging.python.org/en/latest/specifications/entry-points/#use-for-scripts>`_
to the main module
`core.distributed.run <https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py>`_
declared in the ``entry_points`` configuration in
`setup.py <https://github.com/pytorch/pytorch/blob/master/setup.py>`_.
It is equivalent to invoking ``python -m core.distributed.run``.
Transitioning from core.distributed.launch to torchrun
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
``torchrun`` supports the same arguments as ``core.distributed.launch`` **except**
for ``--use-env`` which is now deprecated. To migrate from ``core.distributed.launch``
to ``torchrun`` follow these steps:
1. If your training script is already reading ``local_rank`` from the ``LOCAL_RANK`` environment variable.
Then you need simply omit the ``--use-env`` flag, e.g.:
+--------------------------------------------------------------------+--------------------------------------------+
| ``core.distributed.launch`` | ``torchrun`` |
+====================================================================+============================================+
| | |
| .. code-block:: shell-session | .. code-block:: shell-session |
| | |
| $ python -m core.distributed.launch --use-env train_script.py | $ torchrun train_script.py |
| | |
+--------------------------------------------------------------------+--------------------------------------------+
2. If your training script reads local rank from a ``--local-rank`` cmd argument.
Change your training script to read from the ``LOCAL_RANK`` environment variable as
demonstrated by the following code snippet:
+-------------------------------------------------------+----------------------------------------------------+
| ``core.distributed.launch`` | ``torchrun`` |
+=======================================================+====================================================+
| | |
| .. code-block:: python | .. code-block:: python |
| | |
| | |
| import argparse | import os |
| parser = argparse.ArgumentParser() | local_rank = int(os.environ["LOCAL_RANK"]) |
| parser.add_argument("--local-rank", type=int) | |
| args = parser.parse_args() | |
| | |
| local_rank = args.local_rank | |
| | |
+-------------------------------------------------------+----------------------------------------------------+
.. versionchanged:: 2.0.0
The launcher will pass the ``--local-rank=<rank>`` argument to your script.
From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the
previously used underscored ``--local_rank``.
For backward compatibility, it may be necessary for users to handle both
cases in their argument parsing code. This means including both ``"--local-rank"``
and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is
provided, the launcher will trigger an error: "error: unrecognized arguments:
--local-rank=<rank>". For training code that only supports PyTorch 2.0.0+,
including ``"--local-rank"`` should be sufficient.
::
>>> # xdoctest: +SKIP
>>> import argparse
>>> parser = argparse.ArgumentParser()
>>> parser.add_argument("--local-rank", "--local_rank", type=int)
>>> args = parser.parse_args()
The aformentioned changes suffice to migrate from ``core.distributed.launch`` to ``torchrun``.
To take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torchrun``
please refer to:
* :ref:`elastic_train_script` for more information on authoring training scripts that are ``torchrun`` compliant.
* the rest of this page for more information on the features of ``torchrun``.
Usage
--------
Single-node multi-worker
++++++++++++++++++++++++++++++
::
torchrun
--standalone
--nnodes=1
--nproc-per-node=$NUM_TRAINERS
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
Stacked single-node multi-worker
+++++++++++++++++++++++++++++++++++
To run multiple instances (separate jobs) of single-node, multi-worker on the
same host, we need to make sure that each instance (job) is
setup on different ports to avoid port conflicts (or worse, two jobs being merged
as a single job). To do this you have to run with ``--rdzv-backend=c10d``
and specify a different port by setting ``--rdzv-endpoint=localhost:$PORT_k``.
For ``--nodes=1``, its often convenient to let ``torchrun`` pick a free random
port automatically instead of manually assigning different ports for each run.
::
torchrun
--rdzv-backend=c10d
--rdzv-endpoint=localhost:0
--nnodes=1
--nproc-per-node=$NUM_TRAINERS
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
Fault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures)
++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
::
torchrun
--nnodes=$NUM_NODES
--nproc-per-node=$NUM_TRAINERS
--max-restarts=3
--rdzv-id=$JOB_ID
--rdzv-backend=c10d
--rdzv-endpoint=$HOST_NODE_ADDR
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and
the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any
node in your training cluster, but ideally you should pick a node that has a high bandwidth.
.. note::
If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400.
Elastic (``min=1``, ``max=4``, tolerates up to 3 membership changes or failures)
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
::
torchrun
--nnodes=1:4
--nproc-per-node=$NUM_TRAINERS
--max-restarts=3
--rdzv-id=$JOB_ID
--rdzv-backend=c10d
--rdzv-endpoint=$HOST_NODE_ADDR
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and
the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any
node in your training cluster, but ideally you should pick a node that has a high bandwidth.
.. note::
If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400.
Note on rendezvous backend
------------------------------
For multi-node training you need to specify:
1. ``--rdzv-id``: A unique job id (shared by all nodes participating in the job)
2. ``--rdzv-backend``: An implementation of
:py:class:`core.distributed.elastic.rendezvous.RendezvousHandler`
3. ``--rdzv-endpoint``: The endpoint where the rendezvous backend is running; usually in form
``host:port``.
Currently ``c10d`` (recommended), ``etcd-v2``, and ``etcd`` (legacy) rendezvous backends are
supported out of the box. To use ``etcd-v2`` or ``etcd``, setup an etcd server with the ``v2`` api
enabled (e.g. ``--enable-v2``).
.. warning::
``etcd-v2`` and ``etcd`` rendezvous use etcd API v2. You MUST enable the v2 API on the etcd
server. Our tests use etcd v3.4.3.
.. warning::
For etcd-based rendezvous we recommend using ``etcd-v2`` over ``etcd`` which is functionally
equivalent, but uses a revised implementation. ``etcd`` is in maintenance mode and will be
removed in a future version.
Definitions
--------------
1. ``Node`` - A physical instance or a container; maps to the unit that the job manager works with.
2. ``Worker`` - A worker in the context of distributed training.
3. ``WorkerGroup`` - The set of workers that execute the same function (e.g. trainers).
4. ``LocalWorkerGroup`` - A subset of the workers in the worker group running on the same node.
5. ``RANK`` - The rank of the worker within a worker group.
6. ``WORLD_SIZE`` - The total number of workers in a worker group.
7. ``LOCAL_RANK`` - The rank of the worker within a local worker group.
8. ``LOCAL_WORLD_SIZE`` - The size of the local worker group.
9. ``rdzv_id`` - A user-defined id that uniquely identifies the worker group for a job. This id is
used by each node to join as a member of a particular worker group.
9. ``rdzv_backend`` - The backend of the rendezvous (e.g. ``c10d``). This is typically a strongly
consistent key-value store.
10. ``rdzv_endpoint`` - The rendezvous backend endpoint; usually in form ``<host>:<port>``.
A ``Node`` runs ``LOCAL_WORLD_SIZE`` workers which comprise a ``LocalWorkerGroup``. The union of
all ``LocalWorkerGroups`` in the nodes in the job comprise the ``WorkerGroup``.
Environment Variables
----------------------
The following environment variables are made available to you in your script:
1. ``LOCAL_RANK`` - The local rank.
2. ``RANK`` - The global rank.
3. ``GROUP_RANK`` - The rank of the worker group. A number between 0 and ``max_nnodes``. When
running a single worker group per node, this is the rank of the node.
4. ``ROLE_RANK`` - The rank of the worker across all the workers that have the same role. The role
of the worker is specified in the ``WorkerSpec``.
5. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to
``--nproc-per-node`` specified on ``torchrun``.
6. ``WORLD_SIZE`` - The world size (total number of workers in the job).
7. ``ROLE_WORLD_SIZE`` - The total number of workers that was launched with the same role specified
in ``WorkerSpec``.
8. ``MASTER_ADDR`` - The FQDN of the host that is running worker with rank 0; used to initialize
the Torch Distributed backend.
9. ``MASTER_PORT`` - The port on the ``MASTER_ADDR`` that can be used to host the C10d TCP store.
10. ``TORCHELASTIC_RESTART_COUNT`` - The number of worker group restarts so far.
11. ``TORCHELASTIC_MAX_RESTARTS`` - The configured maximum number of restarts.
12. ``TORCHELASTIC_RUN_ID`` - Equal to the rendezvous ``run_id`` (e.g. unique job id).
13. ``PYTHON_EXEC`` - System executable override. If provided, the python user script will
use the value of ``PYTHON_EXEC`` as executable. The `sys.executable` is used by default.
Deployment
------------
1. (Not needed for the C10d backend) Start the rendezvous backend server and get the endpoint (to be
passed as ``--rdzv-endpoint`` to the launcher script)
2. Single-node multi-worker: Start the launcher on the host to start the agent process which
creates and monitors a local worker group.
3. Multi-node multi-worker: Start the launcher with the same arguments on all the nodes
participating in training.
When using a job/cluster manager the entry point command to the multi-node job should be this
launcher.
Failure Modes
---------------
1. Worker failure: For a training job with ``n`` workers, if ``k<=n`` workers fail all workers
are stopped and restarted up to ``max_restarts``.
2. Agent failure: An agent failure results in a local worker group failure. It is up to the job
manager to fail the entire job (gang semantics) or attempt to replace the node. Both behaviors
are supported by the agent.
3. Node failure: Same as agent failure.
Membership Changes
--------------------
1. Node departure (scale-down): The agent is notified of the departure, all existing workers are
stopped, a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and
``WORLD_SIZE``.
2. Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped,
a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and
``WORLD_SIZE``.
Important Notices
--------------------
1. This utility and multi-process distributed (single-node or
multi-node) GPU training currently only achieves the best performance using
the NCCL distributed backend. Thus NCCL backend is the recommended backend to
use for GPU training.
2. The environment variables necessary to initialize a Torch process group are provided to you by
this module, no need for you to pass ``RANK`` manually. To initialize a process group in your
training script, simply run:
::
>>> # xdoctest: +SKIP("stub")
>>> from mindnlp import core.distributed as dist
>>> dist.init_process_group(backend="gloo|nccl")
3. In your training program, you can either use regular distributed functions
or use :func:`core.nn.parallel.DistributedDataParallel` module. If your
training program uses GPUs for training and you would like to use
:func:`core.nn.parallel.DistributedDataParallel` module,
here is how to configure it.
::
local_rank = int(os.environ["LOCAL_RANK"])
model = core.nn.parallel.DistributedDataParallel(model,
device_ids=[local_rank],
output_device=local_rank)
Please ensure that ``device_ids`` argument is set to be the only GPU device id
that your code will be operating on. This is generally the local rank of the
process. In other words, the ``device_ids`` needs to be ``[int(os.environ("LOCAL_RANK"))]``,
and ``output_device`` needs to be ``int(os.environ("LOCAL_RANK"))`` in order to use this
utility
4. On failures or membership changes ALL surviving workers are killed immediately. Make sure to
checkpoint your progress. The frequency of checkpoints should depend on your job's tolerance
for lost work.
5. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all
nodes run the same number of local workers (per role).
6. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assigned a
different range of ranks than before. NEVER hard code any assumptions about the stable-ness of
ranks or some correlation between ``RANK`` and ``LOCAL_RANK``.
7. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about
``WORLD_SIZE`` as the world size can change as nodes are allowed to leave and join.
8. It is recommended for your script to have the following structure:
::
def main():
load_checkpoint(checkpoint_path)
initialize()
train()
def train():
for batch in iter(dataset):
train_step(batch)
if should_checkpoint:
save_checkpoint(checkpoint_path)
9. (Recommended) On worker errors, this tool will summarize the details of the error
(e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)
is heuristically reported as the "Root Cause" error. To get tracebacks as part of this
error summary print out, you must decorate your main entrypoint function in your
training script as shown in the example below. If not decorated, then the summary
will not include the traceback of the exception and will only contain the exitcode.
For details on torchelastic error handling see: https://pycore.org/docs/stable/elastic/errors.html
::
from core.distributed.elastic.multiprocessing.errors import record
@record
def main():
# do train
pass
if __name__ == "__main__":
main()
"""
import os
import sys
import uuid
from argparse import ArgumentParser, REMAINDER
from importlib import metadata
from typing import Callable, List, Optional, Set, Tuple, Type, Union
from mindnlp import core
from core.distributed.argparse_util import check_env, env
from core.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std
from core.distributed.elastic.multiprocessing.errors import record
from core.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
from core.distributed.elastic.utils import macros
from core.distributed.elastic.utils.logging import get_logger
from core.distributed.launcher.api import elastic_launch, LaunchConfig
from core.utils.backend_registration import _get_custom_mod_func
logger = get_logger(__name__)
def get_args_parser() -> ArgumentParser:
"""Parse the command line options."""
parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher")
#
# Worker/node size related arguments.
#
parser.add_argument(
"--nnodes",
action=env,
type=str,
default="1:1",
help="Number of nodes, or the range of nodes in form <minimum_nodes>:<maximum_nodes>.",
)
parser.add_argument(
"--nproc-per-node",
"--nproc_per_node",
action=env,
type=str,
default="1",
help="Number of workers per node; supported values: [auto, cpu, gpu, int].",
)
#
# Rendezvous related arguments
#
parser.add_argument(
"--rdzv-backend",
"--rdzv_backend",
action=env,
type=str,
default="static",
help="Rendezvous backend.",
)
parser.add_argument(
"--rdzv-endpoint",
"--rdzv_endpoint",
action=env,
type=str,
default="",
help="Rendezvous backend endpoint; usually in form <host>:<port>.",
)
parser.add_argument(
"--rdzv-id",
"--rdzv_id",
action=env,
type=str,
default="none",
help="User-defined group id.",
)
parser.add_argument(
"--rdzv-conf",
"--rdzv_conf",
action=env,
type=str,
default="",
help="Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).",
)
parser.add_argument(
"--standalone",
action=check_env,
help="Start a local standalone rendezvous backend that is represented by a C10d TCP store "
"on a free port. Useful when launching single-node, multi-worker job. If specified "
"--rdzv-backend, --rdzv-endpoint, --rdzv-id are auto-assigned and any explicitly set values "
"are ignored.",
)
#
# User-code launch related arguments.
#
parser.add_argument(
"--max-restarts",
"--max_restarts",
action=env,
type=int,
default=0,
help="Maximum number of worker group restarts before failing.",
)
parser.add_argument(
"--monitor-interval",
"--monitor_interval",
action=env,
type=float,
default=0.1,
help="Interval, in seconds, to monitor the state of workers.",
)
parser.add_argument(
"--start-method",
"--start_method",
action=env,
type=str,
default="spawn",
choices=["spawn", "fork", "forkserver"],
help="Multiprocessing start method to use when creating workers.",
)
parser.add_argument(
"--role",
action=env,
type=str,
default="default",
help="User-defined role for the workers.",
)
parser.add_argument(
"-m",
"--module",
action=check_env,
help="Change each process to interpret the launch script as a Python module, executing "
"with the same behavior as 'python -m'.",
)
parser.add_argument(
"--no-python",
"--no_python",
action=check_env,
help="Skip prepending the training script with 'python' - just execute it directly. Useful "
"when the script is not a Python script.",
)
parser.add_argument(
"--run-path",
"--run_path",
action=check_env,
help="Run the training script with runpy.run_path in the same interpreter."
" Script must be provided as an abs path (e.g. /abs/path/script.py)."
" Takes precedence over --no-python.",
)
parser.add_argument(
"--log-dir",
"--log_dir",
action=env,
type=str,
default=None,
help="Base directory to use for log files (e.g. /var/log/torch/elastic). The same "
"directory is re-used for multiple runs (a unique job-level sub-directory is created with "
"rdzv_id as the prefix).",
)
parser.add_argument(
"-r",
"--redirects",
action=env,
type=str,
default="0",
help="Redirect std streams into a log file in the log directory (e.g. [-r 3] redirects "
"both stdout+stderr for all workers, [-r 0:1,1:2] redirects stdout for local rank 0 and "
"stderr for local rank 1).",
)
parser.add_argument(
"-t",
"--tee",
action=env,
type=str,
default="0",
help="Tee std streams into a log file and also to console (see --redirects for format).",
)
parser.add_argument(
"--local-ranks-filter",
"--local_ranks_filter",
action=env,
type=str,
default="",
help="Only show logs from specified ranks in console (e.g. [--local_ranks_filter=0,1,2] will "
"only show logs from rank 0, 1 and 2). This will only apply to stdout and stderr, not to"
"log files saved via --redirect or --tee",
)
#
# Backwards compatible parameters with caffe2.distributed.launch.
#
parser.add_argument(
"--node-rank",
"--node_rank",
type=int,
action=env,
default=0,
help="Rank of the node for multi-node distributed training.",
)
parser.add_argument(
"--master-addr",
"--master_addr",
default="127.0.0.1",
type=str,
action=env,
help="Address of the master node (rank 0) that only used for static rendezvous. It should "
"be either the IP address or the hostname of rank 0. For single node multi-proc training "
"the --master-addr can simply be 127.0.0.1; IPv6 should have the pattern "
"`[0:0:0:0:0:0:0:1]`.",
)
parser.add_argument(
"--master-port",
"--master_port",
default=29500,
type=int,
action=env,
help="Port on the master node (rank 0) to be used for communication during distributed "
"training. It is only used for static rendezvous.",
)
parser.add_argument(
"--local-addr",
"--local_addr",
default=None,
type=str,
action=env,
help="Address of the local node. If specified, will use the given address for connection. "
"Else, will look up the local node address instead. Else, it will be default to local "
"machine's FQDN.",
)
parser.add_argument(
"--logs-specs",
"--logs_specs",
default=None,
type=str,
help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. "
"Can be used to override custom logging behavior.",
)
#
# Positional arguments.
#
parser.add_argument(
"training_script",
type=str,
help="Full path to the (single GPU) training program/script to be launched in parallel, "
"followed by all the arguments for the training script.",
)
# Rest from the training program.
parser.add_argument("training_script_args", nargs=REMAINDER)
return parser
def parse_args(args):
parser = get_args_parser()
return parser.parse_args(args)
def parse_min_max_nnodes(nnodes: str):
arr = nnodes.split(":")
if len(arr) == 1:
min_nodes = max_nodes = int(arr[0])
elif len(arr) == 2:
min_nodes = int(arr[0])
max_nodes = int(arr[1])
else:
raise RuntimeError(f'nnodes={nnodes} is not in "MIN:MAX" format') # noqa: E231
return min_nodes, max_nodes
def determine_local_world_size(nproc_per_node: str):
try:
logger.info("Using nproc_per_node=%s.", nproc_per_node)
return int(nproc_per_node)
except ValueError as e:
if nproc_per_node == "cpu":
num_proc = os.cpu_count()
device_type = "cpu"
elif nproc_per_node == "gpu":
if not core.cuda.is_available():
raise ValueError("Cuda is not available.") from e
device_type = "gpu"
num_proc = core.cuda.device_count()
elif nproc_per_node == core._C._get_privateuse1_backend_name():
if not _get_custom_mod_func("is_available")():
raise ValueError(f"{nproc_per_node} is not available.") from e
device_type = nproc_per_node
num_proc = _get_custom_mod_func("device_count")()
elif nproc_per_node == "auto":
if core.cuda.is_available():
num_proc = core.cuda.device_count()
device_type = "gpu"
elif (
hasattr(torch, core._C._get_privateuse1_backend_name())
and _get_custom_mod_func("is_available")()
):
num_proc = _get_custom_mod_func("device_count")()
device_type = core._C._get_privateuse1_backend_name()
else:
num_proc = os.cpu_count()
device_type = "cpu"
else:
raise ValueError(
f"Unsupported nproc_per_node value: {nproc_per_node}"
) from e
logger.info(
"Using nproc_per_node=%s, setting nproc_per_node to %s since the instance has %s %s",
nproc_per_node,
num_proc,
num_proc,
device_type,
)
return num_proc
def get_rdzv_endpoint(args):
if args.rdzv_backend == "static" and not args.rdzv_endpoint:
return f"{args.master_addr}:{args.master_port}" # noqa: E231
return args.rdzv_endpoint
def get_use_env(args) -> bool:
"""
Retrieve ``use_env`` from the args.
``use_env`` is a legacy argument, if ``use_env`` is False, the
``--node-rank`` argument will be transferred to all worker processes.
``use_env`` is only used by the ``core.distributed.launch`` and will
be deprecated in future releases.
"""
if not hasattr(args, "use_env"):
return True
return args.use_env
def _get_logs_specs_class(logs_specs_name: Optional[str]) -> Type[LogsSpecs]:
"""
Attemps to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
Provides plugin mechanism to provide custom implementation of LogsSpecs.
Returns `DefaultLogsSpecs` when logs_spec_name is None.
Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints.
"""
logs_specs_cls = None
if logs_specs_name is not None:
eps = metadata.entry_points()
if hasattr(eps, "select"): # >= 3.10
group = eps.select(group="torchrun.logs_specs")
if group.select(name=logs_specs_name):
logs_specs_cls = group[logs_specs_name].load()
elif specs := eps.get("torchrun.logs_specs"): # < 3.10
if entrypoint_list := [ep for ep in specs if ep.name == logs_specs_name]:
logs_specs_cls = entrypoint_list[0].load()
if logs_specs_cls is None:
raise ValueError(
f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key"
)
logger.info(
"Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls)
)
else:
logs_specs_cls = DefaultLogsSpecs
return logs_specs_cls
def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]:
# If ``args`` not passed, defaults to ``sys.argv[:1]``
min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
assert 0 < min_nodes <= max_nodes
assert args.max_restarts >= 0
if (
hasattr(args, "master_addr")
and args.rdzv_backend != "static"
and not args.rdzv_endpoint
):
logger.warning(
"master_addr is only used for static rdzv_backend and when rdzv_endpoint "
"is not specified."
)
nproc_per_node = determine_local_world_size(args.nproc_per_node)
if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1:
omp_num_threads = 1
logger.warning(
"\n*****************************************\n"
"Setting OMP_NUM_THREADS environment variable for each process to be "
"%s in default, to avoid your system being overloaded, "
"please further tune the variable for optimal performance in "
"your application as needed. \n"
"*****************************************",
omp_num_threads,
)
# This env variable will be passed down to the subprocesses
os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
log_line_prefix_template = os.getenv("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE")
rdzv_configs = _parse_rendezvous_config(args.rdzv_conf)
if args.rdzv_backend == "static":
rdzv_configs["rank"] = args.node_rank
rdzv_endpoint = get_rdzv_endpoint(args)
ranks: Optional[Set[int]] = None
if args.local_ranks_filter:
try:
ranks = set(map(int, args.local_ranks_filter.split(",")))
assert ranks
except Exception as e:
raise ValueError(
"--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
) from e
logs_specs_cls: Type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
logs_specs = logs_specs_cls(
log_dir=args.log_dir,
redirects=Std.from_str(args.redirects),
tee=Std.from_str(args.tee),
local_ranks_filter=ranks,
)
config = LaunchConfig(
min_nodes=min_nodes,
max_nodes=max_nodes,
nproc_per_node=nproc_per_node,
run_id=args.rdzv_id,
role=args.role,
rdzv_endpoint=rdzv_endpoint,
rdzv_backend=args.rdzv_backend,
rdzv_configs=rdzv_configs,
max_restarts=args.max_restarts,
monitor_interval=args.monitor_interval,
start_method=args.start_method,
log_line_prefix_template=log_line_prefix_template,
local_addr=args.local_addr,
logs_specs=logs_specs,
)
with_python = not args.no_python
cmd: Union[Callable, str]
cmd_args = []
use_env = get_use_env(args)
if args.run_path:
cmd = run_script_path
cmd_args.append(args.training_script)
else:
if with_python:
cmd = os.getenv("PYTHON_EXEC", sys.executable)
cmd_args.append("-u")
if args.module:
cmd_args.append("-m")
cmd_args.append(args.training_script)
else:
if args.module:
raise ValueError(
"Don't use both the '--no-python' flag"
" and the '--module' flag at the same time."
)
cmd = args.training_script
if not use_env:
cmd_args.append(f"--local-rank={macros.local_rank}")
cmd_args.extend(args.training_script_args)
return config, cmd, cmd_args
def run_script_path(training_script: str, *training_script_args: str):
"""
Run the provided `training_script` from within this interpreter.
Usage: `script_as_function("/abs/path/to/script.py", "--arg1", "val1")`
"""
import runpy
import sys
sys.argv = [training_script] + [*training_script_args]
runpy.run_path(sys.argv[0], run_name="__main__")
def run(args):
core.multiprocessing._set_thread_name("pt_elastic")
if args.standalone:
args.rdzv_backend = "c10d"
args.rdzv_endpoint = "localhost:0"
args.rdzv_id = str(uuid.uuid4())
logger.info(
"\n**************************************\n"
"Rendezvous info:\n"
"--rdzv-backend=%s "
"--rdzv-endpoint=%s "
"--rdzv-id=%s\n"
"**************************************\n",
args.rdzv_backend,
args.rdzv_endpoint,
args.rdzv_id,
)
config, cmd, cmd_args = config_from_args(args)
elastic_launch(
config=config,
entrypoint=cmd,
)(*cmd_args)
@record
def main(args=None):
args = parse_args(args)
run(args)
if __name__ == "__main__":
main()

+ 0
- 5
mindnlp/core/nn/attention/flex_attention.py View File

@@ -1,5 +0,0 @@
from mindnlp import core
BlockMask = core.Tensor
flex_attention = None
create_block_mask = None
_DEFAULT_SPARSE_BLOCK_SIZE = None

+ 0
- 90
mindnlp/core/npu/amp/autocast_mode.py View File

@@ -1,90 +0,0 @@
# mypy: allow-untyped-defs
import functools
from typing import Any
from typing_extensions import deprecated
from mindnlp import core
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
class autocast(core.amp.autocast_mode.autocast):
r"""See :class:`core.autocast`.
``core.cuda.amp.autocast(args...)`` is deprecated. Please use ``core.amp.autocast("cuda", args...)`` instead.
"""
@deprecated(
"`core.cuda.amp.autocast(args...)` is deprecated. "
"Please use `core.amp.autocast('cuda', args...)` instead.",
category=FutureWarning,
)
def __init__(
self,
enabled: bool = True,
dtype: core.dtype = core.float16,
cache_enabled: bool = True,
):
if core._jit_internal.is_scripting():
self._enabled = enabled
self.device = "cuda"
self.fast_dtype = dtype
return
super().__init__(
"cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
)
def __enter__(self):
if core._jit_internal.is_scripting():
return self
return super().__enter__()
# TODO: discuss a unified TorchScript-friendly API for autocast
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
if core._jit_internal.is_scripting():
return
return super().__exit__(exc_type, exc_val, exc_tb)
def __call__(self, func):
if core._jit_internal.is_scripting():
return func
return super().__call__(func)
# Preserved only for BC reasons
@deprecated(
"`core.cuda.amp.autocast_mode._cast(value, dtype)` is deprecated. "
"Please use `core.amp.autocast_mode._cast(value, 'cuda', dtype)` instead.",
category=FutureWarning,
)
def _cast(value, dtype):
return core.amp.autocast_mode._cast(value, "cuda", dtype)
@deprecated(
"`core.cuda.amp.custom_fwd(args...)` is deprecated. "
"Please use `core.amp.custom_fwd(args..., device_type='cuda')` instead.",
category=FutureWarning,
)
def custom_fwd(fwd=None, *, cast_inputs=None):
"""
``core.cuda.amp.custom_fwd(args...)`` is deprecated. Please use
``core.amp.custom_fwd(args..., device_type='cuda')`` instead.
"""
return functools.partial(core.amp.custom_fwd, device_type="cuda")(
fwd=fwd, cast_inputs=cast_inputs
)
@deprecated(
"`core.cuda.amp.custom_bwd(args...)` is deprecated. "
"Please use `core.amp.custom_bwd(args..., device_type='cuda')` instead.",
category=FutureWarning,
)
def custom_bwd(bwd):
"""
``core.cuda.amp.custom_bwd(args...)`` is deprecated. Please use
``core.amp.custom_bwd(args..., device_type='cuda')`` instead.
"""
return functools.partial(core.amp.custom_bwd, device_type="cuda")(bwd)

+ 0
- 0
mindnlp/core/testing/_internal/__init__.py View File


+ 1
- 1
mindnlp/dataset/transforms/lookup.py View File

@@ -20,7 +20,7 @@ lookup transforms
"""
import mindspore._c_dataengine as cde # pylint: disable=no-name-in-module, import-error
from mindspore.dataset.text.transforms import TextTensorOperation
from mindspore.dataset.core.datatypes import mstype_to_detype
from mindspore.dataset.mindtorch.datatypes import mstype_to_detype
from mindspore.common import dtype as mstype
from mindspore.dataset.text import Vocab as msVocab
from mindnlp.vocab import Vocab as nlpVocab


+ 3
- 3
mindnlp/experimental/rwkv6/modeling_rwkv6.py View File

@@ -19,8 +19,8 @@
# ============================================================================
import mindspore
import mindnlp
import core.nn as nn
import core.ops as ops
import mindtorch.nn as nn
import mindtorch.ops as ops
from typing import Tuple


@@ -188,7 +188,7 @@ class RWKV_RNN(nn.Module):
self.set_train(False)

# 加载权重
w = core.serialization.load(args['MODEL_NAME'] + '.pth')
w = mindtorch.serialization.load(args['MODEL_NAME'] + '.pth')
# 将所有权重转换为float32
self.num_layer = 0


+ 1
- 1
mindnlp/experimental/rwkv6/sampler_rwkv6.py View File

@@ -18,7 +18,7 @@
# limitations under the License.
# ============================================================================
import mindspore
from mindnlp.core import ops
from mindtorch import ops


def sample_logits(out: mindspore.Tensor, temperature: float = 1.0, top_p: float = 0.8) -> mindspore.Tensor:


+ 16
- 16
mindnlp/integrations/safetensors.py View File

@@ -5,9 +5,9 @@ import numpy as np

import safetensors

from mindnlp import core
import mindtorch

from core.configs import SUPPORT_BF16
from mindtorch.configs import SUPPORT_BF16

if SUPPORT_BF16:
from mindspore.common.np_dtype import bfloat16 # pylint: disable=import-error
@@ -20,19 +20,19 @@ MAX_HEADER_SIZE = 100 * 1000 * 1000


_MS_TYPES = {
"F64": core.float64,
"F32": core.float32,
"F16": core.float16,
"BF16": core.bfloat16,
"I64": core.int64,
"U64": core.uint64,
"I32": core.int32,
"U32": core.uint32,
"I16": core.int16,
"U16": core.uint16,
"I8": core.int8,
"U8": core.uint8,
"BOOL": core.bool,
"F64": mindtorch.float64,
"F32": mindtorch.float32,
"F16": mindtorch.float16,
"BF16": mindtorch.bfloat16,
"I64": mindtorch.int64,
"U64": mindtorch.uint64,
"I32": mindtorch.int32,
"U32": mindtorch.uint32,
"I16": mindtorch.int16,
"U16": mindtorch.uint16,
"I8": mindtorch.int8,
"U8": mindtorch.uint8,
"BOOL": mindtorch.bool,
}

_NP_TYPES = {
@@ -93,7 +93,7 @@ class PySafeSlice:
tensor = tensor.reshape(self.shape)
if not SUPPORT_BF16 and self.info["dtype"] == 'BF16':
tensor = tensor.astype(np.float16)
tensor = core.from_numpy(tensor)
tensor = mindtorch.from_numpy(tensor)
return tensor

@property


+ 2
- 2
mindnlp/quant/mindbnb/bitsandbytes/nn/modules.py View File

@@ -32,8 +32,8 @@ from bitsandbytes.utils import (
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
)

from mindnlp.core import nn
from core.nn import Parameter
from mindtorch import nn
from mindtorch.nn import Parameter


def empty(*size, dtype=None):


+ 1
- 1
mindnlp/quant/mindbnb/bitsandbytes/utils.py View File

@@ -23,7 +23,7 @@ from typing import Tuple
import mindspore
from mindspore import ops

from mindnlp.core import nn
from mindtorch import nn


def outlier_hook(module, input):


+ 2
- 2
mindnlp/quant/mindbnb/integrations/replace_modules.py View File

@@ -21,8 +21,8 @@ import logging
from bitsandbytes.nn.modules import Int8Params
import bitsandbytes as bnb

from mindnlp.core import nn
from core.nn import Parameter
from mindtorch import nn
from mindtorch.nn import Parameter

logger = logging.getLogger(__name__)



+ 1
- 1
mindnlp/quant/mindbnb/tests/test_mindbnb_linear.py View File

@@ -21,7 +21,7 @@ import mindspore.context
import numpy as np
import mindspore
from mindspore import Tensor
from mindnlp.core import nn
from mindtorch import nn
from bitsandbytes.nn import Linear8bitLt




+ 3
- 3
mindnlp/quant/smooth_quant/quant.py View File

@@ -3,9 +3,9 @@ from typing import Optional, Tuple
import mindspore
from mindspore import Tensor
from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register
from mindnlp.core import nn, ops
from mindnlp.core.serialization import load
from mindnlp.core.configs import ON_ORANGE_PI
from mindtorch import nn, ops
from mindtorch.serialization import load
from mindtorch.configs import ON_ORANGE_PI

from .smooth import smooth_lm



+ 1
- 1
mindnlp/quant/smooth_quant/smooth.py View File

@@ -1,7 +1,7 @@
'''
code from https://github.com/mit-han-lab/smoothquant/
'''
from mindnlp.core import ops, nn, no_grad
from mindtorch import ops, nn, no_grad

from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm



+ 1
- 1
mindnlp/transformers/__init__.py View File

@@ -1,6 +1,6 @@
import sys
from packaging import version
from mindnlp.core.configs import ON_ORANGE_PI
from mindtorch.configs import ON_ORANGE_PI
from mindnlp.utils.import_utils import *
from mindnlp.utils.import_utils import _LazyModule



+ 4
- 4
mindnlp/transformers/generation/logits_process.py View File

@@ -1,11 +1,11 @@
from mindnlp import core
import mindtorch

def InfNanRemoveLogitsProcessor_call(self, input_ids, scores):
# set all +/-inf values to max/min possible value
scores_processed = scores
scores_processed = core.where(scores == float("inf"), core.finfo(scores.dtype).max, scores_processed)
scores_processed = core.where(scores == -float("inf"), core.finfo(scores.dtype).min, scores_processed)
scores_processed = mindtorch.where(scores == float("inf"), mindtorch.finfo(scores.dtype).max, scores_processed)
scores_processed = mindtorch.where(scores == -float("inf"), mindtorch.finfo(scores.dtype).min, scores_processed)
# set all nan values to 0.0
scores_processed = core.where(scores != scores, 0.0, scores)
scores_processed = mindtorch.where(scores != scores, 0.0, scores)

return scores_processed

+ 111
- 111
mindnlp/transformers/masking_utils.py View File

@@ -15,11 +15,11 @@
import itertools
from typing import Callable, Optional, Union

from mindnlp import core
from mindnlp.core.nn import functional as F
import mindtorch
from mindtorch.nn import functional as F

# Register a fake type to avoid crashing for annotations and `isinstance` checks
BlockMask = core.Tensor
BlockMask = mindtorch.Tensor

_is_torch_greater_or_equal_than_2_6 = True

@@ -29,7 +29,7 @@ def and_masks(*mask_functions: list[Callable]) -> Callable:
raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}")

def and_mask(batch_idx, head_idx, q_idx, kv_idx):
result = q_idx.new_ones((), dtype=core.bool)
result = q_idx.new_ones((), dtype=mindtorch.bool)
for mask in mask_functions:
result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
return result
@@ -42,7 +42,7 @@ def or_masks(*mask_functions: list[Callable]) -> Callable:
raise RuntimeError(f"All inputs should be callable mask_functions: {mask_functions}")

def or_mask(batch_idx, head_idx, q_idx, kv_idx):
result = q_idx.new_zeros((), dtype=core.bool)
result = q_idx.new_zeros((), dtype=mindtorch.bool)
for mask in mask_functions:
result = result | mask(batch_idx, head_idx, q_idx, kv_idx).to(result.device)
return result
@@ -94,7 +94,7 @@ def chunked_causal_mask_function(chunk_size: int) -> Callable:
return and_masks(chunked_overlay(chunk_size), causal_mask_function)


def padding_mask_function(padding_mask: core.Tensor) -> Callable:
def padding_mask_function(padding_mask: mindtorch.Tensor) -> Callable:
"""
This return the mask_function function corresponding to a 2D padding mask.
"""
@@ -107,7 +107,7 @@ def padding_mask_function(padding_mask: core.Tensor) -> Callable:

return inner_mask

def packed_sequence_mask_function(packed_sequence_mask: core.Tensor) -> Callable:
def packed_sequence_mask_function(packed_sequence_mask: mindtorch.Tensor) -> Callable:
"""
This return the mask_function function corresponding to a 2D packed sequence mask.
"""
@@ -153,13 +153,13 @@ def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callabl
dimensions.extend([(None, 0, None, None), (0, None, None, None)])

for dims in dimensions:
mask_function = core.vmap(mask_function, in_dims=dims, out_dims=0)
mask_function = mindtorch.vmap(mask_function, in_dims=dims, out_dims=0)
return mask_function


def prepare_padding_mask(
attention_mask: Optional[core.Tensor], kv_length: int, kv_offset: int, _slice: bool = True
) -> Optional[core.Tensor]:
attention_mask: Optional[mindtorch.Tensor], kv_length: int, kv_offset: int, _slice: bool = True
) -> Optional[mindtorch.Tensor]:
"""
From the 2D attention mask, prepare the correct padding mask to use by potentially padding it, and slicing
according to the `kv_offset` if `_slice` is `True`.
@@ -168,19 +168,19 @@ def prepare_padding_mask(
if attention_mask is not None:
# Pad it if necesary
if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
local_padding_mask = core.nn.functional.pad(attention_mask, (0, padding_length))
local_padding_mask = mindtorch.nn.functional.pad(attention_mask, (0, padding_length))
# For flex, we should not slice them, only use an offset
if _slice:
# Equivalent to: `local_padding_mask = attention_mask[:, kv_offset : kv_offset + kv_length]`,
# but without data-dependent slicing (i.e. core.compile friendly)
mask_indices = core.arange(kv_length, device=local_padding_mask.device)
# but without data-dependent slicing (i.e. mindtorch.compile friendly)
mask_indices = mindtorch.arange(kv_length, device=local_padding_mask.device)
mask_indices += kv_offset
local_padding_mask = local_padding_mask[:, mask_indices]
return local_padding_mask


def _ignore_causal_mask_sdpa(
padding_mask: Optional[core.Tensor],
padding_mask: Optional[mindtorch.Tensor],
query_length: int,
kv_length: int,
kv_offset: int,
@@ -194,13 +194,13 @@ def _ignore_causal_mask_sdpa(
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
passed).
"""
is_tracing = core.jit.is_tracing() or isinstance(padding_mask, core.fx.Proxy)
is_tracing = mindtorch.jit.is_tracing() or isinstance(padding_mask, mindtorch.fx.Proxy)
if padding_mask is not None and padding_mask.shape[-1] > kv_length:
mask_indices = core.arange(kv_length, device=padding_mask.device)
mask_indices = mindtorch.arange(kv_length, device=padding_mask.device)
mask_indices += kv_offset
padding_mask = padding_mask[:, mask_indices]

# When using `core.export` or `core.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
# When using `mindtorch.export` or `mindtorch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
# hard-coded to the forward. If a user exports a model with query_length > 1, the exported model will hard-code `is_causal=True`
# which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). Thus, we only set
# `ignore_causal_mask = True` if we are not tracing
@@ -219,15 +219,15 @@ def _ignore_causal_mask_sdpa(

def sdpa_mask_recent_torch(
batch_size: int,
cache_position: core.Tensor,
cache_position: mindtorch.Tensor,
kv_length: int,
kv_offset: int = 0,
mask_function: Callable = causal_mask_function,
attention_mask: Optional[core.Tensor] = None,
attention_mask: Optional[mindtorch.Tensor] = None,
local_size: Optional[int] = None,
allow_is_causal_skip: bool = True,
**kwargs,
) -> Optional[core.Tensor]:
) -> Optional[mindtorch.Tensor]:
"""
Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
the element should take part in the attention computation, and False that it should not.
@@ -331,15 +331,15 @@ def sdpa_mask_recent_torch(

# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. torch.compile friendly)
kv_arange = core.arange(kv_length, device=cache_position.device)
kv_arange = mindtorch.arange(kv_length, device=cache_position.device)
kv_arange += kv_offset

# Potentially add the padding 2D mask
if padding_mask is not None:
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))

batch_arange = core.arange(batch_size, device=cache_position.device)
head_arange = core.arange(1, device=cache_position.device)
batch_arange = mindtorch.arange(batch_size, device=cache_position.device)
head_arange = mindtorch.arange(1, device=cache_position.device)
# This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
# scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
# We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
@@ -350,16 +350,16 @@ def sdpa_mask_recent_torch(

def sdpa_mask_older_torch(
batch_size: int,
cache_position: core.Tensor,
cache_position: mindtorch.Tensor,
kv_length: int,
kv_offset: int = 0,
mask_function: Callable = causal_mask_function,
attention_mask: Optional[core.Tensor] = None,
attention_mask: Optional[mindtorch.Tensor] = None,
local_size: Optional[int] = None,
allow_is_causal_skip: bool = True,
allow_torch_fix: bool = True,
**kwargs,
) -> Optional[core.Tensor]:
) -> Optional[mindtorch.Tensor]:
"""
NOTE: This function is only used when torch version is torch<2.5 - see `sdpa_mask_recent_torch` otherwise.

@@ -372,7 +372,7 @@ def sdpa_mask_older_torch(
Args:
batch_size (`int`):
The batch size of the input sequence.
cache_position (`core.Tensor`):
cache_position (`mindtorch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
kv_length (`int`):
The size that the key and value states will have during the attention computation.
@@ -380,14 +380,14 @@ def sdpa_mask_older_torch(
An optional offset to indicate at which first position the key and values states will refer to.
mask_function (`Callable`):
The mask factory function describing the mask pattern.
attention_mask (`core.Tensor`, optional):
attention_mask (`mindtorch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
local_size (`int`, optional):
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
to try to skip mask creation if possible.
allow_is_causal_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
`core.sdpa` instead. Default to `True`.
`mindtorch.sdpa` instead. Default to `True`.
allow_torch_fix (`bool`, optional):
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
versions. We need an arg to skip it when using eager. By default `True`.
@@ -400,9 +400,9 @@ def sdpa_mask_older_torch(
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size):
return None

# Similar to `kv_arange = core.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. core.compile friendly)
kv_arange = core.arange(kv_length, device=cache_position.device)
# Similar to `kv_arange = mindtorch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. mindtorch.compile friendly)
kv_arange = mindtorch.arange(kv_length, device=cache_position.device)
kv_arange += kv_offset

# This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well,
@@ -422,7 +422,7 @@ def sdpa_mask_older_torch(
# # Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any
# # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
# if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
# causal_mask |= core.all(~causal_mask, dim=-1, keepdim=True)
# causal_mask |= mindtorch.all(~causal_mask, dim=-1, keepdim=True)
return causal_mask


@@ -432,14 +432,14 @@ sdpa_mask = sdpa_mask_older_torch

def eager_mask(
batch_size: int,
cache_position: core.Tensor,
cache_position: mindtorch.Tensor,
kv_length: int,
kv_offset: int = 0,
mask_function: Callable = causal_mask_function,
attention_mask: Optional[core.Tensor] = None,
dtype: core.dtype = core.float32,
attention_mask: Optional[mindtorch.Tensor] = None,
dtype: mindtorch.dtype = mindtorch.float32,
**kwargs,
) -> core.Tensor:
) -> mindtorch.Tensor:
"""
Create a 4D float mask of shape `(batch_size, 1, query_length, kv_length)` where a value of 0 indicates that
the element should take part in the attention computation, and -inf (minimum value for the given `dtype`) that
@@ -448,7 +448,7 @@ def eager_mask(
Args:
batch_size (`int`):
The batch size of the input sequence.
cache_position (`core.Tensor`):
cache_position (`mindtorch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
kv_length (`int`):
The size that the key and value states will have during the attention computation.
@@ -456,10 +456,10 @@ def eager_mask(
An optional offset to indicate at which first position the key and values states will refer to.
mask_function (`Callable`):
The mask factory function describing the mask pattern.
attention_mask (`core.Tensor`, optional):
attention_mask (`mindtorch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
dtype (`core.dtype`, optional):
The dtype to use for the mask. By default, `core.float32`.
dtype (`mindtorch.dtype`, optional):
The dtype to use for the mask. By default, `mindtorch.float32`.
"""
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
_ = kwargs.pop("allow_is_causal_skip", None)
@@ -474,19 +474,19 @@ def eager_mask(
allow_torch_fix=False,
**kwargs,
)
min_dtype = core.finfo(dtype).min
min_dtype = mindtorch.finfo(dtype).min
# we need 0s where the tokens should be taken into account, and -inf otherwise (mask is already of boolean type)
mask = core.where(mask, core.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
mask = mindtorch.where(mask, mindtorch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype)
return mask


def flash_attention_mask(
batch_size: int,
cache_position: core.Tensor,
cache_position: mindtorch.Tensor,
kv_length: int,
kv_offset: int = 0,
mask_function: Callable = causal_mask_function,
attention_mask: Optional[core.Tensor] = None,
attention_mask: Optional[mindtorch.Tensor] = None,
**kwargs,
):
"""
@@ -497,7 +497,7 @@ def flash_attention_mask(
Args:
batch_size (`int`):
The batch size of the input sequence.
cache_position (`core.Tensor`):
cache_position (`mindtorch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
kv_length (`int`):
The size that the key and value states will have during the attention computation.
@@ -505,7 +505,7 @@ def flash_attention_mask(
An optional offset to indicate at which first position the key and values states will refer to.
mask_function (`Callable`):
The mask factory function describing the mask pattern.
attention_mask (`core.Tensor`, optional):
attention_mask (`mindtorch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
"""
if attention_mask is not None:
@@ -521,21 +521,21 @@ def flash_attention_mask(

def flex_attention_mask(
batch_size: int,
cache_position: core.Tensor,
cache_position: mindtorch.Tensor,
kv_length: int,
kv_offset: int = 0,
mask_function: Callable = causal_mask_function,
attention_mask: Optional[core.Tensor] = None,
attention_mask: Optional[mindtorch.Tensor] = None,
**kwargs,
) -> BlockMask:
"""
Create a 4D block mask which is a compressed representation of the full 4D block causal mask. BlockMask is essential
for performant computation of flex attention. See: https://pycore.org/blog/flexattention/
for performant computation of flex attention. See: https://pymindtorch.org/blog/flexattention/

Args:
batch_size (`int`):
The batch size of the input sequence.
cache_position (`core.Tensor`):
cache_position (`mindtorch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
kv_length (`int`):
The size that the key and value states will have during the attention computation.
@@ -543,7 +543,7 @@ def flex_attention_mask(
An optional offset to indicate at which first position the key and values states will refer to.
mask_function (`Callable`):
The mask factory function describing the mask pattern.
attention_mask (`core.Tensor`, optional):
attention_mask (`mindtorch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
"""
q_length, q_offset = cache_position.shape[0], cache_position[0]
@@ -555,7 +555,7 @@ def flex_attention_mask(
pad_len = ((attention_mask.shape[1] // flex_default_block_size) + 1) * flex_default_block_size
pad_len = pad_len - attention_mask.shape[1]
if not _is_torch_greater_or_equal_than_2_6 and pad_len > 0:
attention_mask = core.nn.functional.pad(attention_mask, value=0, pad=(0, pad_len))
attention_mask = mindtorch.nn.functional.pad(attention_mask, value=0, pad=(0, pad_len))

padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
@@ -585,13 +585,13 @@ ALL_MASK_ATTENTION_FUNCTIONS = {
"flex_attention": flex_attention_mask,
}

def find_packed_sequence_indices(position_ids: core.Tensor) -> core.Tensor:
def find_packed_sequence_indices(position_ids: mindtorch.Tensor) -> mindtorch.Tensor:
"""
Find the indices of the sequence to which each new query token in the sequence belongs when using packed
tensor format (i.e. several sequences packed in the same batch dimension).

Args:
position_ids (`core.Tensor`)
position_ids (`mindtorch.Tensor`)
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.

Returns:
@@ -604,7 +604,7 @@ def find_packed_sequence_indices(position_ids: core.Tensor) -> core.Tensor:
# Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence
# cannot be part of the end of the first batch dim and the start of the 2nd one for example
first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1
position_diff = core.diff(position_ids, prepend=first_dummy_value, dim=-1)
position_diff = mindtorch.diff(position_ids, prepend=first_dummy_value, dim=-1)
packed_sequence_mask = (position_diff != 1).cumsum(-1)

# Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0`
@@ -614,13 +614,13 @@ def find_packed_sequence_indices(position_ids: core.Tensor) -> core.Tensor:

def _preprocess_mask_arguments(
config,
input_embeds: core.Tensor,
attention_mask: Optional[Union[core.Tensor, BlockMask]],
cache_position: core.Tensor,
input_embeds: mindtorch.Tensor,
attention_mask: Optional[Union[mindtorch.Tensor, BlockMask]],
cache_position: mindtorch.Tensor,
past_key_values,
position_ids: Optional[core.Tensor],
position_ids: Optional[mindtorch.Tensor],
layer_idx: Optional[int],
) -> tuple[bool, Optional[Union[core.Tensor, BlockMask]], int, int]:
) -> tuple[bool, Optional[Union[mindtorch.Tensor, BlockMask]], int, int]:
"""
Perform some common pre-processing of the mask arguments we get from the modeling code. Mostly determine the
key-value length and offsets, and if we should early exit or not.
@@ -628,17 +628,17 @@ def _preprocess_mask_arguments(
Args:
config (`PretrainedConfig`):
The model config.
input_embeds (`core.Tensor`):
input_embeds (`mindtorch.Tensor`):
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
batch size, query length and dtype.
attention_mask (`core.Tensor`, optional):
attention_mask (`mindtorch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
It can also be an already prepared 4D mask, in which case it is returned as-is.
cache_position (`core.Tensor`):
cache_position (`mindtorch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
position_ids (`core.Tensor`, optional)
position_ids (`mindtorch.Tensor`, optional)
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
layer_idx (`int`, optional):
If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value
@@ -647,9 +647,9 @@ def _preprocess_mask_arguments(
Returns:
early_exit (`bool`):
Whether we should early exit mask creation, and return the mask as-is.
attention_mask (`core.Tensor` or `BlockMask` or `None`):
attention_mask (`mindtorch.Tensor` or `BlockMask` or `None`):
The attention mask to either return immediately, or to use in downstream mask creation.
packed_sequence_mask (`core.Tensor`, optional):
packed_sequence_mask (`mindtorch.Tensor`, optional):
In case we detected packed sequence format, this is a tensor where each similar integer indicates that
the tokens belong to the same sequence.
kv_length (`int`):
@@ -658,20 +658,20 @@ def _preprocess_mask_arguments(
An offset to indicate at which first position the key and values states will refer to.
"""
# If the mask is already 4D, simply return as-is (it was already prepared, or it is custom)
if isinstance(attention_mask, (core.Tensor, BlockMask)) and len(attention_mask.shape) == 4:
if isinstance(attention_mask, (mindtorch.Tensor, BlockMask)) and len(attention_mask.shape) == 4:
return True, attention_mask, None, None, None

# For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask!
# Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise
# full graph dynamo tracing (i.e. core.export or compile with `fullgraph=True`) will fail on Python<3.11
# with `core._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
# full graph dynamo tracing (i.e. mindtorch.export or compile with `fullgraph=True`) will fail on Python<3.11
# with `mindtorch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped
# according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11
if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS:
return True, None, None, None, None

# Move the mask to correct device, and potentially switch dtype for efficiency
if attention_mask is not None and attention_mask.ndim == 2:
attention_mask = attention_mask.to(device=cache_position.device, dtype=core.bool)
attention_mask = attention_mask.to(device=cache_position.device, dtype=mindtorch.bool)

# If using a cache, it can give all informations about mask sizes based on seen tokens
if past_key_values is not None:
@@ -695,14 +695,14 @@ def _preprocess_mask_arguments(

def create_causal_mask(
config,
input_embeds: core.Tensor,
attention_mask: Optional[core.Tensor],
cache_position: core.Tensor,
input_embeds: mindtorch.Tensor,
attention_mask: Optional[mindtorch.Tensor],
cache_position: mindtorch.Tensor,
past_key_values,
position_ids: Optional[core.Tensor] = None,
position_ids: Optional[mindtorch.Tensor] = None,
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[core.Tensor, BlockMask]]:
) -> Optional[Union[mindtorch.Tensor, BlockMask]]:
"""
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
has an HybridCache structure, this function will return the mask corresponding to one of the "full_attention" layers (to align
@@ -711,17 +711,17 @@ def create_causal_mask(
Args:
config (`PretrainedConfig`):
The model config.
input_embeds (`core.Tensor`):
input_embeds (`mindtorch.Tensor`):
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
batch size, query length and dtype.
attention_mask (`core.Tensor`, optional):
attention_mask (`mindtorch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
It can also be an already prepared 4D mask, in which case it is returned as-is.
cache_position (`core.Tensor`):
cache_position (`mindtorch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
position_ids (`core.Tensor`, optional)
position_ids (`mindtorch.Tensor`, optional)
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the causal mask function (by doing the union of both). This is
@@ -784,14 +784,14 @@ def create_causal_mask(

def create_sliding_window_causal_mask(
config,
input_embeds: core.Tensor,
attention_mask: Optional[core.Tensor],
cache_position: core.Tensor,
input_embeds: mindtorch.Tensor,
attention_mask: Optional[mindtorch.Tensor],
cache_position: mindtorch.Tensor,
past_key_values,
position_ids: Optional[core.Tensor] = None,
position_ids: Optional[mindtorch.Tensor] = None,
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[core.Tensor, BlockMask]]:
) -> Optional[Union[mindtorch.Tensor, BlockMask]]:
"""
Create a sliding window causal mask based on the attention implementation used (stored in the config). This type
of attention pattern was mostly democratized by Mistral. If `past_key_values` has an HybridCache structure, this
@@ -801,17 +801,17 @@ def create_sliding_window_causal_mask(
Args:
config (`PretrainedConfig`):
The model config.
input_embeds (`core.Tensor`):
input_embeds (`mindtorch.Tensor`):
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
batch size, query length and dtype.
attention_mask (`core.Tensor`, optional):
attention_mask (`mindtorch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
It can also be an already prepared 4D mask, in which case it is returned as-is.
cache_position (`core.Tensor`):
cache_position (`mindtorch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
position_ids (`core.Tensor`, optional)
position_ids (`mindtorch.Tensor`, optional)
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is
@@ -879,14 +879,14 @@ def create_sliding_window_causal_mask(

def create_chunked_causal_mask(
config,
input_embeds: core.Tensor,
attention_mask: Optional[core.Tensor],
cache_position: core.Tensor,
input_embeds: mindtorch.Tensor,
attention_mask: Optional[mindtorch.Tensor],
cache_position: mindtorch.Tensor,
past_key_values,
position_ids: Optional[core.Tensor] = None,
position_ids: Optional[mindtorch.Tensor] = None,
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
) -> Optional[Union[core.Tensor, BlockMask]]:
) -> Optional[Union[mindtorch.Tensor, BlockMask]]:
"""
Create a chunked attention causal mask based on the attention implementation used (stored in the config). This type
of attention pattern was mostly democratized by Llama4. If `past_key_values` has an HybridCache structure, this
@@ -896,17 +896,17 @@ def create_chunked_causal_mask(
Args:
config (`PretrainedConfig`):
The model config.
input_embeds (`core.Tensor`):
input_embeds (`mindtorch.Tensor`):
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
batch size, query length and dtype.
attention_mask (`core.Tensor`, optional):
attention_mask (`mindtorch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
It can also be an already prepared 4D mask, in which case it is returned as-is.
cache_position (`core.Tensor`):
cache_position (`mindtorch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
position_ids (`core.Tensor`, optional)
position_ids (`mindtorch.Tensor`, optional)
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is
@@ -988,11 +988,11 @@ LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING = {

def create_masks_for_generate(
config,
input_embeds: core.Tensor,
attention_mask: Optional[core.Tensor],
cache_position: core.Tensor,
input_embeds: mindtorch.Tensor,
attention_mask: Optional[mindtorch.Tensor],
cache_position: mindtorch.Tensor,
past_key_values,
position_ids: Optional[core.Tensor] = None,
position_ids: Optional[mindtorch.Tensor] = None,
or_mask_function: Optional[Callable] = None,
and_mask_function: Optional[Callable] = None,
**kwargs,
@@ -1004,17 +1004,17 @@ def create_masks_for_generate(
Args:
config (`PretrainedConfig`):
The model config.
input_embeds (`core.Tensor`):
input_embeds (`mindtorch.Tensor`):
The input embeddings of shape (batch_size, query_length, hidden_dim). This is used only to infer the
batch size, query length and dtype.
attention_mask (`core.Tensor`, optional):
attention_mask (`mindtorch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length).
It can also be an already prepared 4D mask, in which case it is returned as-is.
cache_position (`core.Tensor`):
cache_position (`mindtorch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
past_key_values (`Cache`, optional):
The past key values, if we use a cache.
position_ids (`core.Tensor`, optional)
position_ids (`mindtorch.Tensor`, optional)
A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences.
or_mask_function (`Callable`, optional):
An optional mask function to combine with the other mask function (by doing the union of both). This is
@@ -1087,7 +1087,7 @@ YELLOW_SQUARE = f"{YELLOW}{BLACK_SQUARE}{RESET}"
GREEN_SQUARE = f"{GREEN}{BLACK_SQUARE}{RESET}"


def tensor_to_mask_visual(original_tensor: core.Tensor, grid_size=(20, 40), style="majong") -> str:
def tensor_to_mask_visual(original_tensor: mindtorch.Tensor, grid_size=(20, 40), style="majong") -> str:
BLACK_SQUARE, WHITE_SQUARE, LOW_TRIANGLE, UPPER_TRIANGLE = get_style(style)
h, w = original_tensor.shape
max_h, max_w = grid_size
@@ -1139,11 +1139,11 @@ def tensor_to_mask_visual(original_tensor: core.Tensor, grid_size=(20, 40), styl
return "\n".join(result)


class AttentionMask(core.Tensor):
class AttentionMask(mindtorch.Tensor):
def __new__(cls, data, style=None):
# Create a new instance of AttentionMask as a Tensor
cls.style = style
return core.Tensor._make_subclass(cls, data, require_grad=False)
return mindtorch.Tensor._make_subclass(cls, data, require_grad=False)

def __init__(self, data):
# You can initialize any additional metadata here if needed
@@ -1164,7 +1164,7 @@ class AttentionMask(core.Tensor):
block_vis = tensor_to_mask_visual(dense_mask[batch_idx], grid_size=grid_size, style=self.style)
total_vis.append(block_vis)

total_vis.append(f"core.Tensor(shape={tuple(self.shape)}, dtype={self.dtype})")
total_vis.append(f"mindtorch.Tensor(shape={tuple(self.shape)}, dtype={self.dtype})")
return "\n".join(total_vis)

def __repr__(self):
@@ -1174,7 +1174,7 @@ class AttentionMask(core.Tensor):
return self.to_string()

@classmethod
def from_tensor(cls, tensor: core.Tensor, style: Optional[str] = None) -> "AttentionMask":
def from_tensor(cls, tensor: mindtorch.Tensor, style: Optional[str] = None) -> "AttentionMask":
res = cls(tensor)
res.style = style
return res

+ 1
- 2
mindnlp/transformers/modeling_utils.py View File

@@ -2,8 +2,7 @@
import types

from mindspore.communication import GlobalComm
from mindspore import runtime
from ..core import nn, ops, distributed as dist
from mindtorch import nn, ops, distributed as dist
from ..utils import logging

logger = logging.get_logger(__name__)


+ 2
- 2
mindnlp/transformers/ms_utils.py View File

@@ -20,8 +20,8 @@ from typing import Union, Optional, List, Tuple
import mindspore
from mindspore.common.initializer import initializer, Normal

from mindnlp.core import nn, ops
from mindnlp.core.nn import Parameter
from mindtorch import nn, ops
from mindtorch.nn import Parameter

ALL_LAYERNORM_LAYERS = [nn.LayerNorm]



+ 7
- 7
mindnlp/transformers/trainer.py View File

@@ -1,6 +1,6 @@
from typing import Union, Any, Optional
from mindnlp import core
from mindnlp.core import nn, autograd
import mindtorch
from mindtorch import nn, autograd

from transformers.training_args import OptimizerNames
from accelerate.utils import DistributedType
@@ -8,9 +8,9 @@ from accelerate.utils import DistributedType
def training_step(
self,
model: nn.Module,
inputs: dict[str, Union[core.Tensor, Any]],
num_items_in_batch: Optional[core.Tensor] = None,
) -> core.Tensor:
inputs: dict[str, Union[mindtorch.Tensor, Any]],
num_items_in_batch: Optional[mindtorch.Tensor] = None,
) -> mindtorch.Tensor:
"""
Perform a training step on a batch of inputs.

@@ -19,14 +19,14 @@ def training_step(
Args:
model (`nn.Module`):
The model to train.
inputs (`dict[str, Union[core.Tensor, Any]]`):
inputs (`dict[str, Union[mindtorch.Tensor, Any]]`):
The inputs and targets of the model.

The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.

Return:
`core.Tensor`: The tensor with training loss on this batch.
`mindtorch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):


+ 1
- 1
mindnlp/triton/__init__.py View File

@@ -1,7 +1,7 @@
"""triton adapter for mindspore"""
from functools import lru_cache
import mindspore
from mindnlp.core import ops
from mindtorch import ops
from mindnlp.utils import is_triton_available

if is_triton_available():


+ 1
- 1
mindnlp/utils/decorators.py View File

@@ -1,7 +1,7 @@
import warnings
import functools
import mindspore
from ..core.configs import ON_A1
from mindtorch.configs import ON_A1

def dtype_wrapper(fn):
def wrapper(*args, **kwargs):


+ 3
- 3
mindnlp/utils/safetensors_patch.py View File

@@ -4,8 +4,8 @@ from typing import OrderedDict
import numpy as np
import mindspore

from mindnlp import core
from mindnlp.core.configs import SUPPORT_BF16
import mindtorch
from mindtorch.configs import SUPPORT_BF16
import safetensors
from safetensors import SafetensorError

@@ -95,7 +95,7 @@ class PySafeSlice:
array = array[slice]
if not SUPPORT_BF16 and self.info["dtype"] == 'BF16':
array = array.astype(np.float16)
tensor = core.from_numpy(array)
tensor = mindtorch.from_numpy(array)
tensor._ptr = array.ctypes.data
return tensor



+ 1
- 1
mindnlp/utils/testing_utils.py View File

@@ -43,7 +43,7 @@ import urllib3
import numpy as np

import mindspore
from mindnlp.core.configs import SUPPORT_BF16
from mindtorch.configs import SUPPORT_BF16

from transformers.utils.import_utils import (
is_pytest_available,


mindnlp/core/_C/_ConvBackend.py → mindtorch/_C/_ConvBackend.py View File


mindnlp/core/_C/__init__.py → mindtorch/_C/__init__.py View File

@@ -2,7 +2,7 @@ from typing import Any
import mindspore
from mindspore.ops.operations._inner_ops import Generator as GeneratorOp
from mindnlp import core
import mindtorch
from . import _nn
from ..configs import DEVICE_TARGET
@@ -55,7 +55,7 @@ class device():
_id = None if _target == 'cpu' else 0
elif isinstance(type, device):
if index is not None:
raise ValueError("core.device(): When input is core.device, `index` can not be set.")
raise ValueError("mindtorch.device(): When input is mindtorch.device, `index` can not be set.")
_target = type.type
_id = type.index
elif isinstance(type, int):
@@ -67,9 +67,9 @@ class device():
_target = DEVICE_MAP[device_target]
else:
print(type)
raise TypeError("core.device(): `type` must be type of 'str' or 'core.device'.")
raise TypeError("mindtorch.device(): `type` must be type of 'str' or 'mindtorch.device'.")
else:
raise ValueError("core.device(): `type` can not be None")
raise ValueError("mindtorch.device(): `type` can not be None")
self.type = _target
self.index = _id
@@ -96,11 +96,11 @@ class device():
def __enter__(self):
# self.prev_idx = torch.cuda._exchange_device(self.idx)
core._bind.set_device_in_context(self)
mindtorch._bind.set_device_in_context(self)
def __exit__(self, type: Any, value: Any, traceback: Any):
# self.idx = torch.cuda._maybe_exchange_device(self.prev_idx)
core._bind.set_device_in_context(None)
mindtorch._bind.set_device_in_context(None)
return False
device_ = device
@@ -211,4 +211,10 @@ def _log_api_usage_once(*args):
pass
ScriptDict = dict
ScriptList = list
ScriptList = list
class _DistStoreError(RuntimeError): pass
def _get_accelerator():
device_target = mindspore.get_context("device_target")
return device_(DEVICE_MAP[device_target])

+ 117
- 0
mindtorch/_C/_distributed_c10d.py View File

@@ -0,0 +1,117 @@
import pickle
from typing import List, Any
from datetime import timedelta

import mindtorch
from mindtorch import Tensor
from mindtorch.distributed import Store, TCPStore
from mindtorch.distributed.c10d import Backend, ReduceOp


class ProcessGroup:
pass

class ProcessGroupGloo(Backend):
def __init__(
self,
store: Store,
rank: int,
size: int,
timeout: timedelta
) -> None:
super().__init__(rank, size)
self.store = store
self.ranks = []
self.pg = None

def name(self) -> str:
return 'gloo'

def allreduce(self, tensors: List[Tensor], opts: Any) -> Any:
if mindtorch.distributed.is_initialized():
self._allreduce_new_pg(tensors[0], opts)
else:
self._allreduce_use_store(tensors, opts)

def _allreduce_new_pg(self, tensor, opts):
# Get all global ranks
if len(self.ranks) == 0:
rank_bytes = pickle.dumps(mindtorch.distributed.get_rank())
self.store.set(f'__ar_rank_local_to_global_{self.rank_}', rank_bytes)
for local_rank in range(self.size_):
global_rank = pickle.loads(self.store.get(f'__ar_rank_local_to_global_{local_rank}'))
self.ranks.append(global_rank)

if self.pg is None:
self.pg = mindtorch.distributed.new_group(self.ranks, backend='gloo')

mindtorch.distributed.all_reduce(tensor, op=opts.reduceOp, group=self.pg, async_op=False)

def _allreduce_use_store(self, tensors: List[Tensor], opts: Any) -> Any:
tensor = tensors[0]
tensor_bytes = pickle.dumps(tensor)
self.store.set(f'__ar_data_{self.rank_}', tensor_bytes)

# Gather all tensors
gathered = []
for i in range(self.size_):
data = self.store.get(f'__ar_data_{i}')
gathered.append(pickle.loads(data))
stacked = mindtorch.stack(gathered)

reduce_op = opts.reduceOp
if reduce_op == ReduceOp.SUM:
result = stacked.sum(dim=0)
elif reduce_op == ReduceOp.MAX:
if stacked.dtype == mindtorch.int32:
result = stacked.to(mindtorch.int64).max(dim=0).values.to(mindtorch.int32)
else:
result = stacked.max(dim=0).values
elif reduce_op == ReduceOp.MIN:
if stacked.dtype == mindtorch.int32:
result = stacked.to(mindtorch.int64).min(dim=0)[0].to(mindtorch.int32)
else:
result = stacked.min(dim=0)[0]
elif reduce_op == ReduceOp.PRODUCT:
result = stacked.prod(dim=0)
else:
raise ValueError(f'Unsupported reduce operation: {reduce_op}')

tensors[0].copy_(result)
self._synchronize_and_cleanup()

def _synchronize_and_cleanup(self):
if self.rank_ == 0:
# Wait for the completion of allreduce() execution for other ranks and remove the tensor_i key
# to prevent subsequent allreduce() exceptions.
for i in range(1, self.size_):
self.store.get(f'__ar_finish_1_{i}')
for i in range(self.size_):
self.store.delete_key(f'__ar_data_{i}')
self.store.delete_key(f'__ar_finish_1_{i}')

# Ensure that other ranks wait for the deletion of tensor_i key to complete.
self.store.set('__ar_finish_all', '')

# Ensure that rank 0 exits last to prevent errors in other ranks.
for i in range(1, self.size_):
self.store.get(f'__ar_finish_2_{i}')
self.store.delete_key(f'__ar_finish_2_{i}')
self.store.delete_key('__ar_finish_all')
else:
self.store.set(f'__ar_finish_1_{self.rank_}', '')
self.store.get('__ar_finish_all')
self.store.set(f'__ar_finish_2_{self.rank_}', '')

def _set_sequence_number_for_group(self):
pass


class ProcessGroupHCCL:
def __init__(self, group_name):
self.group_name = group_name

def get_hccl_comm_name(self, global_rank):
return self.group_name

class Options: ...

mindnlp/core/_C/_nn.py → mindtorch/_C/_nn.py View File

@@ -1,30 +1,30 @@
from mindnlp import core
import mindtorch

def _parse_to(*args, **kwargs):
"""
Mimic core._C._nn._parse_to functionality in Python.
Mimic mindtorch._C._nn._parse_to functionality in Python.
Args:
tensor (core.Tensor): The tensor to parse.
tensor (mindtorch.Tensor): The tensor to parse.
*args: Positional arguments for `to`.
**kwargs: Keyword arguments for `to`.

Returns:
core.Tensor: The tensor with the desired properties.
mindtorch.Tensor: The tensor with the desired properties.
"""
if len(args) == 1:
# Handle `device` or `dtype`
if isinstance(args[0], core.dtype): # dtype only
if isinstance(args[0], mindtorch.dtype): # dtype only
dtype = args[0]
device = None
elif isinstance(args[0], core.device): # device only
elif isinstance(args[0], mindtorch.device): # device only
device = args[0]
dtype = None
elif isinstance(args[0], (str, int)):
device = core.device(args[0])
device = mindtorch.device(args[0])
dtype = None
else:
raise TypeError(f"Expected core.dtype or core.device, but got {type(args[0])}")
raise TypeError(f"Expected mindtorch.dtype or mindtorch.device, but got {type(args[0])}")
elif len(args) == 2:
# Handle `device` and `dtype`
dtype = args[1]

mindnlp/core/_C/size.py → mindtorch/_C/size.py View File

@@ -19,7 +19,7 @@ class Size(tuple):
return _get_tuple_numel(self)
def __repr__(self):
return "core.Size(" + str(list(self)) + ")"
return "mindtorch.Size(" + str(list(self)) + ")"
def __getitem__(self, slice):
out = super().__getitem__(slice)

mindnlp/core/__future__.py → mindtorch/__future__.py View File


mindnlp/core/__init__.py → mindtorch/__init__.py View File

@@ -14,6 +14,7 @@
# ============================================================================
"""core module"""
import os
import platform
import math
from typing import (
Any as _Any,
@@ -29,6 +30,32 @@ from typing import (
import mindspore
from mindspore.runtime import Stream
from mindspore.common.api import _pynative_executor
from mindspore._c_expression import MSContext # pylint: disable=no-name-in-module, import-error
# for huawei cloud modelarts
if 'RANK_TABLE_FILE' in os.environ:
del os.environ['RANK_TABLE_FILE']
try:
from mindspore._c_expression import disable_multi_thread
except:
disable_multi_thread = None
if os.environ.get('DEVICE_TARGET', None) is not None:
mindspore.set_device(os.environ.get('DEVICE_TARGET'))
# for different ascend devices
if platform.system().lower() == 'linux' and mindspore.get_context('device_target') == 'Ascend':
SOC = MSContext.get_instance().get_ascend_soc_version()
# enable vmm since only vmm can release device memory when del tensor.
if SOC != 'ascend310b':
os.environ["MS_ALLOC_CONF"] = 'enable_vmm:True,vmm_align_size:2MB'
if SOC in ('ascend910', 'ascend310b'):
# context.set_context(ascend_config={"precision_mode": "allow_mix_precision"})
mindspore.device_context.ascend.op_precision.precision_mode('allow_mix_precision')
if SOC == 'ascend310b' and disable_multi_thread is not None:
disable_multi_thread()
pi = math.pi
strided = None
@@ -43,7 +70,7 @@ inf = float("inf")
nan = float("nan")
from . import _C
from ._dtype import *
from ._tensor import Tensor, tensor, scalar_tensor, is_tensor, \
LongTensor, FloatTensor, BoolTensor, HalfTensor, BFloat16Tensor, IntTensor
@@ -148,6 +175,9 @@ from . import profiler, cuda, amp, compiler, jit, version, __future__, overrides
from ._lowrank import svd_lowrank
from .random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state
from .torch_proxy import initialize_torch_proxy, setup_metadata_patch
initialize_torch_proxy()
setup_metadata_patch()
__version__ = 'test_version_no_value'

mindnlp/core/_apis/__init__.py → mindtorch/_apis/__init__.py View File


mindnlp/core/_apis/cpu.py → mindtorch/_apis/cpu.py View File

@@ -4,14 +4,14 @@ import math
import numpy as np
import mindspore
from mindspore._c_expression import _empty_instance
from mindnlp import core
import mindtorch
from .._op_prim.cpu import legacy

def empty(*args, **kwargs):
return _empty_instance(*args, **kwargs, device='CPU')

def inplace_normal(input, mean, std, generator_):
out = np.random.normal(mean, std, input.shape).astype(core.dtype2np[input.dtype])
out = np.random.normal(mean, std, input.shape).astype(mindtorch.dtype2np[input.dtype])
numpy_to_tensor_overwrite(out, input)

return input
@@ -53,7 +53,7 @@ def tensor_shape(input):
return legacy.tensor_shape(input)

def arange(start, end, step, dtype):
return core.Tensor.from_numpy(np.arange(start, end, step, core.dtype2np[dtype]))
return mindtorch.Tensor.from_numpy(np.arange(start, end, step, mindtorch.dtype2np[dtype]))

def broadcast_to(input, shape):
return legacy.broadcast_to(input, shape)
@@ -64,7 +64,7 @@ def zeros(shape, dtype):
def inplace_uniform(input, from_, to_, generator_):
seed, _ = generator_._step(12)
np.random.seed(seed.item())
out = np.random.uniform(from_, to_, input.shape).astype(core.dtype2np[input.dtype])
out = np.random.uniform(from_, to_, input.shape).astype(mindtorch.dtype2np[input.dtype])
numpy_to_tensor_overwrite(out, input)
return input

@@ -178,7 +178,7 @@ def pad_v3(input, new_pad, mode, value=None, contiguous=True):

def cumsum(self, dim, dtype):
if self.shape[dim] == 0:
return core.tensor([], dtype=self.dtype, device=self.device)
return mindtorch.tensor([], dtype=self.dtype, device=self.device)
return legacy.cum_sum(self, dim, False, False)

def reduce_any(input, axis, keepdims):
@@ -956,12 +956,12 @@ def conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
return legacy.conv3_d(input, weight, bias, stride, padding, dilation, groups)

def normal_float_float(mean, std, size, dtype, generator):
out = np.random.normal(mean, std, size).astype(core.dtype2np[dtype])
out = np.random.normal(mean, std, size).astype(mindtorch.dtype2np[dtype])
out = mindspore.Tensor(out)
return out

def normal_tensor_tensor(mean, std, size, dtype, generator):
out = np.random.normal(mean.item(), std.item(), size).astype(core.dtype2np[dtype])
out = np.random.normal(mean.item(), std.item(), size).astype(mindtorch.dtype2np[dtype])
out = mindspore.Tensor(out)
return out

@@ -1159,7 +1159,7 @@ def randperm(n, generator, dtype):

def gamma(shape, alpha, beta):
out = np.random.gamma(alpha, 1/beta, shape)
return core.Tensor.from_numpy(out)
return mindtorch.Tensor.from_numpy(out)

def logical_or(input_x, input_y):
return legacy.logical_or(input_x, input_y)
@@ -1202,7 +1202,7 @@ def linalg_qr(input_x, mode):

def diag(input, diagonal):
out = np.diag(input.numpy(), diagonal)
return core.Tensor.from_numpy(out)
return mindtorch.Tensor.from_numpy(out)

def logit(input, eps=1e-5):
return legacy.logit(input, eps)

mindnlp/core/_apis/gpu.py → mindtorch/_apis/gpu.py View File

@@ -3,7 +3,7 @@ import numbers
import math
import mindspore
from mindspore._c_expression import _empty_instance
from mindnlp import core
import mindtorch
from .._op_prim.gpu import legacy

try:
@@ -125,8 +125,8 @@ def div(input, other):
return legacy.div(input, other)

def mul(input, other):
if input.dtype == core.bool:
if isinstance(other, bool) or (not isinstance(other, numbers.Number) and other.dtype == core.bool):
if input.dtype == mindtorch.bool:
if isinstance(other, bool) or (not isinstance(other, numbers.Number) and other.dtype == mindtorch.bool):
return bitwise_and_scalar(input, other)
return legacy.mul(input, other)

@@ -165,7 +165,7 @@ def pad_v3(input, new_pad, mode, value=None, contiguous=True):

def cumsum(self, dim, dtype):
if self.shape[dim] == 0:
return core.tensor([], dtype=self.dtype, device=self.device)
return mindtorch.tensor([], dtype=self.dtype, device=self.device)
return legacy.cum_sum(self, dim, False, False)

def reduce_any(input, axis, keepdims):

mindnlp/core/_apis/meta.py → mindtorch/_apis/meta.py View File

@@ -5,13 +5,13 @@ except:

import math
import numpy as np
from mindnlp import core
import mindtorch

__all__ = []

def arange(start, end, step, dtype):
out = Tensor_(shape=(math.ceil((end - start) / step), ), dtype=dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('arange')

@@ -26,19 +26,19 @@ def broadcast_to(input, shape):
out_shape += (s,)

out = Tensor_(shape=out_shape, dtype=input.dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('broadcast_to')

def zeros(size, dtype):
out = Tensor_(shape=size, dtype=dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('zeros')

def ones(size, dtype):
out = Tensor_(shape=size, dtype=dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('ones')

@@ -60,12 +60,12 @@ __all__.append('inplace_normal')
def getitem(input, slice):
out = input.asnumpy()[slice]
out = Tensor_(shape=out.shape, dtype=input.dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('getitem')

def sub(input, other, alpha):
if isinstance(input, core.Tensor):
if isinstance(input, mindtorch.Tensor):
return input
return other

@@ -74,7 +74,7 @@ __all__.append('sub')
def pad_v3(input, pad, mode, value):
out = np.pad(input.asnumpy(), pad, mode, constant_values=value)
out = Tensor_(shape=out.shape, dtype=input.dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('pad_v3')

@@ -85,20 +85,20 @@ __all__.append('abs')

def cast(input, dtype):
out = Tensor_(shape=input.shape, dtype=dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('cast')

def index_select(input, dim, index):
out = np.take(input.asnumpy(), index.asnumpy(), dim)
out = Tensor_(shape=out.shape, dtype=input.dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('index_select')

def identity(input):
out = Tensor_(shape=input.shape, dtype=input.dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('identity')

@@ -113,20 +113,20 @@ def inplace_copy(input, other):
__all__.append('inplace_copy')

def div(input, other):
if isinstance(input, core.Tensor):
if isinstance(input, mindtorch.Tensor):
shape = input.shape
dtype = input.dtype
else:
shape = other.shape
dtype = other.dtype
out = Tensor_(shape=shape, dtype=dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('div')

def pow_scalar_tensor(input, other):
out = Tensor_(shape=other.shape, dtype=other.dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('pow_scalar_tensor')

@@ -134,7 +134,7 @@ def concat(tensors, dim):
shape = list(tensors[0].shape)
shape[dim] = sum([t.shape[dim] for t in tensors])
out = Tensor_(shape=tuple(shape), dtype=tensors[0].dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('concat')

@@ -145,7 +145,7 @@ __all__.append('tril')

def reshape(input, shape):
out = Tensor_(shape=tuple(shape), dtype=input.dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('reshape')

@@ -163,7 +163,7 @@ def linalg_vector_norm(input, p, dim, keepdim, dtype):
if dtype is None:
dtype = input.dtype
out = Tensor_(shape=tuple(new_shape), dtype=dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('linalg_vector_norm')

@@ -174,7 +174,7 @@ __all__.append('erfinv')

def stop_gradient(input):
out = Tensor_(shape=input.shape, dtype=input.dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('stop_gradient')

@@ -184,18 +184,18 @@ __all__.append('log')

def mul(input, other):
out = Tensor_(shape=input.shape, dtype=input.dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)
__all__.append('mul')

def randn(size, generator, dtype):
out = Tensor_(shape=size, dtype=dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('randn')

def zeros_like(input, *args, **kwargs):
out = Tensor_(shape=input.shape, dtype=input.dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)
__all__.append('zeros_like')

def inplace_add(input, other, alpha):
@@ -211,7 +211,7 @@ def expand_dims(input, dim):
input_shape.insert(dim, 1)

out = Tensor_(shape=tuple(input_shape), dtype=input.dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)


def floor_div(input, other):
@@ -235,9 +235,9 @@ __all__.append('triu')

def fill_scalar(size, fill_value, dtype):
if dtype is None:
dtype = core.get_default_dtype()
dtype = mindtorch.get_default_dtype()
out = Tensor_(shape=size, dtype=dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('fill_scalar')

@@ -247,8 +247,8 @@ def sqrt(input):
__all__.append('sqrt')

def normal_float_float(mean, std, size, geneartor):
out = Tensor_(shape=size, dtype=core.float32)
return core.Tensor(out)
out = Tensor_(shape=size, dtype=mindtorch.float32)
return mindtorch.Tensor(out)


__all__.append('normal_float_float')
@@ -257,7 +257,7 @@ def stack(tensors, dim):
x_shape = list(tensors[0].shape)
x_shape.insert(dim, len(tensors))
out = Tensor_(shape=tuple(x_shape), dtype=tensors[0].dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('stack')

@@ -268,10 +268,10 @@ def argmax_with_value(input, dim, keepdim):
else:
out_shape.pop(dim)

indices = Tensor_(shape=out_shape, dtype=core.int64)
indices = Tensor_(shape=out_shape, dtype=mindtorch.int64)
values = Tensor_(shape=out_shape, dtype=input.dtype)

return core.Tensor(indices), core.Tensor(values)
return mindtorch.Tensor(indices), mindtorch.Tensor(values)

__all__.append('argmax_with_value')

@@ -279,7 +279,7 @@ def tile(input, dims):
input_shape = input.shape
out_shape = [input_shape[i] * dims[i] for i in range(input.ndim)]
out = Tensor_(shape=tuple(out_shape), dtype=input.dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('tile')

@@ -292,7 +292,7 @@ def flatten(input, start_dim, end_dim):

flatten_shape = input_shape[:start_dim] + input_shape[start_dim:end_dim+1] + input_shape[end_dim+1:]
out = Tensor_(shape=tuple(flatten_shape), dtype=input.dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('flatten')

@@ -312,7 +312,7 @@ def squeeze(input, dim):
new_shape += (s,)

out = Tensor_(shape=tuple(new_shape), dtype=input.dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('squeeze')

@@ -323,7 +323,7 @@ __all__.append('exp')

def rand(size, generator, dtype):
out = Tensor_(shape=size, dtype=dtype)
return core.Tensor(out)
return mindtorch.Tensor(out)

__all__.append('rand')

@@ -358,14 +358,14 @@ def bitwise_xor_tensor(input, other):
__all__.append('bitwise_xor_tensor')

def divmod(input, other, rounding_mode):
if isinstance(input, core.Tensor):
if isinstance(input, mindtorch.Tensor):
return input
return other

__all__.append('divmod')

def greater_equal(input, other):
if isinstance(input, core.Tensor):
if isinstance(input, mindtorch.Tensor):
return input
return other


mindnlp/core/_apis/npu.py → mindtorch/_apis/npu.py View File


mindnlp/core/_bind.py → mindtorch/_bind.py View File

@@ -163,17 +163,17 @@ class finfo:
return str(self._dtype)
def asarray(obj: Any, *, dtype, device=None, copy = None, requires_grad = False):
data = obj.data.view(core.dtype2np[dtype])
out = core.Tensor(data)
core._utils.set_device_address(out)
data = obj.data.view(mindtorch.dtype2np[dtype])
out = mindtorch.Tensor(data)
mindtorch._utils.set_device_address(out)
return out
def view(self, dtype):
data_ptr = self.data_ptr()
nbytes = self.nbytes
data = np.ctypeslib.as_array((ctypes.c_byte * nbytes).from_address(data_ptr), shape=(nbytes,))
data = data.view(core.dtype2np[dtype])
data = data.view(mindtorch.dtype2np[dtype])
assert data_ptr == data.ctypes.data
out = core.Tensor(data)
core._utils.set_device_address(out)
out = mindtorch.Tensor(data)
mindtorch._utils.set_device_address(out)
return out

mindnlp/core/_custom_ops.py → mindtorch/_custom_ops.py View File


mindnlp/core/_dtype.py → mindtorch/_dtype.py View File


mindnlp/core/_dynamo/__init__.py → mindtorch/_dynamo/__init__.py View File


mindnlp/core/_dynamo/_trace_wrapped_higher_order_op.py → mindtorch/_dynamo/_trace_wrapped_higher_order_op.py View File


mindnlp/core/_dynamo/config.py → mindtorch/_dynamo/config.py View File


mindnlp/core/_dynamo/decorators.py → mindtorch/_dynamo/decorators.py View File


mindnlp/core/_dynamo/eval_frame.py → mindtorch/_dynamo/eval_frame.py View File


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save
Baidu
map