55 Commits

Author SHA1 Message Date
  Zhanghao Wu fa2a89c5dc
update values. 6 days ago
  Christopher Cooper af0b6d0963 Merge branch 'master' of github.com:skypilot-org/skypilot into consolidation-docs 6 days ago
  Christopher Cooper b5ff2fdb54
[deps] pin pycares<5 to work around aiodns issue (#8259) 6 days ago
  Christopher Cooper 62868dc428 Merge branch 'master' of github.com:skypilot-org/skypilot into consolidation-docs 6 days ago
  Christopher Cooper 2a16743f31 address comments 6 days ago
  lloyd-brown 9f7bdf7d28
[Tests] Add Cloud Selection to CLI Smoke Tests (#8256) 6 days ago
  DanielZhangQD 6f12d52a0f
[Helm] Support updating ssh node pool config with helm chart (#8249) 1 week ago
  Aylei 3bdad29488
Fixed API server mem bench on k8s (#8254) 1 week ago
  zpoint c28d94abd9
Fix `test_nemorl` failure (#8253) 1 week ago
  zpoint 6259aa05fd
Fix smoke test failure `test_container_logs_two_jobs_kubernetes` (#8250) 1 week ago
  Kevin Mingtarja 8cc5193918
[Slurm] Unify ssh proxycommand config for run and rsync in SlurmCommandRunner (#8248) 1 week ago
  Seung Jin 343123b289
[Slurm] configurable Slurm provision timeout, set default to 10s (#8244) 1 week ago
  Kevin Mingtarja c13dfb9c04
[Slurm] Fix UV_CACHE_DIR permission issues with multiple users (#8245) 1 week ago
  Seung Jin ff928a2757
[Slurm] show Slurm infra at cluster level (as opposed to partition level) (#8246) 1 week ago
  Kevin Mingtarja 6152d193a8
[Slurm] Remove unnecessary setup commands (#8247) 1 week ago
  Kevin Mingtarja fb234495ff
[Test] Skip test_job_queue if k8s cluster has no GPUs (#8242) 1 week ago
  Seung Jin f723ff528f
[k8s] Disable ray memory monitor on k8s (#8231) 1 week ago
  Seung Jin 81892e7f86
Display default Slurm partitions first (#8239) 1 week ago
  Kevin Mingtarja 2ef047a7a4
[Test] Fix capitalization in test_kubernetes_slurm_show_gpus (#8238) 1 week ago
  mk0walsk 4deb1be77f
Update torchtune documentation links (#8237) 1 week ago
  DanielZhangQD 87dcd33a96
[Kubernetes] Do not schedule on not ready nodes (#8172) 1 week ago
  Aylei 06adc6ba37
Add memory benchmark in smoke test (#8161) 1 week ago
  lloyd-brown b9d575f46a
[Pools] Fix Non-Existent Bucket Error on Scale Down (#8214) 1 week ago
  Kevin Mingtarja fced388758
[Slurm] Multi-node clusters (#8219) 1 week ago
  DanielZhangQD ab1fddebd2
[Dashboard] Align the top bar elements (#8235) 1 week ago
  Christopher Cooper 4bfae458f1
[core] restore cluster_name_on_cloud from cluster yaml (#8233) 1 week ago
  Daniel Shin f0062485bd
[SSH Node Pools] Refactor Deployments (#8226) 1 week ago
  Seung Jin 8700f616d1
[SSH node pools] Ban heterogeneous nodes when setting up SSH node pool (#8230) 1 week ago
  Daniel Shin 31955a4da7
[Docs] Add SSH Node Pools Resource Comment (#8223) 1 week ago
  lloyd-brown d86dd037af
[Core] Redact Docker Password from Provision Logs (#8080) 1 week ago
  Caleb Whitaker 3b36b99500
Fix Mistral documentation links (#8229) 1 week ago
  DanielZhangQD e11e4d6845
[CLI] Hide `volume` from `sky -h` (#8228) 1 week ago
  Daniel Shin 6b2b5cbe75
[Vast] Provide options to only provision on secure instances (#8212) 1 week ago
  Kevin Mingtarja 5a288c0360
Fix COPY_SKYPILOT_TEMPLATES_COMMANDS on consecutive launches (#8218) 1 week ago
  Kevin Mingtarja fd635367eb
[Slurm] Support specifying partitions as zones (#8198) 1 week ago
  David Young 898862a2aa
[feat] Add helm chart support for DigitalOcean credentials (#7931) 1 week ago
  Seung Jin bb73286bc3
[Slurm] robust slurm check against nonexistent config file (#8207) 1 week ago
  DanielZhangQD 9c9ea1a755
[Dashboard] Update topbar style (#8213) 1 week ago
  Aylei fb8cc1fea1
API server plugin support (#7993) 1 week ago
  Aylei f801fb3a3a
Fixed incorrect user info in handlers (#8209) 1 week ago
  DanielZhangQD ba32ce2395
[Chart] Support coreweave credentials in helm chart (#8200) 1 week ago
  lloyd-brown 9174f75d38
[Core] Put Daemonize Call Back to Make Sky Cancel Reliable (#8203) 1 week ago
  江家瑋 f7eb867568
[slurm] Slurm support (#5491) 1 week ago
  Seung Jin 846c47a021
separate out optimizer dryruns to its separate group (#7950) 1 week ago
  Daniel Shin 3b7a2ac2c8
[SSH Node Pools] Move Deployment to `sky/ssh_node_pools` folder and other refactor (#8173) 1 week ago
  vincent d warmerdam 6b38fadcab
Docs: marimo notebooks as a skypilot job (#8123) 1 week ago
  Brian Strauch 305bc803b7
Change log level from info to error on request failure (#7686) 1 week ago
  Aylei 0d4f1bd611
Fixed incorrect user in /enabled_clouds API (#8199) 1 week ago
  Aylei e64c9081e8
Fixed AWS session cache memory leak in status refresh daemon (#8098) 1 week ago
  Aylei e002c44923
Fix ephemeral volume creation (#8179) 1 week ago
  Christopher Cooper d8494f1775
[gh action] support publishing rc versions to pypi (#8188) 1 week ago
  Seung Jin 9a9abb8198
[Examples] Improve GitOps Example (#8030) 1 week ago
  Seung Jin 042648986f
[Examples] Reintroduce CLI based approaches for select examples (#8182) 1 week ago
  Seung Jin d907688dd7
demonstrate use of secrets in marimo example (#8162) 1 week ago
  Seung Jin a005696479
[Docs] Bring GCP archive download instructions for macOS inline (#8181) 1 week ago
100 changed files with 7053 additions and 449 deletions
Split View
  1. +16
    -3
      .buildkite/generate_pipeline.py
  2. +73
    -0
      .github/actions/run-python-tests/action.yaml
  3. +5
    -2
      .github/workflows/publish-and-validate.yml
  4. +41
    -0
      .github/workflows/pytest-optimizer.yml
  5. +5
    -59
      .github/workflows/pytest.yml
  6. +84
    -8
      charts/skypilot/templates/api-deployment.yaml
  7. +97
    -0
      charts/skypilot/tests/deployment_test.yaml
  8. +22
    -0
      charts/skypilot/values.schema.json
  9. +11
    -0
      charts/skypilot/values.yaml
  10. +200
    -31
      docs/source/examples/interactive-development.rst
  11. +14
    -0
      docs/source/examples/managed-jobs.rst
  12. +14
    -2
      docs/source/getting-started/installation.rst
  13. +1
    -0
      docs/source/images/jobs-consolidation-mode.svg
  14. BIN
      docs/source/images/marimo-auth.png
  15. BIN
      docs/source/images/marimo-example.png
  16. BIN
      docs/source/images/marimo-job.png
  17. BIN
      docs/source/images/marimo-nvidea.png
  18. BIN
      docs/source/images/marimo-use.png
  19. +527
    -0
      docs/source/reference/api-server/api-server-admin-deploy.rst
  20. +79
    -0
      docs/source/reference/api-server/helm-values-spec.rst
  21. +21
    -0
      docs/source/reference/config.rst
  22. +9
    -0
      docs/source/reservations/existing-machines.rst
  23. +24
    -1
      examples/github_actions/README.md
  24. +4
    -1
      examples/marimo/marimo.yaml
  25. +9
    -0
      examples/plugin/README.md
  26. +6
    -0
      examples/plugin/example_plugin/__init__.py
  27. +123
    -0
      examples/plugin/example_plugin/plugin.py
  28. +8
    -0
      examples/plugin/plugins.yaml
  29. +7
    -0
      examples/plugin/pyproject.toml
  30. +8
    -1
      examples/resnet_distributed_torch.yaml
  31. +30
    -2
      examples/vector_database/README.md
  32. +2
    -2
      llm/llama-3_1-finetuning/readme.md
  33. +22
    -1
      llm/localgpt/README.md
  34. +2
    -2
      llm/mixtral/README.md
  35. +2
    -0
      sky/__init__.py
  36. +1
    -61
      sky/adaptors/aws.py
  37. +478
    -0
      sky/adaptors/slurm.py
  38. +45
    -4
      sky/backends/backend_utils.py
  39. +32
    -33
      sky/backends/cloud_vm_ray_backend.py
  40. +340
    -2
      sky/backends/task_codegen.py
  41. +0
    -3
      sky/catalog/__init__.py
  42. +12
    -4
      sky/catalog/kubernetes_catalog.py
  43. +243
    -0
      sky/catalog/slurm_catalog.py
  44. +14
    -3
      sky/check.py
  45. +329
    -22
      sky/client/cli/command.py
  46. +56
    -2
      sky/client/sdk.py
  47. +2
    -0
      sky/clouds/__init__.py
  48. +7
    -0
      sky/clouds/cloud.py
  49. +578
    -0
      sky/clouds/slurm.py
  50. +2
    -1
      sky/clouds/ssh.py
  51. +10
    -0
      sky/clouds/vast.py
  52. +128
    -36
      sky/core.py
  53. +19
    -0
      sky/dashboard/src/components/elements/icons.jsx
  54. +240
    -5
      sky/dashboard/src/components/elements/sidebar.jsx
  55. +180
    -68
      sky/dashboard/src/components/infra.jsx
  56. +214
    -0
      sky/dashboard/src/data/connectors/infra.jsx
  57. +15
    -5
      sky/dashboard/src/pages/_app.js
  58. +139
    -0
      sky/dashboard/src/pages/plugins/[...slug].js
  59. +345
    -0
      sky/dashboard/src/plugins/PluginProvider.jsx
  60. +16
    -2
      sky/data/mounting_utils.py
  61. +3
    -3
      sky/global_user_state.py
  62. +2
    -0
      sky/models.py
  63. +6
    -5
      sky/optimizer.py
  64. +1
    -0
      sky/provision/__init__.py
  65. +20
    -0
      sky/provision/common.py
  66. +15
    -2
      sky/provision/docker_utils.py
  67. +42
    -6
      sky/provision/kubernetes/utils.py
  68. +15
    -6
      sky/provision/provisioner.py
  69. +12
    -0
      sky/provision/slurm/__init__.py
  70. +13
    -0
      sky/provision/slurm/config.py
  71. +572
    -0
      sky/provision/slurm/instance.py
  72. +583
    -0
      sky/provision/slurm/utils.py
  73. +4
    -1
      sky/provision/vast/instance.py
  74. +10
    -6
      sky/provision/vast/utils.py
  75. +1
    -1
      sky/serve/server/impl.py
  76. +1
    -1
      sky/server/constants.py
  77. +222
    -0
      sky/server/plugins.py
  78. +5
    -2
      sky/server/requests/executor.py
  79. +12
    -1
      sky/server/requests/payloads.py
  80. +2
    -0
      sky/server/requests/request_names.py
  81. +5
    -1
      sky/server/requests/requests.py
  82. +17
    -0
      sky/server/requests/serializers/encoders.py
  83. +60
    -0
      sky/server/requests/serializers/return_value_serializers.py
  84. +78
    -8
      sky/server/server.py
  85. +30
    -0
      sky/server/server_utils.py
  86. +17
    -6
      sky/setup_files/dependencies.py
  87. +13
    -3
      sky/skylet/attempt_skylet.py
  88. +34
    -9
      sky/skylet/constants.py
  89. +10
    -4
      sky/skylet/events.py
  90. +52
    -0
      sky/skylet/executor/README.md
  91. +1
    -0
      sky/skylet/executor/__init__.py
  92. +189
    -0
      sky/skylet/executor/slurm.py
  93. +2
    -1
      sky/skylet/job_lib.py
  94. +22
    -6
      sky/skylet/log_lib.py
  95. +8
    -6
      sky/skylet/log_lib.pyi
  96. +5
    -1
      sky/skylet/skylet.py
  97. +2
    -1
      sky/skylet/subprocess_daemon.py
  98. +12
    -0
      sky/ssh_node_pools/constants.py
  99. +40
    -3
      sky/ssh_node_pools/core.py
  100. +4
    -0
      sky/ssh_node_pools/deploy/__init__.py

+ 16
- 3
.buildkite/generate_pipeline.py View File

@@ -43,6 +43,7 @@ QUEUE_GENERIC_CLOUD = 'generic_cloud'
QUEUE_EKS = 'eks'
QUEUE_GKE = 'gke'
QUEUE_KIND = 'kind'
QUEUE_BENCHMARK = 'single_container'
# We use a separate queue for generic cloud tests on remote servers because:
# - generic_cloud queue has high concurrency on a single VM
# - remote-server requires launching a docker container per test
@@ -63,6 +64,7 @@ CLOUD_QUEUE_MAP = {
'nebius': QUEUE_GENERIC_CLOUD,
'lambda': QUEUE_GENERIC_CLOUD,
'runpod': QUEUE_GENERIC_CLOUD,
'slurm': QUEUE_GENERIC_CLOUD,
'kubernetes': QUEUE_KIND
}

@@ -71,8 +73,11 @@ GENERATED_FILE_HEAD = ('# This is an auto-generated Buildkite pipeline by '
'edit directly.\n')


def _get_buildkite_queue(cloud: str, remote_server: bool,
run_on_cloud_kube_backend: bool, args: str) -> str:
def _get_buildkite_queue(cloud: str,
remote_server: bool,
run_on_cloud_kube_backend: bool,
args: str,
benchmark_test: bool = False) -> str:
"""Get the Buildkite queue for a given cloud.

We use a separate queue for generic cloud tests on remote servers because:
@@ -82,11 +87,17 @@ def _get_buildkite_queue(cloud: str, remote_server: bool,

Kubernetes has low concurrency on a single VM originally,
so remote-server won't drain VM resources, we can reuse the same queue.

For benchmark test, we use a dedicated benchmark queue that has guaranteed
resources offering to get reliable performance results.
"""
env_queue = os.environ.get('BUILDKITE_QUEUE', None)
if env_queue:
return env_queue

if benchmark_test:
return QUEUE_BENCHMARK

if '--env-file' in args:
# TODO(zeping): Remove this when test requirements become more varied.
# Currently, tests specifying --env-file and a custom API server endpoint are assigned to
@@ -246,6 +257,7 @@ def _extract_marked_tests(
clouds_to_include = []
run_on_cloud_kube_backend = ('resource_heavy' in marks and
'kubernetes' in default_clouds_to_run)
benchmark_test = 'benchmark' in marks

for mark in marks:
if mark not in PYTEST_TO_CLOUD_KEYWORD:
@@ -284,7 +296,8 @@ def _extract_marked_tests(
] * (len(final_clouds_to_include) - len(param_list))
function_cloud_map[function_name] = (final_clouds_to_include, [
_get_buildkite_queue(cloud, remote_server,
run_on_cloud_kube_backend, args)
run_on_cloud_kube_backend, args,
benchmark_test)
for cloud in final_clouds_to_include
], param_list, [
extra_args for _ in range(len(final_clouds_to_include))


+ 73
- 0
.github/actions/run-python-tests/action.yaml View File

@@ -0,0 +1,73 @@
name: "Run Python Tests"
description: "Setup and run Python tests"

inputs:
python-version:
description: "Python version to use"
required: true
test-path:
description: "Path to the test file or directory"
required: true
test-name:
description: "Name of the test"
required: true
buildkite-analytics-token:
description: "Buildkite analytics token"
required: true

runs:
using: "composite"
steps:
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v4
with:
version: "latest"
python-version: ${{ inputs.python-version }}
- name: Install dependencies
shell: bash
run: |
uv venv --seed ~/test-env
source ~/test-env/bin/activate
uv pip install --prerelease=allow "azure-cli>=2.65.0"
# Use -e to include examples and tests folder in the path for unit
# tests to access them.
uv pip install -e ".[all]"
uv pip install pytest pytest-xdist pytest-env>=0.6 pytest-asyncio memory-profiler==0.61.0 buildkite-test-collector
- name: Set branch name and commit message
id: set_env
shell: bash
run: |
if [ "${{ github.event_name }}" == "pull_request" ]; then
echo "branch=${{ github.head_ref }}" >> $GITHUB_OUTPUT
{
echo "message<<EOF"
echo "${{ github.event.pull_request.title }}"
echo "EOF"
} >> $GITHUB_OUTPUT
else
echo "branch=${{ github.ref_name }}" >> $GITHUB_OUTPUT
# For push events after merge, head_commit.message contains the merge commit message
# which includes the PR title. For merge_group events, use merge_group.head_commit.message
{
echo "message<<EOF"
if [ "${{ github.event_name }}" == "merge_group" ]; then
echo "${{ github.event.merge_group.head_commit.message }}"
else
echo "${{ github.event.head_commit.message }}"
fi
echo "EOF"
} >> $GITHUB_OUTPUT
fi
- name: Run tests with pytest
shell: bash
env:
TEST_PATH: ${{ inputs.test-path }}
TEST_NAME: ${{ inputs.test-name }}
BUILDKITE_ANALYTICS_TOKEN: ${{ inputs.buildkite-analytics-token }}
run: |
source ~/test-env/bin/activate
if [[ "$TEST_NAME" == "No Parallel Tests" || "$TEST_NAME" == "Event Loop Lag Tests" ]]; then
SKYPILOT_DISABLE_USAGE_COLLECTION=1 SKYPILOT_SKIP_CLOUD_IDENTITY_CHECK=1 eval "pytest -n 0 --dist no $TEST_PATH"
else
SKYPILOT_DISABLE_USAGE_COLLECTION=1 SKYPILOT_SKIP_CLOUD_IDENTITY_CHECK=1 eval "pytest -n 4 --dist worksteal $TEST_PATH"
fi

+ 5
- 2
.github/workflows/publish-and-validate.yml View File

@@ -44,6 +44,8 @@ jobs:
- name: Validate published package
run: |
export SKYPILOT_DISABLE_USAGE_COLLECTION=1

# fastapi has some broken package info on test PyPI, so manually install it from real PyPI.
pip install fastapi

# Set up variables for package check
@@ -66,12 +68,13 @@ jobs:
pip uninstall -y ${{ inputs.package_name }} || true

# Install the package with no cache
# Use --pre so that pre-release versions (e.g. rcs) will be selected
if [ "${{ inputs.repository_type }}" == "test-pypi" ]; then
echo "Installing from Test PyPI..."
pip install --no-cache-dir --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple ${{ inputs.package_name }}
pip install --no-cache-dir --pre --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple ${{ inputs.package_name }}[server]
else
echo "Installing from PyPI..."
pip install --no-cache-dir ${{ inputs.package_name }}
pip install --no-cache-dir --pre ${{ inputs.package_name }}[server]
fi

# Check the version


+ 41
- 0
.github/workflows/pytest-optimizer.yml View File

@@ -0,0 +1,41 @@
name: Python Tests
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- master
- 'releases/**'
pull_request:
branches:
- master
- 'releases/**'
merge_group:

jobs:
python-test-optimizer:
strategy:
matrix:
python-version: ["3.8"]
test-path:
- "tests/test_optimizer_dryruns.py -k \"partial\""
- "tests/test_optimizer_dryruns.py -k \"not partial\""
- tests/test_optimizer_random_dag.py
include:
- test-path: "tests/test_optimizer_dryruns.py -k \"partial\""
test-name: "Optimizer Dryruns Part 1"
- test-path: "tests/test_optimizer_dryruns.py -k \"not partial\""
test-name: "Optimizer Dryruns Part 2"
- test-path: tests/test_optimizer_random_dag.py
test-name: "Optimizer Random DAG Tests"
runs-on: ubuntu-latest
name: "Python Tests - ${{ matrix.test-name }}"
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Run Python Tests
uses: ./.github/actions/run-python-tests
with:
python-version: ${{ matrix.python-version }}
test-path: ${{ matrix.test-path }}
test-name: ${{ matrix.test-name }}

+ 5
- 59
.github/workflows/pytest.yml View File

@@ -21,11 +21,8 @@ jobs:
# Group them based on running time to save CI time and resources
- tests/unit_tests
- tests/test_cli.py
- "tests/test_optimizer_dryruns.py -k \"partial\""
- "tests/test_optimizer_dryruns.py -k \"not partial\""
- tests/test_jobs_and_serve.py tests/test_yaml_parser.py tests/test_global_user_state.py tests/test_config.py tests/test_jobs.py tests/test_list_accelerators.py tests/test_wheels.py tests/test_api.py tests/test_storage.py tests/test_api_compatibility.py
- tests/test_no_parellel.py
- tests/test_optimizer_random_dag.py
- tests/test_ssh_proxy_lag.py
include:
# We separate out the random DAG tests because its flaky due to catalog updates.
@@ -34,16 +31,10 @@ jobs:
test-name: "Unit Tests"
- test-path: tests/test_cli.py
test-name: "CLI Tests"
- test-path: "tests/test_optimizer_dryruns.py -k \"partial\""
test-name: "Optimizer Dryruns Part 1"
- test-path: "tests/test_optimizer_dryruns.py -k \"not partial\""
test-name: "Optimizer Dryruns Part 2"
- test-path: tests/test_jobs_and_serve.py tests/test_yaml_parser.py tests/test_global_user_state.py tests/test_config.py tests/test_jobs.py tests/test_list_accelerators.py tests/test_wheels.py tests/test_api.py tests/test_storage.py tests/test_api_compatibility.py tests/test_infra_k8s_alias.py
test-name: "Jobs, Serve, Wheels, API, Config, Optimizer & Storage Tests"
- test-path: tests/test_no_parellel.py
test-name: "No Parallel Tests"
- test-path: tests/test_optimizer_random_dag.py
test-name: "Optimizer Random DAG Tests"
- test-path: tests/test_ssh_proxy_lag.py
test-name: "Event Loop Lag Tests"
runs-on: ubuntu-latest
@@ -51,58 +42,13 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v4
- name: Run Python Tests
uses: ./.github/actions/run-python-tests
with:
version: "latest"
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
uv venv --seed ~/test-env
source ~/test-env/bin/activate
uv pip install --prerelease=allow "azure-cli>=2.65.0"
# Use -e to include examples and tests folder in the path for unit
# tests to access them.
uv pip install -e ".[all]"
uv pip install pytest pytest-xdist pytest-env>=0.6 pytest-asyncio memory-profiler==0.61.0 buildkite-test-collector
- name: Set branch name and commit message
id: set_env
run: |
if [ "${{ github.event_name }}" == "pull_request" ]; then
echo "branch=${{ github.head_ref }}" >> $GITHUB_OUTPUT
{
echo "message<<EOF"
echo "${{ github.event.pull_request.title }}"
echo "EOF"
} >> $GITHUB_OUTPUT
else
echo "branch=${{ github.ref_name }}" >> $GITHUB_OUTPUT
# For push events after merge, head_commit.message contains the merge commit message
# which includes the PR title. For merge_group events, use merge_group.head_commit.message
{
echo "message<<EOF"
if [ "${{ github.event_name }}" == "merge_group" ]; then
echo "${{ github.event.merge_group.head_commit.message }}"
else
echo "${{ github.event.head_commit.message }}"
fi
echo "EOF"
} >> $GITHUB_OUTPUT
fi
- name: Run tests with pytest
env:
TEST_PATH: ${{ matrix.test-path }}
TEST_NAME: ${{ matrix.test-name }}
BUILDKITE_ANALYTICS_TOKEN: ${{ secrets.BUILDKITE_ANALYTICS_TOKEN }}
run: |
source ~/test-env/bin/activate
export GITHUB_REF="${{ steps.set_env.outputs.branch }}"
export TEST_ANALYTICS_COMMIT_MESSAGE="${{ steps.set_env.outputs.message }}"
if [[ "$TEST_NAME" == "No Parallel Tests" || "$TEST_NAME" == "Event Loop Lag Tests" ]]; then
SKYPILOT_DISABLE_USAGE_COLLECTION=1 SKYPILOT_SKIP_CLOUD_IDENTITY_CHECK=1 eval "pytest -n 0 --dist no $TEST_PATH"
else
SKYPILOT_DISABLE_USAGE_COLLECTION=1 SKYPILOT_SKIP_CLOUD_IDENTITY_CHECK=1 eval "pytest -n 4 --dist worksteal $TEST_PATH"
fi
test-path: ${{ matrix.test-path }}
test-name: ${{ matrix.test-name }}
buildkite-analytics-token: ${{ secrets.BUILDKITE_ANALYTICS_TOKEN }}

limited-deps-test:
# Test with limited dependencies to ensure cloud module imports don't break


+ 84
- 8
charts/skypilot/templates/api-deployment.yaml View File

@@ -188,13 +188,9 @@ spec:
fi
{{- if .Values.apiService.sshNodePools }}
mkdir -p /root/.sky
# The PVC serves as the ground truth for the ssh_node_pools.yaml file, if it already exists we don't overwrite it
if [ ! -s /root/.sky/ssh_node_pools.yaml ]; then
echo "ssh_node_pools.yaml not found in /root/.sky, copying from ConfigMap \`skypilot-ssh-node-pools\`"
cp /var/skypilot/ssh_node_pool/ssh_node_pools.yaml /root/.sky/ssh_node_pools.yaml
else
echo "ssh_node_pools.yaml already exists in /root/.sky, skipping copy"
fi
echo "Linking ssh_node_pools.yaml from secret to /root/.sky/ssh_node_pools.yaml"
# The secret serves as the ground truth for the ssh_node_pools.yaml file, read-only
ln -sf /var/skypilot/ssh_node_pool/ssh_node_pools.yaml /root/.sky/ssh_node_pools.yaml
# ~/.kube/config is required to be persistent when sshNodePools is enabled, init it if it is empty to avoid parsing error.
if [ ! -s /root/.kube/config ]; then
echo "{}" > /root/.kube/config
@@ -294,6 +290,11 @@ spec:
mountPath: /root/.cloudflare
readOnly: true
{{- end }}
{{- if .Values.coreweaveCredentials.enabled }}
- name: coreweave-config
mountPath: /root/.coreweave
readOnly: true
{{- end }}
{{- if .Values.gcpCredentials.enabled }}
- name: gcp-config
mountPath: /root/.config/gcloud
@@ -320,6 +321,11 @@ spec:
mountPath: /root/.runpod
readOnly: true
{{- end }}
{{- if .Values.digitaloceanCredentials.enabled }}
- name: digitalocean-config
mountPath: /root/.config/doctl
readOnly: true
{{- end }}
{{- if .Values.lambdaCredentials.enabled }}
- name: lambda-config
mountPath: /root/.lambda_cloud
@@ -388,6 +394,7 @@ spec:
echo "Credentials file created successfully."
else
echo "AWS credentials not found in environment variables. Skipping credentials setup."
echo "Sleeping for 10 minutes before exiting for debugging purposes."
sleep 600
fi
env:
@@ -483,6 +490,29 @@ spec:
- name: gcp-config
mountPath: /root/.config/gcloud
{{- end }}
{{- if .Values.coreweaveCredentials.enabled }}
- name: setup-coreweave-credentials
{{- with .Values.securityContext }}
securityContext:
{{- toYaml . | nindent 10 }}
{{- end }}
image: {{ include "common.image" (dict "root" . "image" .Values.apiService.image) }}
command: ["/bin/sh", "-c"]
{{- if $.Values.global.extraEnvs }}
env:
{{- with $.Values.global.extraEnvs }}
{{- toYaml . | nindent 8 }}
{{- end }}
{{- end }}
args:
- |
cp /root/.coreweave_credentials/* /root/.coreweave/
volumeMounts:
- name: coreweave-credentials
mountPath: /root/.coreweave_credentials
- name: coreweave-config
mountPath: /root/.coreweave
{{- end }}
{{- if .Values.runpodCredentials.enabled }}
- name: create-runpod-credentials
{{- with .Values.securityContext }}
@@ -501,6 +531,7 @@ spec:
echo "api_key = \"$RUNPOD_API_KEY\"" >> /root/.runpod/config.toml
else
echo "RunPod credentials not found in environment variables. Skipping credentials setup."
echo "Sleeping for 10 minutes before exiting for debugging purposes."
sleep 600
fi
env:
@@ -516,6 +547,38 @@ spec:
- name: runpod-config
mountPath: /root/.runpod
{{- end }}
{{- if .Values.digitaloceanCredentials.enabled }}
- name: create-digitalocean-credentials
{{- with .Values.securityContext }}
securityContext:
{{- toYaml . | nindent 10 }}
{{- end }}
image: {{ .Values.apiService.image }}
command: ["/bin/sh", "-c"]
args:
- |
echo "Setting up Digital Ocean credentials..."
if [ -n "$DIGITALOCEAN_CREDENTIALS" ]; then
echo "Digital Ocean credentials found in environment variable."
mkdir -p /.config/doctl
cat > /.config/doctl/config.yaml <<EOF
access-token: $DIGITALOCEAN_CREDENTIALS
EOF
else
echo "Digital Ocean credentials not found in environment variables. Skipping credentials setup."
echo "Sleeping for 10 minutes before exiting for debugging purposes."
sleep 600
fi
env:
- name: DIGITALOCEAN_CREDENTIALS
valueFrom:
secretKeyRef:
name: {{ .Values.digitaloceanCredentials.digitaloceanSecretName }}
key: api_key
volumeMounts:
- name: digitalocean-config
mountPath: /.config/doctl
{{- end }}
{{- if .Values.lambdaCredentials.enabled }}
- name: create-lambda-credentials
{{- with .Values.securityContext }}
@@ -533,6 +596,7 @@ spec:
echo "api_key = $LAMBDA_API_KEY" > /root/.lambda_cloud/lambda_keys
else
echo "Lambda credentials not found in environment variables. Skipping credentials setup."
echo "Sleeping for 10 minutes before exiting for debugging purposes."
sleep 600
fi
env:
@@ -565,6 +629,7 @@ spec:
echo "$VAST_API_KEY" > /root/.config/vastai/vast_api_key
else
echo "Vast credentials not found in environment variables. Skipping credentials setup."
echo "Sleeping for 10 minutes before exiting for debugging purposes."
sleep 600
fi
env:
@@ -612,10 +677,21 @@ spec:
- name: gcp-config
emptyDir: {}
{{- end }}
{{- if .Values.coreweaveCredentials.enabled }}
- name: coreweave-credentials
secret:
secretName: {{ .Values.coreweaveCredentials.coreweaveSecretName }}
- name: coreweave-config
emptyDir: {}
{{- end }}
{{- if .Values.runpodCredentials.enabled }}
- name: runpod-config
emptyDir: {}
{{- end }}
{{- if .Values.digitaloceanCredentials.enabled }}
- name: digitalocean-config
emptyDir: {}
{{- end }}
{{- if .Values.lambdaCredentials.enabled }}
- name: lambda-config
emptyDir: {}
@@ -666,4 +742,4 @@ spec:
{{- with .Values.apiService.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- end }}

+ 97
- 0
charts/skypilot/tests/deployment_test.yaml View File

@@ -490,3 +490,100 @@ tests:
content:
name: aws-config
emptyDir: {}

# Test cases for coreweaveCredentials configuration
- it: should mount coreweave credentials correctly when coreweaveCredentials is enabled
set:
coreweaveCredentials.enabled: true
coreweaveCredentials.coreweaveSecretName: my-coreweave-secret
asserts:
# Should have coreweave-credentials secret volume
- contains:
path: spec.template.spec.volumes
content:
name: coreweave-credentials
secret:
secretName: my-coreweave-secret
# Should have coreweave-config emptyDir volume
- contains:
path: spec.template.spec.volumes
content:
name: coreweave-config
emptyDir: {}
# Should mount coreweave-config in container
- contains:
path: spec.template.spec.containers[0].volumeMounts
content:
name: coreweave-config
mountPath: /root/.coreweave
readOnly: true

# Should have init container
- contains:
path: spec.template.spec.initContainers
any: true
content:
name: setup-coreweave-credentials
# Init container should have correct volume mounts
- contains:
path: spec.template.spec.initContainers
any: true
content:
volumeMounts:
- name: coreweave-credentials
mountPath: /root/.coreweave_credentials
- name: coreweave-config
mountPath: /root/.coreweave

- it: should inject global extra envs into coreweave init container
set:
global.extraEnvs:
- name: GLOBAL_ENV
value: global
coreweaveCredentials.enabled: true
coreweaveCredentials.coreweaveSecretName: my-coreweave-secret
asserts:
- contains:
path: spec.template.spec.initContainers
any: true
content:
name: setup-coreweave-credentials
env:
- name: GLOBAL_ENV
value: global

- it: should not mount coreweave credentials when coreweaveCredentials is disabled
set:
coreweaveCredentials.enabled: false
asserts:
# Should NOT have coreweave-credentials volume
- notContains:
path: spec.template.spec.volumes
content:
name: coreweave-credentials
# Should NOT have coreweave-config volume
- notContains:
path: spec.template.spec.volumes
content:
name: coreweave-config
emptyDir: {}
# Should NOT have volumeMount
- notMatchRegexRaw:
pattern: "mountPath: /root/.coreweave"
# Should NOT have init container
- notMatchRegexRaw:
pattern: "setup-coreweave-credentials"

- it: should prefix coreweave credentials init container image with the global registry override
set:
global.imageRegistry: registry.example.com/custom
coreweaveCredentials.enabled: true
coreweaveCredentials.coreweaveSecretName: my-coreweave-secret
apiService.image: berkeleyskypilot/skypilot-nightly:latest
asserts:
- contains:
path: spec.template.spec.initContainers
any: true
content:
name: setup-coreweave-credentials
image: registry.example.com/custom/berkeleyskypilot/skypilot-nightly:latest

+ 22
- 0
charts/skypilot/values.schema.json View File

@@ -266,6 +266,28 @@
}
}
},
"coreweaveCredentials": {
"type": "object",
"properties": {
"coreweaveSecretName": {
"type": "string"
},
"enabled": {
"type": "boolean"
}
}
},
"digitaloceanCredentials": {
"type": "object",
"properties": {
"digitaloceanSecretName": {
"type": "string"
},
"enabled": {
"type": "boolean"
}
}
},
"extraInitContainers": {
"type": [
"array",


+ 11
- 0
charts/skypilot/values.yaml View File

@@ -487,6 +487,12 @@ runpodCredentials:
# Name of the secret containing the RunPod credentials. Only used if enabled is true.
runpodSecretName: runpod-credentials

# Populate Digital Ocean credentials from the secret with key `api_key`
digitaloceanCredentials:
enabled: false
# Name of the secret containing the Digital Ocean credentials. Only used if enabled is true.
digitaloceanSecretName: digitalocean-credentials

# Populate Lambda credentials from the secret with key `api_key`
lambdaCredentials:
enabled: false
@@ -512,6 +518,11 @@ r2Credentials:
# Name of the secret containing the r2 credentials. Only used if enabled is true.
r2SecretName: r2-credentials

coreweaveCredentials:
enabled: false
# Name of the secret containing the coreweave credentials. Only used if enabled is true.
coreweaveSecretName: coreweave-credentials

# Extra init containers to run before the api container
# @schema type: [array, null]; item: object
extraInitContainers: null


+ 200
- 31
docs/source/examples/interactive-development.rst View File

@@ -13,6 +13,7 @@ SkyPilot makes interactive development easy on Kubernetes or cloud VMs. It helps
- :ref:`SSH <dev-ssh>`
- :ref:`VSCode <dev-vscode>`
- :ref:`Jupyter Notebooks <dev-notebooks>`
- :ref:`marimo Notebooks <marimo-notebooks>`

.. _dev-launch:

@@ -137,21 +138,27 @@ Inside the cluster, you can run the following commands to start a Jupyter sessio

In your local browser, you should now be able to access :code:`localhost:8888` and see the following screen:

.. image:: ../images/jupyter-auth.png
:width: 100%
:alt: Jupyter authentication window
.. dropdown:: Jupyter authentication page

.. image:: ../images/jupyter-auth.png
:width: 100%
:alt: Jupyter authentication page

Enter the password or token and you will be directed to a page where you can create a new notebook.

.. image:: ../images/jupyter-create.png
:width: 100%
:alt: Create a new Jupyter notebook
.. dropdown:: Jupyter home page

.. image:: ../images/jupyter-create.png
:width: 100%
:alt: Create a new Jupyter notebook

You can verify that this notebook is running on the GPU-backed instance using :code:`nvidia-smi`.

.. image:: ../images/jupyter-gpu.png
:width: 100%
:alt: nvidia-smi in notebook
.. dropdown:: nvidia-smi in Jupyter notebook

.. image:: ../images/jupyter-gpu.png
:width: 100%
:alt: nvidia-smi in notebook

The GPU node is a normal SkyPilot cluster, so you can use the usual CLI commands on it. For example, run ``sky down/stop`` to terminate or stop it, and ``sky exec`` to execute a task.

@@ -163,31 +170,33 @@ range of SkyPilot's features including :ref:`mounted storage <sky-storage>` and

The following :code:`jupyter.yaml` is an example of a task specification that can launch notebooks with SkyPilot.

.. code:: yaml
.. dropdown:: jupyter.yaml

# jupyter.yaml
.. code:: yaml

name: jupyter
# jupyter.yaml

resources:
accelerators: L4:1
name: jupyter

file_mounts:
/covid:
source: s3://fah-public-data-covid19-cryptic-pockets
mode: MOUNT
resources:
accelerators: L4:1

setup: |
pip install --upgrade pip
conda init bash
conda create -n jupyter python=3.9 -y
conda activate jupyter
pip install jupyter
file_mounts:
/covid:
source: s3://fah-public-data-covid19-cryptic-pockets
mode: MOUNT

run: |
cd ~/sky_workdir
conda activate jupyter
jupyter notebook --port 8888 &
setup: |
pip install --upgrade pip
conda init bash
conda create -n jupyter python=3.9 -y
conda activate jupyter
pip install jupyter

run: |
cd ~/sky_workdir
conda activate jupyter
jupyter notebook --port 8888 &

Launch the GPU-backed Jupyter notebook:

@@ -203,12 +212,172 @@ To access the notebook locally, use SSH port forwarding.

You can verify that this notebook has access to the mounted storage bucket.

.. image:: ../images/jupyter-covid.png
:width: 100%
:alt: accessing covid data from notebook
.. dropdown:: Jupyter notebook page

.. image:: ../images/jupyter-covid.png
:width: 100%
:alt: accessing covid data from notebook

.. _marimo-notebooks:

marimo notebooks
~~~~~~~~~~~~~~~~~

To start a marimo notebook interactively via ``sky``, you can connect to the machine and forward the
port that you want marimo to use:

.. code-block:: bash

ssh -L 8080:localhost:8080 dev

Inside the cluster, you can run the following commands to start marimo.

.. note::
By starting the notebook this way it runs in a completely sandboxed environment. The ``uvx`` command ensures that
we can use ``marimo`` without installing it in a pre-existing environment and the ``--sandbox`` flag
makes sure that any dependencies of the notebook are installed in a separate environment too.

.. code-block:: bash

pip install uv
uvx marimo edit --sandbox demo.py --port 8080 --token-password=supersecret

In your local browser, you should now be able to access :code:`localhost:8080` and see the following screen:

.. dropdown:: marimo authentication page

.. image:: ../images/marimo-auth.png
:width: 100%
:alt: marimo authentication page

Enter the password or token and you will be directed to your notebook.

.. dropdown:: marimo notebook page

.. image:: ../images/marimo-use.png
:width: 100%
:alt: What a newly created marimo notebook looks like

You can verify that this notebook is running on the GPU-backed instance using :code:`nvidia-smi` in
the terminal that marimo provides from the browser.

.. dropdown:: nvidia-smi in marimo notebook

.. image:: ../images/marimo-nvidea.png
:width: 100%
:alt: nvidia-smi in marimo notebook

marimo as SkyPilot jobs
^^^^^^^^^^^^^^^^^^^^^^^

Because marimo notebooks are stored as Python scripts on disk, you can immediately use it as a SkyPilot job too.

To demonstrate this, let's consider the following marimo notebook:

.. dropdown:: marimo notebook example

.. image:: ../images/marimo-example.png
:width: 100%
:alt: marimo notebook example

This is the underlying code:

.. dropdown:: marimo notebook example code

.. code-block:: python

# /// script
# requires-python = ">=3.12"
# dependencies = [
# "marimo",
# ]
# ///

import marimo

__generated_with = "0.18.1"
app = marimo.App(sql_output="polars")


@app.cell
def _():
import marimo as mo
return (mo,)


@app.cell
def _(mo):
print(mo.cli_args())
return


if __name__ == "__main__":
app.run()

This notebook uses :code:`mo.cli_args()` to parse any command-line arguments passed to the notebook.
A more real-life use-case would take such arguments to train a PyTorch model, but this
tutorial will omit the details of training a model for sake of simplicity.

You can confirm this locally by running the notebook with the following command:

.. code-block:: bash

uv run demo.py --hello world --demo works --lr 0.01

This will print the command-line arguments passed to the notebook.

.. code-block:: bash

{'hello': 'world', 'demo': 'works', 'lr': '0.01'}

To use a notebook like this as a job you'll want to configure a notebook
yaml file like this:

.. dropdown:: marimo-demo.yaml

.. code-block:: yaml

# marimo-demo.yaml
name: marimo-demo

# Specify specific resources for this job here
resources:

# This needs to point to the folder that has the marimo notebook
workdir: scripts

# Fill in any env keys, like wandb
envs:
WANDB_API_KEY: "key"

# We only need to install uv
setup: pip install uv

# If the notebook is sandboxed via --sandbox, uv takes care of the dependencies
run: uv run demo.py --hello world --demo works --lr 0.01


You can now submit this job to ``sky`` using the following command:

.. code-block:: bash

sky jobs launch -n marimo-demo marimo-demo.yaml

This command will provision cloud resources and then launch the job. You can monitor
the job status by checking logs in the terminal, but you can also check the dashboard
by running :code:`sky dashboard`.

This is what the dashboard of the job looks like after it is done.

.. dropdown:: SkyPilot jobs dashboard

.. image:: ../images/marimo-job.png
:align: center
:alt: marimo job completed

The resources used during the job will also turn off automatically after
the resource detects a configurable amount of inactivity. You can learn more
about how to configure this behavior on the :ref:`managed-jobs` guide.

Working with clusters
---------------------


+ 14
- 0
docs/source/examples/managed-jobs.rst View File

@@ -729,6 +729,18 @@ If you have deployed a :ref:`remote API server <sky-api-server>`, you can avoid
.. warning::
Because the jobs controller must stay alive to manage running jobs, it's required to use an external API server to enable consolidation mode.

.. image:: ../images/jobs-consolidation-mode.svg
:width: 800
:alt: Architecture diagram of SkyPilot remote API server with and without consolidation mode
:align: center

Consolidating the API server and the jobs controller has a few advantages:

- 6x faster job submission.
- Consistent cloud/Kubernetes credentials across the API server and jobs controller.
- Persistent managed job state using the same database as the API server, e.g., PostgreSQL.
- No extra VM/pod is needed for the jobs controller, saving cost.

To enable the consolidated deployment, set :ref:`consolidation_mode <config-yaml-jobs-controller-consolidation-mode>` in the API server config.

.. code-block:: yaml
@@ -749,4 +761,6 @@ To enable the consolidated deployment, set :ref:`consolidation_mode <config-yaml
# Restart the API server to pick up the config change
kubectl -n $NAMESPACE rollout restart deployment $RELEASE_NAME-api-server

See :ref:`more about the Kubernetes upgrade strategy of the API server <sky-api-server-graceful-upgrade>`.

The jobs controller will use a bit of overhead - it reserves an extra 2GB of memory for itself, which may reduce the amount of requests your API server can handle. To counteract, you can increase the amount of CPU and memory allocated to the API server: See :ref:`sky-api-server-resources-tuning`.

+ 14
- 2
docs/source/getting-started/installation.rst View File

@@ -373,10 +373,20 @@ GCP
# This will generate ~/.config/gcloud/application_default_credentials.json.
gcloud auth application-default login

.. tab-item:: Archive Download
.. tab-item:: Manual Install
:sync: gcp-archive-download-tab

Follow the `Google Cloud SDK installation instructions <https://cloud.google.com/sdk/docs/install#installation_instructions>`_ for your OS.
For MacOS with Silicon Chips:

.. code-block:: shell

curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-darwin-arm.tar.gz gcloud.tar.gz
tar -xf gcloud.tar.gz
./google-cloud-sdk/install.sh
# Update your path with the newly installed gcloud

If you are using other architecture or OS,
follow the `Google Cloud SDK installation instructions <https://cloud.google.com/sdk/docs/install#installation_instructions>`_ to download the appropriate package.

Be sure to complete the optional step that adds ``gcloud`` to your ``PATH``.
This step is required for SkyPilot to recognize that your ``gcloud`` installation is configured correctly.
@@ -423,6 +433,8 @@ CoreWeave

CoreWeave also offers InfiniBand networking for high-performance distributed training. You can enable InfiniBand support by adding ``network_tier: best`` to your SkyPilot task configuration.

.. _coreweave-caios-installation:

CoreWeave Object Storage (CAIOS)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^



+ 1
- 0
docs/source/images/jobs-consolidation-mode.svg
File diff suppressed because it is too large
View File


BIN
docs/source/images/marimo-auth.png View File

Before After
Width: 1092  |  Height: 1119  |  Size: 124 KiB

BIN
docs/source/images/marimo-example.png View File

Before After
Width: 902  |  Height: 637  |  Size: 112 KiB

BIN
docs/source/images/marimo-job.png View File

Before After
Width: 1162  |  Height: 1300  |  Size: 308 KiB

BIN
docs/source/images/marimo-nvidea.png View File

Before After
Width: 1092  |  Height: 1119  |  Size: 225 KiB

BIN
docs/source/images/marimo-use.png View File

Before After
Width: 1092  |  Height: 1119  |  Size: 174 KiB

+ 527
- 0
docs/source/reference/api-server/api-server-admin-deploy.rst View File

@@ -224,6 +224,30 @@ Following tabs describe how to configure credentials for different clouds on the

The specific cloud's credential for the exec-based authentication also needs to be configured. For example, to enable exec-based authentication for GKE, you also need to setup GCP credentials (see the GCP tab above).

.. dropdown:: Update Kubernetes credentials

After Kubernetes credentials are enabled, you can update the kubeconfig file in ``kube-credentials`` by:

1. Replace the existing secret in place:

.. code-block:: bash

kubectl delete secret kube-credentials
kubectl create secret generic kube-credentials \
--namespace $NAMESPACE \
--from-file=config=$HOME/.kube/config

2. Then it will take tens of seconds to take effect on the API server. You can verify the updated credentials in the API server pod:

.. code-block:: bash

# The NAMESPACE and RELEASE_NAME should be consistent with the API server deployment
API_SERVER_POD_NAME=$(kubectl get pods -n $NAMESPACE -l app=${RELEASE_NAME}-api -o jsonpath='{.items[0].metadata.name}')
# If `SSH Node Pools` is not enabled
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- cat /root/.kube/config
# If `SSH Node Pools` is enabled
#kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- cat /var/skypilot/kubeconfig/config

To use multiple Kubernetes clusters, you will need to add the context names to ``allowed_contexts`` in the SkyPilot config. An example config file that allows using the hosting Kubernetes cluster and two additional Kubernetes clusters is shown below:

.. code-block:: yaml
@@ -267,6 +291,52 @@ Following tabs describe how to configure credentials for different clouds on the
--reuse-values \
--set awsCredentials.enabled=true

.. dropdown:: Update AWS credentials (single profile)

After AWS credentials are enabled, update the access or secret key in ``aws-credentials`` using either approach:

1. Create a new secret with a new name:

.. code-block:: bash

kubectl create secret generic aws-credentials-new \
--namespace $NAMESPACE \
--from-literal=aws_access_key_id=YOUR_ACCESS_KEY_ID \
--from-literal=aws_secret_access_key=YOUR_SECRET_ACCESS_KEY

Then point Helm to the new secret name:

.. code-block:: bash

helm upgrade --install skypilot skypilot/skypilot-nightly --devel \
--namespace $NAMESPACE \
--reuse-values \
--set awsCredentials.awsSecretName=aws-credentials-new

2. Replace the existing secret in place, then restart the API server:

.. code-block:: bash

kubectl delete secret aws-credentials
kubectl create secret generic aws-credentials \
--namespace $NAMESPACE \
--from-literal=aws_access_key_id=YOUR_ACCESS_KEY_ID \
--from-literal=aws_secret_access_key=YOUR_SECRET_ACCESS_KEY

Restart the API server:

.. code-block:: bash

kubectl rollout restart deployment/$RELEASE_NAME-api-server -n $NAMESPACE

Verify the updated credentials in the API server pod:

.. code-block:: bash

# The NAMESPACE and RELEASE_NAME should be consistent with the API server deployment
API_SERVER_POD_NAME=$(kubectl get pods -n $NAMESPACE -l app=${RELEASE_NAME}-api -o jsonpath='{.items[0].metadata.name}')
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- cat /root/.aws/credentials

**Option 2: Multiple profiles (for multiple workspaces)**

Use this if you need different AWS profiles for different workspaces. Create a Kubernetes secret from your AWS credentials file:
@@ -288,6 +358,27 @@ Following tabs describe how to configure credentials for different clouds on the
--set awsCredentials.enabled=true \
--set awsCredentials.useCredentialsFile=true

.. dropdown:: Update AWS credentials (multiple profiles)

After AWS credentials are enabled, you can update the credentials file in ``aws-credentials`` by:

1. Replace the existing secret in place:

.. code-block:: bash

kubectl delete secret aws-credentials
kubectl create secret generic aws-credentials \
--namespace $NAMESPACE \
--from-file=credentials=$HOME/.aws/credentials

2. Then it will take tens of seconds to take effect on the API server. You can verify the updated credentials in the API server pod:

.. code-block:: bash

# The NAMESPACE and RELEASE_NAME should be consistent with the API server deployment
API_SERVER_POD_NAME=$(kubectl get pods -n $NAMESPACE -l app=${RELEASE_NAME}-api -o jsonpath='{.items[0].metadata.name}')
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- cat /root/.aws/credentials

.. dropdown:: Use existing AWS credentials

You can also set the following values to use a secret that already contains your AWS credentials:
@@ -352,6 +443,50 @@ Following tabs describe how to configure credentials for different clouds on the
--set gcpCredentials.enabled=true \
--set gcpCredentials.gcpSecretName=your_secret_name

.. dropdown:: Update GCP credentials

After GCP credentials are enabled, you can update the credentials file in ``gcp-credentials`` using either approach:

1. Create a new secret with a new name:

.. code-block:: bash

kubectl create secret generic gcp-credentials-new \
--namespace $NAMESPACE \
--from-file=gcp-cred.json=YOUR_SERVICE_ACCOUNT_JSON_KEY_NEW.json

Then point Helm to the new secret name:

.. code-block:: bash

helm upgrade --install skypilot skypilot/skypilot-nightly --devel \
--namespace $NAMESPACE \
--reuse-values \
--set gcpCredentials.gcpSecretName=gcp-credentials-new

2. Replace the existing secret in place, then restart the API server:

.. code-block:: bash

kubectl delete secret gcp-credentials
kubectl create secret generic gcp-credentials \
--namespace $NAMESPACE \
--from-file=gcp-cred.json=YOUR_SERVICE_ACCOUNT_JSON_KEY.json

Restart the API server:

.. code-block:: bash

kubectl rollout restart deployment/$RELEASE_NAME-api-server -n $NAMESPACE

Verify the updated credentials in the API server pod:

.. code-block:: bash

# The NAMESPACE and RELEASE_NAME should be consistent with the API server deployment
API_SERVER_POD_NAME=$(kubectl get pods -n $NAMESPACE -l app=${RELEASE_NAME}-api -o jsonpath='{.items[0].metadata.name}')
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- ls -lart /root/.config/gcloud

.. tab-item:: RunPod
:sync: runpod-creds-tab

@@ -380,6 +515,50 @@ Following tabs describe how to configure credentials for different clouds on the
--set runpodCredentials.enabled=true \
--set runpodCredentials.runpodSecretName=your_secret_name

.. dropdown:: Update RunPod credentials

After RunPod credentials are enabled, you can update the API key in ``runpod-credentials`` using either approach:

1. Create a new secret with a new name:

.. code-block:: bash

kubectl create secret generic runpod-credentials-new \
--namespace $NAMESPACE \
--from-literal api_key=YOUR_API_KEY_NEW

Then point Helm to the new secret name:

.. code-block:: bash

helm upgrade --install skypilot skypilot/skypilot-nightly --devel \
--namespace $NAMESPACE \
--reuse-values \
--set runpodCredentials.runpodSecretName=runpod-credentials-new

2. Replace the existing secret in place, then restart the API server:

.. code-block:: bash

kubectl delete secret runpod-credentials
kubectl create secret generic runpod-credentials \
--namespace $NAMESPACE \
--from-literal api_key=YOUR_API_KEY

Restart the API server:

.. code-block:: bash

kubectl rollout restart deployment/$RELEASE_NAME-api-server -n $NAMESPACE

Verify the updated credentials in the API server pod:

.. code-block:: bash

# The NAMESPACE and RELEASE_NAME should be consistent with the API server deployment
API_SERVER_POD_NAME=$(kubectl get pods -n $NAMESPACE -l app=${RELEASE_NAME}-api -o jsonpath='{.items[0].metadata.name}')
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- cat /root/.runpod/config.toml

.. tab-item:: Lambda
:sync: lambda-creds-tab

@@ -416,6 +595,50 @@ Following tabs describe how to configure credentials for different clouds on the
--set lambdaCredentials.enabled=true \
--set lambdaCredentials.lambdaSecretName=your_secret_name

.. dropdown:: Update Lambda credentials

After Lambda credentials are enabled, you can update the API key in ``lambda-credentials`` using either approach:

1. Create a new secret with a new name:

.. code-block:: bash

kubectl create secret generic lambda-credentials-new \
--namespace $NAMESPACE \
--from-literal api_key=YOUR_API_KEY_NEW

Then point Helm to the new secret name:

.. code-block:: bash

helm upgrade --install skypilot skypilot/skypilot-nightly --devel \
--namespace $NAMESPACE \
--reuse-values \
--set lambdaCredentials.lambdaSecretName=lambda-credentials-new

2. Replace the existing secret in place, then restart the API server:

.. code-block:: bash

kubectl delete secret lambda-credentials
kubectl create secret generic lambda-credentials \
--namespace $NAMESPACE \
--from-literal api_key=YOUR_API_KEY

Restart the API server:

.. code-block:: bash

kubectl rollout restart deployment/$RELEASE_NAME-api-server -n $NAMESPACE

Verify the updated credentials in the API server pod:

.. code-block:: bash

# The NAMESPACE and RELEASE_NAME should be consistent with the API server deployment
API_SERVER_POD_NAME=$(kubectl get pods -n $NAMESPACE -l app=${RELEASE_NAME}-api -o jsonpath='{.items[0].metadata.name}')
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- cat /root/.lambda_cloud/lambda_keys

.. tab-item:: Nebius
:sync: nebius-creds-tab

@@ -480,6 +703,27 @@ Following tabs describe how to configure credentials for different clouds on the
--set nebiusCredentials.enabled=true \
--set nebiusCredentials.nebiusSecretName=your_secret_name

.. dropdown:: Update Nebius credentials

After Nebius credentials are enabled, you can update the credentials file in ``nebius-credentials`` by:

1. Replace the existing secret in place:

.. code-block:: bash

kubectl delete secret nebius-credentials
kubectl create secret generic nebius-credentials \
--namespace $NAMESPACE \
--from-file=credentials=$HOME/.nebius/credentials.json

2. Then it will take tens of seconds to take effect on the API server. You can verify the updated credentials in the API server pod:

.. code-block:: bash

# The NAMESPACE and RELEASE_NAME should be consistent with the API server deployment
API_SERVER_POD_NAME=$(kubectl get pods -n $NAMESPACE -l app=${RELEASE_NAME}-api -o jsonpath='{.items[0].metadata.name}')
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- cat /root/.nebius/credentials.json

.. tab-item:: Vast
:sync: vast-creds-tab

@@ -516,6 +760,49 @@ Following tabs describe how to configure credentials for different clouds on the
--set vastCredentials.enabled=true \
--set vastCredentials.vastSecretName=your_secret_name

.. dropdown:: Update Vast credentials

After Vast credentials are enabled, you can update the API key in ``vast-credentials`` using either approach:

1. Create a new secret with a new name:

.. code-block:: bash

kubectl create secret generic vast-credentials-new \
--namespace $NAMESPACE \
--from-literal api_key=YOUR_API_KEY_NEW

Then point Helm to the new secret name:

.. code-block:: bash

helm upgrade --install skypilot skypilot/skypilot-nightly --devel \
--namespace $NAMESPACE \
--reuse-values \
--set vastCredentials.vastSecretName=vast-credentials-new

2. Replace the existing secret in place, then restart the API server:

.. code-block:: bash

kubectl delete secret vast-credentials
kubectl create secret generic vast-credentials \
--namespace $NAMESPACE \
--from-literal api_key=YOUR_API_KEY

Restart the API server:

.. code-block:: bash

kubectl rollout restart deployment/$RELEASE_NAME-api-server -n $NAMESPACE

Verify the updated credentials in the API server pod:

.. code-block:: bash

# The NAMESPACE and RELEASE_NAME should be consistent with the API server deployment
API_SERVER_POD_NAME=$(kubectl get pods -n $NAMESPACE -l app=${RELEASE_NAME}-api -o jsonpath='{.items[0].metadata.name}')
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- cat /root/.config/vastai/vast_api_key

.. tab-item:: SSH Node Pools
:sync: ssh-node-pools-tab
@@ -532,6 +819,17 @@ Following tabs describe how to configure credentials for different clouds on the
--reuse-values \
--set-file apiService.sshNodePools=/your/path/to/ssh_node_pools.yaml

.. note::

Updating the value of ``apiService.sshNodePools`` will not restart the API server but it will take tens of seconds to take effect on the API server.
You can verify the config updates on the API server by running the following command:

.. code-block:: bash

# The NAMESPACE and RELEASE_NAME should be consistent with the API server deployment
API_SERVER_POD_NAME=$(kubectl get pods -n $NAMESPACE -l app=${RELEASE_NAME}-api -o jsonpath='{.items[0].metadata.name}')
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- cat /root/.sky/ssh_node_pools.yaml

If your ``ssh_node_pools.yaml`` requires SSH keys, create a secret that contains the keys and set the :ref:`apiService.sshKeySecret <helm-values-apiService-sshKeySecret>` to the secret name:

.. code-block:: bash
@@ -551,6 +849,28 @@ Following tabs describe how to configure credentials for different clouds on the
--reuse-values \
--set apiService.sshKeySecret=$SECRET_NAME

.. dropdown:: Update SSH key credentials

After SSH key credentials are enabled, you can update the credentials file in ``$SECRET_NAME`` by:

1. Replace the existing secret in place:

.. code-block:: bash

kubectl delete secret $SECRET_NAME
kubectl create secret generic $SECRET_NAME \
--namespace $NAMESPACE \
--from-file=id_rsa=/path/to/id_rsa \
--from-file=other_id_rsa=/path/to/other_id_rsa

2. Then it will take tens of seconds to take effect on the API server. You can verify the updated credentials in the API server pod:

.. code-block:: bash

# The NAMESPACE and RELEASE_NAME should be consistent with the API server deployment
API_SERVER_POD_NAME=$(kubectl get pods -n $NAMESPACE -l app=${RELEASE_NAME}-api -o jsonpath='{.items[0].metadata.name}')
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- ls -lart /root/.ssh/

After the API server is deployed, use the ``sky ssh up`` command to set up the SSH Node Pools. Refer to :ref:`existing-machines` for more details.

.. note::
@@ -582,6 +902,213 @@ Following tabs describe how to configure credentials for different clouds on the
--set r2Credentials.enabled=true \
--set r2Credentials.r2SecretName=r2-credentials

.. dropdown:: Update Cloudflare R2 credentials

After Cloudflare R2 credentials are enabled, you can update the credentials file in ``r2-credentials`` using either approach:

1. Create a new secret with a new name:

.. code-block:: bash

kubectl create secret generic r2-credentials-new \
--namespace $NAMESPACE \
--from-file=r2.credentials=$HOME/.cloudflare/r2.credentials
--from-file=accountid=$HOME/.cloudflare/accountid

Then point Helm to the new secret name:

.. code-block:: bash

helm upgrade --install skypilot skypilot/skypilot-nightly --devel \
--namespace $NAMESPACE \
--reuse-values \
--set r2Credentials.r2SecretName=r2-credentials-new

2. Replace the existing secret in place, then restart the API server:

.. code-block:: bash

kubectl delete secret r2-credentials
kubectl create secret generic r2-credentials \
--namespace $NAMESPACE \
--from-file=r2.credentials=$HOME/.cloudflare/r2.credentials
--from-file=accountid=$HOME/.cloudflare/accountid

Restart the API server:

.. code-block:: bash

kubectl rollout restart deployment/$RELEASE_NAME-api-server -n $NAMESPACE

Verify the updated credentials in the API server pod:

.. code-block:: bash

# The NAMESPACE and RELEASE_NAME should be consistent with the API server deployment
API_SERVER_POD_NAME=$(kubectl get pods -n $NAMESPACE -l app=${RELEASE_NAME}-api -o jsonpath='{.items[0].metadata.name}')
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- cat /root/.cloudflare/r2.credentials
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- cat /root/.cloudflare/accountid

.. tab-item:: CoreWeave
:sync: coreweave-creds-tab

SkyPilot API server uses the same credentials as the :ref:`CoreWeave CAIOS installation <coreweave-caios-installation>` to authenticate with CoreWeave Object Storage.

Once you have the credentials configured locally, you can store them in a Kubernetes secret:

.. code-block:: bash

kubectl create secret generic coreweave-credentials \
--namespace $NAMESPACE \
--from-file=cw.config=$HOME/.coreweave/cw.config \
--from-file=cw.credentials=$HOME/.coreweave/cw.credentials

When installing or upgrading the Helm chart, enable CoreWeave CAIOS credentials by setting ``coreweaveCredentials.enabled=true``:

.. code-block:: bash

# --reuse-values keeps the Helm chart values set in the previous step
helm upgrade --install $RELEASE_NAME skypilot/skypilot-nightly --devel \
--namespace $NAMESPACE \
--reuse-values \
--set coreweaveCredentials.enabled=true

.. dropdown:: Use existing CoreWeave CAIOS credentials

You can also set the following values to use a secret that already contains your CoreWeave CAIOS credentials:

.. code-block:: bash

# TODO: replace with your secret name
helm upgrade --install $RELEASE_NAME skypilot/skypilot-nightly --devel \
--namespace $NAMESPACE \
--reuse-values \
--set coreweaveCredentials.enabled=true \
--set coreweaveCredentials.coreweaveSecretName=your_secret_name

.. dropdown:: Update CoreWeave CAIOS credentials

After CoreWeave CAIOS credentials are enabled, you can update the credentials file in ``coreweave-credentials`` using either approach:

1. Create a new secret with a new name:

.. code-block:: bash

kubectl create secret generic coreweave-credentials-new \
--namespace $NAMESPACE \
--from-file=cw.config=$HOME/.coreweave/cw.config \
--from-file=cw.credentials=$HOME/.coreweave/cw.credentials

Then point Helm to the new secret name:

.. code-block:: bash

helm upgrade --install skypilot skypilot/skypilot-nightly --devel \
--namespace $NAMESPACE \
--reuse-values \
--set coreweaveCredentials.coreweaveSecretName=coreweave-credentials-new

2. Replace the existing secret in place, then restart the API server:

.. code-block:: bash

kubectl delete secret coreweave-credentials
kubectl create secret generic coreweave-credentials \
--namespace $NAMESPACE \
--from-file=cw.config=$HOME/.coreweave/cw.config \
--from-file=cw.credentials=$HOME/.coreweave/cw.credentials

Restart the API server:

.. code-block:: bash

kubectl rollout restart deployment/$RELEASE_NAME-api-server -n $NAMESPACE

Verify the updated credentials in the API server pod:

.. code-block:: bash

# The NAMESPACE and RELEASE_NAME should be consistent with the API server deployment
API_SERVER_POD_NAME=$(kubectl get pods -n $NAMESPACE -l app=${RELEASE_NAME}-api -o jsonpath='{.items[0].metadata.name}')
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- cat /root/.coreweave/cw.config
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- cat /root/.coreweave/cw.credentials

.. tab-item:: DigitalOcean
:sync: digitalocean-creds-tab

SkyPilot API server use **API key** to authenticate with DigitalOcean. To configure DigitalOcean access,
follow the `instructions <https://docs.digitalocean.com/reference/api/create-personal-access-token/#creating-a-token>`_
provided by DigitalOcean.

Once the key is generated, create a Kubernetes secret to store it:

.. code-block:: bash

kubectl create secret generic digitalocean-credentials \
--namespace $NAMESPACE \
--from-literal api_key=YOUR_API_KEY

When installing or upgrading the Helm chart, enable DigitalOcean credentials by
setting ``digitaloceanCredentials.enabled=true``

.. dropdown:: Use existing DigitalOcean credentials

You can also set the following values to use a secret that already contains your DigitalOcean API key:

.. code-block:: bash

# TODO: replace with your secret name
# if secret name is not provided, secret name defaults to `digitalocean-credentials`
helm upgrade --install skypilot skypilot/skypilot-nightly --devel \
--namespace $NAMESPACE \
--reuse-values \
--set digitaloceanCredentials.enabled=true \
--set digitaloceanCredentials.digitaloceanSecretName=your_secret_name

.. dropdown:: Update DigitalOcean credentials

After DigitalOcean credentials are enabled, you can update the API key in ``digitalocean-credentials`` using either approach:

1. Create a new secret with a new name:

.. code-block:: bash

kubectl create secret generic digitalocean-credentials-new \
--namespace $NAMESPACE \
--from-literal api_key=YOUR_API_KEY_NEW

Then point Helm to the new secret name:

.. code-block:: bash

helm upgrade --install skypilot skypilot/skypilot-nightly --devel \
--namespace $NAMESPACE \
--reuse-values \
--set digitaloceanCredentials.digitaloceanSecretName=digitalocean-credentials-new

2. Replace the existing secret in place, then restart the API server:

.. code-block:: bash

kubectl delete secret digitalocean-credentials
kubectl create secret generic digitalocean-credentials \
--namespace $NAMESPACE \
--from-literal api_key=YOUR_API_KEY

Restart the API server:

.. code-block:: bash

kubectl rollout restart deployment/$RELEASE_NAME-api-server -n $NAMESPACE

Verify the updated credentials in the API server pod:

.. code-block:: bash

# The NAMESPACE and RELEASE_NAME should be consistent with the API server deployment
API_SERVER_POD_NAME=$(kubectl get pods -n $NAMESPACE -l app=${RELEASE_NAME}-api -o jsonpath='{.items[0].metadata.name}')
kubectl exec $API_SERVER_POD_NAME -n $NAMESPACE -- cat /root/.config/doctl/config.yaml

.. tab-item:: Other clouds
:sync: other-clouds-tab



+ 79
- 0
docs/source/reference/api-server/helm-values-spec.rst View File

@@ -233,6 +233,14 @@ Below is the available helm value keys and the default value of each key:
:ref:`tenantId <helm-values-nebiusCredentials-tenantId>`: null
:ref:`nebiusSecretName <helm-values-nebiusCredentials-nebiusSecretName>`: nebius-credentials

:ref:`coreweaveCredentials <helm-values-coreweaveCredentials>`:
:ref:`enabled <helm-values-coreweaveCredentials-enabled>`: false
:ref:`coreweaveSecretName <helm-values-coreweaveCredentials-coreweaveSecretName>`: coreweave-credentials

:ref:`digitaloceanCredentials <helm-values-digitaloceanCredentials>`:
:ref:`enabled <helm-values-digitaloceanCredentials-enabled>`: false
:ref:`digitaloceanSecretName <helm-values-digitaloceanCredentials-digitaloceanSecretName>`: digitalocean-credentials

:ref:`extraInitContainers <helm-values-extraInitContainers>`: null

:ref:`podSecurityContext <helm-values-podSecurityContext>`: {}
@@ -2106,6 +2114,77 @@ Default: ``nebius-credentials``
nebiusCredentials:
nebiusSecretName: nebius-credentials

.. _helm-values-coreweaveCredentials:

``coreweaveCredentials``
~~~~~~~~~~~~~~~~~~~~~~~~

.. _helm-values-coreweaveCredentials-enabled:

``coreweaveCredentials.enabled``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Enable CoreWeave CAIOS credentials for the API server.

Default: ``false``

.. code-block:: yaml

coreweaveCredentials:
enabled: false

.. _helm-values-coreweaveCredentials-coreweaveSecretName:

``coreweaveCredentials.coreweaveSecretName``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Name of the secret containing the CoreWeave CAIOS credentials. Only used if enabled is true. The secret should contain the following keys:

- ``cw.config``: CoreWeave CAIOS config file
- ``cw.credentials``: CoreWeave CAIOS credentials file

Refer to :ref:`CoreWeave CAIOS installation <coreweave-caios-installation>` for more details.

Default: ``coreweave-credentials``

.. code-block:: yaml

coreweaveCredentials:
coreweaveSecretName: coreweave-credentials

.. _helm-values-digitaloceanCredentials:

``digitaloceanCredentials``
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. _helm-values-digitaloceanCredentials-enabled:

``digitaloceanCredentials.enabled``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Enable DigitalOcean credentials for the API server.

Default: ``false``

.. code-block:: yaml

digitaloceanCredentials:
enabled: false

.. _helm-values-digitaloceanCredentials-digitaloceanSecretName:

``digitaloceanCredentials.digitaloceanSecretName``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Name of the secret containing the DigitalOcean credentials. Only used if enabled is true.

Default: ``digitalocean-credentials``

.. code-block:: yaml

digitaloceanCredentials:
digitaloceanSecretName: digitalocean-credentials

.. _helm-values-extraInitContainers:

``extraInitContainers``


+ 21
- 0
docs/source/reference/config.rst View File

@@ -193,6 +193,9 @@ Below is the configuration syntax and some example values. See detailed explanat
:ref:`tenant_id <config-yaml-nebius-tenant-id>`: tenant-1234567890
:ref:`domain <config-yaml-nebius-domain>`: api.nebius.cloud:443

:ref:`vast <config-yaml-vast>`:
:ref:`secure_only <config-yaml-vast-secure-only>`: true

:ref:`rbac <config-yaml-rbac>`:
:ref:`default_role <config-yaml-rbac-default-role>`: admin

@@ -1675,6 +1678,24 @@ Example:
nebius:
domain: api.nebius.cloud:443

.. _config-yaml-vast:

``vast``
~~~~~~~~

Advanced Vast configuration (optional).

.. _config-yaml-vast-secure-only:

``vast.secure_only``
~~~~~~~~~~~~~~~~~~~~

Configure SkyPilot to only consider offers on Vast verified datacenters (optional).
Internally, this will query Vast with the ``datacenters=true`` parameters. Note
some GPUs may only be available on non-secure offers. This config can be
overridden per task via :ref:`config flag <config-client-cli-flag>`.

Default: ``false``

.. _config-yaml-rbac:



+ 9
- 0
docs/source/reservations/existing-machines.rst View File

@@ -344,3 +344,12 @@ Details: Prerequisites
* All nodes within a SSH Node Pool must have access to port 6443 to its peers (e.g., same VPC). Port 6443 doesn't have to be open to machines outside of the network.
* Nodes should not be part of an existing Kubernetes cluster (use :ref:`Kubernetes Support <kubernetes-overview>` instead).
* When working with GPU instances, GPU drivers must be installed on the host. Verify by running ``nvidia-smi``.


FAQs
----

* **I cannot provision a SkyPilot cluster with the exact amount of resources on my SSH Node Pools machine.**

The SSH Node Pools runtime consumes some resources. Therefore, if you set up SSH Node Pools on a server with 4 CPUs and 16 GB of memory, for instance, SkyPilot cannot provision jobs that require the full 4 CPUs and 16 GB of memory.
The actual resources SkyPilot reports as available will be slightly less than the machine's specifications.

+ 24
- 1
examples/github_actions/README.md View File

@@ -8,6 +8,19 @@ This example provides a GitHub CI pipeline that automatically starts a SkyPilot

> **_NOTE:_** This example is adapted from Metta AI's GitHub actions pipeline: https://github.com/Metta-AI/metta/tree/main

## Why use SkyPilot with GitHub Actions?

Pairing SkyPilot with GitHub Actions can automate routine experiments to run without manual input, improving iteration speed. Using SkyPilot with GitHub Actions can:

**Customize Workflow Triggers**: GitHub Actions provides a breadth of triggers to automate workflows, including:
- Code pushes to a branch
- Changes to specific files
- On a schedule

**Orchestrate and Monitor Jobs across Clouds**: SkyPilot allows the CI task to run across region, clouds and kubernetes clusters, and provides a single pane of glass to monitor the CI jobs.

**Enable Custom Notifications**: Send a notification whenever a CI job runs, with a link to monitor the job status and logs.


## Prerequisites

@@ -27,6 +40,7 @@ To create a service account key:

- **Navigate to the Users page**: On the main page of the SkyPilot dashboard, click "Users".
- **Access Service Account Settings**: Click "Service Accounts" located at the top of the page.
> **Note:** If "Service Accounts" section does not exist on the dashboard, the API server is not using SSO. This section can be skipped.
- **Create New Service Account**: Click "+ Create Service Account" button located at the top of the page and follow the instructions to create a service account token.

### GitHub: Define repository secrets
@@ -86,7 +100,12 @@ The workflow checks out the GitHub repo to a specified commit, generates a uniqu

The ``Launch SkyPilot Job`` action in turn uses a custom action located at ``.github/actions/setup-environment/action.yaml`` to install necessary dependencies (including ``skypilot``), and launches a SkyPilot job.

Once the job is successfully launched, ``sky-job.yaml`` then parses out the job ID of the submitted job. A slack message is then sent to the configured slack channel. An example message is provided below:
Once the job is successfully launched, ``sky-job.yaml`` then parses out the job ID of the submitted job.

The submitted job can be queried either by using `sky jobs queue` or by visiting the Jobs page of the SkyPilot API dashboard.
![dashboard page](https://i.imgur.com/JjDk30Z.png "Dashboard page")

A slack message is then sent to the configured slack channel. An example message is provided below:
![slack message](https://i.imgur.com/p50yoD5.png "Slack message")

## Frequently Asked Questions
@@ -101,3 +120,7 @@ on:
- branches: [main]
+ branches: [master]
```

### How do I limit / isolate the resources available to the workflow?

You can specify a specific cloud, region or kubernetes cluster for the workflow to use in the task YAML. Alternatively, you can define a separate [workspace](https://docs.skypilot.co/en/latest/admin/workspaces.html) the workflow can use, isolating the infrastructure the workflow has access to.

+ 4
- 1
examples/marimo/marimo.yaml View File

@@ -13,7 +13,10 @@ resources:

workdir: .

secrets:
MARIMO_TOKEN_PASSWORD: secretpassword

setup: pip install uv

# Check the docs for more options: https://docs.marimo.io/cli/#marimo-edit
run: uvx marimo edit --port 29324 --headless --host=0.0.0.0 --token-password='secretpassword'
run: uvx marimo edit --port 29324 --headless --host=0.0.0.0 --token-password=$MARIMO_TOKEN_PASSWORD

+ 9
- 0
examples/plugin/README.md View File

@@ -0,0 +1,9 @@
# Example Plugins for SkyPilot API Server

Usage:

```bash
$ pip install .
$ cp plugins.yaml ~/.sky/plugins.yaml
$ sky api stop; sky api start
```

+ 6
- 0
examples/plugin/example_plugin/__init__.py View File

@@ -0,0 +1,6 @@
"""Example plugin for SkyPilot API server."""
from example_plugin.plugin import ExampleBackgroundTaskPlugin
from example_plugin.plugin import ExampleMiddlewarePlugin
from example_plugin.plugin import ExampleParameterizedPlugin
from example_plugin.plugin import ExamplePatchPlugin
from example_plugin.plugin import ExamplePlugin

+ 123
- 0
examples/plugin/example_plugin/plugin.py View File

@@ -0,0 +1,123 @@
"""Example plugin for SkyPilot API server."""
import functools
import logging
import os
import threading
import time

import fastapi
import filelock
import starlette.middleware.base

from sky import check as sky_check
from sky.provision import common
from sky.server import plugins

logger = logging.getLogger(__name__)


class ExamplePlugin(plugins.BasePlugin):
"""Example plugin for SkyPilot API server."""

def build_router(self) -> fastapi.APIRouter:
router = fastapi.APIRouter(prefix='/plugins/example',
tags=['example plugin'])

@router.get('/')
async def get_example():
return {'message': 'Hello from example_plugin'}

return router

def install(self, extension_context: plugins.ExtensionContext):
if extension_context.app:
extension_context.app.include_router(self.build_router())
extension_context.register_rbac_rule(path='/plugins/example/*',
method='GET',
description='Example plugin',
role='user')


class ExampleParameterizedPlugin(plugins.BasePlugin):
"""Example plugin with parameters for SkyPilot API server."""

def __init__(self, message: str):
self._message = message

def build_router(self) -> fastapi.APIRouter:
router = fastapi.APIRouter(prefix='/plugins/example_parameterized',
tags=['example_parameterized plugin'])

@router.get('/')
async def get_example():
return {'message': self._message}

return router

def install(self, extension_context: plugins.ExtensionContext):
if extension_context.app:
extension_context.app.include_router(self.build_router())


class ExampleMiddleware(starlette.middleware.base.BaseHTTPMiddleware):
"""Example middleware for SkyPilot API server."""

async def dispatch(self, request: fastapi.Request, call_next):
client = request.client
logger.info(f'Audit request {request.method} {request.url.path} '
f'from {client.host if client else "unknown"}')
return await call_next(request)


class ExampleMiddlewarePlugin(plugins.BasePlugin):
"""Example plugin that adds a middleware for SkyPilot API server."""

def install(self, extension_context: plugins.ExtensionContext):
if extension_context.app:
extension_context.app.add_middleware(ExampleMiddleware)


class ExampleBackgroundTaskPlugin(plugins.BasePlugin):
"""Example plugin that runs a background task on the API server."""

def install(self, extension_context: plugins.ExtensionContext):

lock_path = os.path.expanduser('~/.sky/plugins/check_context_task.lock')
os.makedirs(os.path.dirname(lock_path), exist_ok=True)

def check_contexts():
lock = filelock.FileLock(lock_path)
try:
with lock.acquire(blocking=False):
while True:
try:
sky_check.check(clouds=['kubernetes'])
except Exception as e:
logger.error('Error checking contexts: %s', e)
time.sleep(60)
except filelock.Timeout:
# Other process is already running the task, skip it.
pass

threading.Thread(target=check_contexts, daemon=True).start()


class ExamplePatchPlugin(plugins.BasePlugin):
"""Example plugin that patches the SkyPilot API server."""

def install(self, extension_context: plugins.ExtensionContext):
# pylint: disable=import-outside-toplevel
from sky.provision.kubernetes import instance

original_run_instances = instance.run_instances

@functools.wraps(original_run_instances)
def patched_run_instances(
region: str, cluster_name: str, cluster_name_on_cloud: str,
config: common.ProvisionConfig) -> common.ProvisionRecord:
result = original_run_instances(region, cluster_name,
cluster_name_on_cloud, config)
logger.info('Post action after running instances')
return result

instance.run_instances = patched_run_instances

+ 8
- 0
examples/plugin/plugins.yaml View File

@@ -0,0 +1,8 @@
plugins:
- class: example_plugin.ExamplePlugin
- class: example_plugin.ExampleParameterizedPlugin
parameters:
message: "Message from parameters"
- class: example_plugin.ExampleMiddlewarePlugin
- class: example_plugin.ExampleBackgroundTaskPlugin
- class: example_plugin.ExamplePatchPlugin

+ 7
- 0
examples/plugin/pyproject.toml View File

@@ -0,0 +1,7 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "example_plugin"
version = "0.0.1"

+ 8
- 1
examples/resnet_distributed_torch.yaml View File

@@ -13,7 +13,14 @@ setup: |
cd pytorch-distributed-resnet
# SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5).
uv pip install -r requirements.txt numpy==1.26.4 torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
mkdir -p data && mkdir -p saved_models && cd data && \
mkdir -p data && mkdir -p saved_models
# Check if working directory is on a shared filesystem (NFS, Lustre, etc.)
# If shared, only download on rank 0 to avoid race conditions
fstype=$(stat -fc%T .)
if [[ "$fstype" =~ ^(nfs|lustre|fuse\.lustre|cifs|smb)$ ]] && [ "${SKYPILOT_SETUP_NODE_RANK:-0}" -ne 0 ]; then
exit 0
fi
cd data && \
wget -c --quiet https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
tar -xvzf cifar-10-python.tar.gz



+ 30
- 2
examples/vector_database/README.md View File

@@ -72,6 +72,35 @@ Processing files: 100%|██████████| 12/12 [00:05<00:00, 2.04

To serve the constructed database, you expose an API endpoint that other applications (or your local client) can call to perform semantic search. Querying allows you to confirm that your database is working and retrieve semantic matches for a given text query. You can integrate this endpoint into larger applications (like an image search engine or recommendation system).

### Option 1: Launch with CLI

To serve the constructed database:
```
sky launch -c vecdb_serve serve_vectordb.yaml
```
which runs the hosted vector database service. Alternatively, you can run
```
sky serve up serve_vectordb.yaml -n vectordb
```
This deploys your vector database as a service on a cloud instance and allows you to interact with it via a public endpoint. Sky Serve facilitates automatic health checks and scaling of the service.

To query the constructed database,

If you run through `sky launch`, use
```
sky status --ip vecdb_serve
```
deployed cluster.

If you run through `sky serve`, you may run
```
sky serve status vectordb --endpoint
```

to get the endpoint address of the service.

### Option 2: Launch with SDK

To serve the constructed database:
```
python3 serve_vectordb.py
@@ -80,8 +109,7 @@ which runs the hosted vector database service. Alternatively, you can run
```
python3 serve_vectordb.py --serve
```
This will deploy your vector database as a service on a cloud instance and allow you to interact with it via a public endpoint. Sky Serve facilitates automatic health checks and scaling of the service.

This deploys your vector database as a service on a cloud instance and allows you to interact with it via a public endpoint. Sky Serve facilitates automatic health checks and scaling of the service.

Use endpoint address of the service printed by the above script to query the constructed database.



+ 2
- 2
llm/llama-3_1-finetuning/readme.md View File

@@ -7,7 +7,7 @@

On July 23, 2024, Meta released the [Llama 3.1 model family](https://ai.meta.com/blog/meta-llama-3-1/), including a 405B parameter model in both base model and instruction-tuned forms. Llama 3.1 405B became _the first open LLM that closely rivals top proprietary models_ like GPT-4o and Claude 3.5 Sonnet.

This guide shows how to use [SkyPilot](https://github.com/skypilot-org/skypilot) and [torchtune](https://pytorch.org/torchtune/stable/index.html) to **finetune Llama 3.1 on your own data and infra**. Everything is packaged in a simple [SkyPilot YAML](https://docs.skypilot.co/en/latest/getting-started/quickstart.html), that can be launched with one command on your infra:
This guide shows how to use [SkyPilot](https://github.com/skypilot-org/skypilot) and [torchtune](https://meta-pytorch.org/torchtune/stable/index.html) to **finetune Llama 3.1 on your own data and infra**. Everything is packaged in a simple [SkyPilot YAML](https://docs.skypilot.co/en/latest/getting-started/quickstart.html), that can be launched with one command on your infra:
- Local GPU workstation
- Kubernetes cluster
- Cloud accounts ([12 clouds supported](https://docs.skypilot.co/en/latest/getting-started/installation.html))
@@ -20,7 +20,7 @@ This guide shows how to use [SkyPilot](https://github.com/skypilot-org/skypilot)


## Let's finetune Llama 3.1
We will use [torchtune](https://pytorch.org/torchtune/stable/index.html) to finetune Llama 3.1. The example below uses the [`yahma/alpaca-cleaned`](https://huggingface.co/datasets/yahma/alpaca-cleaned) dataset, which you can replace with your own dataset later.
We will use [torchtune](https://meta-pytorch.org/torchtune/stable/index.html) to finetune Llama 3.1. The example below uses the [`yahma/alpaca-cleaned`](https://huggingface.co/datasets/yahma/alpaca-cleaned) dataset, which you can replace with your own dataset later.

To set up the environment for launching the finetuning job, finish the [Appendix: Preparation](#appendix-preparation) section first.



+ 22
- 1
llm/localgpt/README.md View File

@@ -19,11 +19,32 @@ Once you are done, we will use [SkyPilot YAML for localGPT](https://github.com/s


## Launching localGPT on your cloud with SkyPilot
1. Use `launch_localgpt.py` to run localGPT on your cloud. SkyPilot will show the estimated cost and chosen cloud before provisioning. For reference, running on T4 instances on AWS would cost about $0.53 per hour.
1. Run localGPT on your cloud.

---

**Option 1: Run with CLI**

Use `sky launch` to run localGPT on your cloud.

```bash
sky launch -c localgpt localgpt.yaml
```

Once you see `INFO:werkzeug:Press CTRL+C to quit`, you can safely Ctrl+C from the `sky launch` command.

**Option 2: Run with SDK**

Use `launch_localgpt.py` to run localGPT on your cloud.

```bash
python3 launch_localgpt.py
```

---

SkyPilot will show the estimated cost and chosen cloud before provisioning. For reference, running on T4 instances on AWS would cost about $0.53 per hour.

2. Run `ssh -L 5111:localhost:5111 localgpt` in a new terminal window to forward the port 5111 to your local machine. Keep this terminal running.

3. Open http://localhost:5111 in your browser. Click on upload file to upload a document. Once the document has been ingested, you can chat with it, ask questions, and summarize it. For example, in the gif below, we use the SkyPilot [NSDI 2023 paper](https://www.usenix.org/system/files/nsdi23-yang-zongheng.pdf) to ask questions about how SkyPilot works.


+ 2
- 2
llm/mixtral/README.md View File

@@ -3,7 +3,7 @@
<!-- $END_REMOVE -->
<!-- $UNCOMMENT# Mixtral: MOE LLM from Mistral AI -->

Mistral AI released Mixtral 8x7B, a high-quality sparse mixture of experts model (SMoE) with open weights. Mixtral outperforms Llama 2 70B on most benchmarks with 6x faster inference. Mistral AI uses SkyPilot as [the default way](https://docs.mistral.ai/self-deployment/skypilot) to distribute their new model. This folder contains the code to serve Mixtral on any cloud with SkyPilot.
Mistral AI released Mixtral 8x7B, a high-quality sparse mixture of experts model (SMoE) with open weights. Mixtral outperforms Llama 2 70B on most benchmarks with 6x faster inference. Mistral AI uses SkyPilot as [the default way](https://docs.mistral.ai/deployment/self-deployment/skypilot) to distribute their new model. This folder contains the code to serve Mixtral on any cloud with SkyPilot.

There are three ways to serve the model:

@@ -148,6 +148,6 @@ curl http://$ENDPOINT/v1/chat/completions \

## 3. Official guide from Mistral AI

Mistral AI also includes a guide for launching the Mixtral 8x7B model with SkyPilot in their official doc. Please refer to [this link](https://docs.mistral.ai/self-deployment/skypilot) for more details.
Mistral AI also includes a guide for launching the Mixtral 8x7B model with SkyPilot in their official doc. Please refer to [this link](https://docs.mistral.ai/deployment/self-deployment/skypilot) for more details.

> Note: the docker image of the official doc may not be updated yet, which can cause a failure where vLLM is complaining about the missing support for the model. Please feel free to create a new docker image with the setup commands in our [serve.yaml](https://github.com/skypilot-org/skypilot/tree/master/llm/mixtral/serve.yaml) file instead.

+ 2
- 0
sky/__init__.py View File

@@ -140,6 +140,7 @@ Cudo = clouds.Cudo
GCP = clouds.GCP
Lambda = clouds.Lambda
SCP = clouds.SCP
Slurm = clouds.Slurm
Kubernetes = clouds.Kubernetes
K8s = Kubernetes
OCI = clouds.OCI
@@ -170,6 +171,7 @@ __all__ = [
'RunPod',
'Vast',
'SCP',
'Slurm',
'Vsphere',
'Fluidstack',
'Nebius',


+ 1
- 61
sky/adaptors/aws.py View File

@@ -28,7 +28,6 @@ This is informed by the following boto3 docs:

# pylint: disable=import-outside-toplevel

import functools
import logging
import threading
import time
@@ -69,65 +68,6 @@ version = 1
_MAX_ATTEMPT_FOR_CREATION = 5


class _ThreadLocalTTLCache(threading.local):
"""Thread-local storage for _thread_local_lru_cache decorator."""

def __init__(self, func, maxsize: int, ttl: int):
super().__init__()
self.func = func
self.maxsize = maxsize
self.ttl = ttl

def get_cache(self):
if not hasattr(self, 'cache'):
self.cache = annotations.ttl_cache(scope='request',
maxsize=self.maxsize,
ttl=self.ttl,
timer=time.time)(self.func)
return self.cache


def _thread_local_ttl_cache(maxsize=32, ttl=60 * 55):
"""Thread-local TTL cache decorator.

Args:
maxsize: Maximum size of the cache.
ttl: Time to live for the cache in seconds.
Default is 55 minutes, a bit less than 1 hour
default lifetime of an STS token.
"""

def decorator(func):
# Create thread-local storage for the LRU cache
local_cache = _ThreadLocalTTLCache(func, maxsize, ttl)

# We can't apply the lru_cache here, because this runs at import time
# so we will always have the main thread's cache.

@functools.wraps(func)
def wrapper(*args, **kwargs):
# We are within the actual function call, which may be on a thread,
# so local_cache.cache will return the correct thread-local cache,
# which we can now apply and immediately call.
return local_cache.get_cache()(*args, **kwargs)

def cache_info():
# Note that this will only give the cache info for the current
# thread's cache.
return local_cache.get_cache().cache_info()

def cache_clear():
# Note that this will only clear the cache for the current thread.
local_cache.get_cache().cache_clear()

wrapper.cache_info = cache_info # type: ignore[attr-defined]
wrapper.cache_clear = cache_clear # type: ignore[attr-defined]

return wrapper

return decorator


def _assert_kwargs_builtin_type(kwargs):
assert all(isinstance(v, (int, float, str)) for v in kwargs.values()), (
f'kwargs should not contain none built-in types: {kwargs}')
@@ -174,7 +114,7 @@ def get_workspace_profile() -> Optional[str]:

# The TTL cache needs to be thread-local to avoid multiple threads sharing the
# same session object, which is not guaranteed to be thread-safe.
@_thread_local_ttl_cache()
@annotations.thread_local_ttl_cache()
def session(check_credentials: bool = True, profile: Optional[str] = None):
"""Create an AWS session.



+ 478
- 0
sky/adaptors/slurm.py View File

@@ -0,0 +1,478 @@
"""Slurm adaptor for SkyPilot."""

import logging
import re
import time
from typing import Dict, List, NamedTuple, Optional, Tuple

from sky.utils import command_runner
from sky.utils import subprocess_utils
from sky.utils import timeline

logger = logging.getLogger(__name__)

# ASCII Unit Separator (\x1f) to handle values with spaces
# and other special characters.
SEP = r'\x1f'

# Regex pattern to extract partition names from scontrol output
# Matches PartitionName=<name> and captures until the next field
_PARTITION_NAME_REGEX = re.compile(r'PartitionName=(.+?)(?:\s+\w+=|$)')

# Default timeout for waiting for job nodes to be allocated, in seconds.
_SLURM_DEFAULT_PROVISION_TIMEOUT = 10


class SlurmPartition(NamedTuple):
"""Information about the Slurm partitions."""
name: str
is_default: bool


# TODO(kevin): Add more API types for other client functions.
class NodeInfo(NamedTuple):
"""Information about a Slurm node from sinfo."""
node: str
state: str
gres: str
cpus: int
memory_gb: float
# The default partition contains a '*' at the end of the name.
# It is the caller's responsibility to strip the '*' if needed.
partition: str


class SlurmClient:
"""Client for Slurm control plane operations."""

def __init__(
self,
ssh_host: str,
ssh_port: int,
ssh_user: str,
ssh_key: Optional[str],
ssh_proxy_command: Optional[str] = None,
):
"""Initialize SlurmClient.

Args:
ssh_host: Hostname of the Slurm controller.
ssh_port: SSH port on the controller.
ssh_user: SSH username.
ssh_key: Path to SSH private key, or None for keyless SSH.
ssh_proxy_command: Optional SSH proxy command.
"""
self.ssh_host = ssh_host
self.ssh_port = ssh_port
self.ssh_user = ssh_user
self.ssh_key = ssh_key
self.ssh_proxy_command = ssh_proxy_command

# Internal runner for executing Slurm CLI commands
# on the controller node.
self._runner = command_runner.SSHCommandRunner(
(ssh_host, ssh_port),
ssh_user,
ssh_key,
ssh_proxy_command=ssh_proxy_command,
)

def query_jobs(
self,
job_name: Optional[str] = None,
state_filters: Optional[List[str]] = None,
) -> List[str]:
"""Query Slurm jobs by state and optional name.

Args:
job_name: Optional job name to filter by.
state_filters: List of job states to filter by
(e.g., ['running', 'pending']). If None, returns all jobs.

Returns:
List of job IDs matching the filters.
"""
cmd = 'squeue --me -h -o "%i"'
if state_filters is not None:
state_filters_str = ','.join(state_filters)
cmd += f' --states {state_filters_str}'
if job_name is not None:
cmd += f' --name {job_name}'

rc, stdout, stderr = self._runner.run(cmd,
require_outputs=True,
stream_logs=False)
subprocess_utils.handle_returncode(rc,
cmd,
'Failed to query Slurm jobs.',
stderr=stderr)

job_ids = stdout.strip().splitlines()
return job_ids

def cancel_jobs_by_name(self,
job_name: str,
signal: Optional[str] = None,
full: bool = False) -> None:
"""Cancel Slurm job(s) by name.

Args:
job_name: Name of the job(s) to cancel.
signal: Optional signal to send to the job(s).
full: If True, signals the batch script and its children processes.
By default, signals other than SIGKILL are not sent to the
batch step (the shell script).
"""
cmd = f'scancel --name {job_name}'
if signal is not None:
cmd += f' --signal {signal}'
if full:
cmd += ' --full'
rc, stdout, stderr = self._runner.run(cmd,
require_outputs=True,
stream_logs=False)
subprocess_utils.handle_returncode(rc,
cmd,
f'Failed to cancel job {job_name}.',
stderr=stderr)
logger.debug(f'Successfully cancelled job {job_name}: {stdout}')

def info(self) -> str:
"""Get Slurm cluster information.

This is useful for checking if the cluster is accessible and
retrieving node information.

Returns:
The stdout output from sinfo.
"""
cmd = 'sinfo'
rc, stdout, stderr = self._runner.run(cmd,
require_outputs=True,
stream_logs=False)
subprocess_utils.handle_returncode(
rc, cmd, 'Failed to get Slurm cluster information.', stderr=stderr)
return stdout

def info_nodes(self) -> List[NodeInfo]:
"""Get Slurm node information.

Returns node names, states, GRES (generic resources like GPUs),
CPUs, memory (MB), and partitions.
"""
cmd = (f'sinfo -h --Node -o '
f'"%N{SEP}%t{SEP}%G{SEP}%c{SEP}%m{SEP}%P"')
rc, stdout, stderr = self._runner.run(cmd,
require_outputs=True,
stream_logs=False)
subprocess_utils.handle_returncode(
rc, cmd, 'Failed to get Slurm node information.', stderr=stderr)

nodes = []
for line in stdout.splitlines():
parts = line.split(SEP)
if len(parts) != 6:
raise RuntimeError(
f'Unexpected output format from sinfo: {line!r}')
try:
node_info = NodeInfo(node=parts[0],
state=parts[1],
gres=parts[2],
cpus=int(parts[3]),
memory_gb=int(parts[4]) / 1024.0,
partition=parts[5])
nodes.append(node_info)
except ValueError as e:
raise RuntimeError(
f'Failed to parse node info from line: {line!r}. '
f'Error: {e}') from e

return nodes

def node_details(self, node_name: str) -> Dict[str, str]:
"""Get detailed Slurm node information.

Returns:
A dictionary of node attributes.
"""

def _parse_scontrol_node_output(output: str) -> Dict[str, str]:
"""Parses the key=value output of 'scontrol show node'."""
node_info = {}
# Split by space, handling values that might have spaces
# if quoted. This is simplified; scontrol can be complex.
parts = output.split()
for part in parts:
if '=' in part:
key, value = part.split('=', 1)
# Simple quote removal, might need refinement
value = value.strip('\'"')
node_info[key] = value
return node_info

cmd = f'scontrol show node {node_name}'
rc, node_details, _ = self._runner.run(cmd,
require_outputs=True,
stream_logs=False)
subprocess_utils.handle_returncode(
rc,
cmd,
f'Failed to get detailed node information for {node_name}.',
stderr=node_details)
node_info = _parse_scontrol_node_output(node_details)
return node_info

def get_node_jobs(self, node_name: str) -> List[str]:
"""Get the list of jobs for a given node name.

Returns:
A list of job names for the current user on the node.
"""
cmd = f'squeue --me -h --nodelist {node_name} -o "%b"'
rc, stdout, stderr = self._runner.run(cmd,
require_outputs=True,
stream_logs=False)
subprocess_utils.handle_returncode(
rc, cmd, f'Failed to get jobs for node {node_name}.', stderr=stderr)
return stdout.splitlines()

def get_job_state(self, job_id: str) -> Optional[str]:
"""Get the state of a Slurm job.

Args:
job_id: The Slurm job ID.

Returns:
The job state (e.g., 'PENDING', 'RUNNING', 'COMPLETED', etc.),
or None if the job is not found.
"""
# Use --only-job-state since we only need the job state.
# This reduces the work required by slurmctld.
cmd = f'squeue -h --only-job-state --jobs {job_id} -o "%T"'
rc, stdout, stderr = self._runner.run(cmd,
require_outputs=True,
stream_logs=False)
if rc != 0:
# Job may not exist
logger.debug(f'Failed to get job state for job {job_id}: {stderr}')
return None

state = stdout.strip()
return state if state else None

@timeline.event
def get_job_reason(self, job_id: str) -> Optional[str]:
"""Get the reason a job is in its current state

Args:
job_id: The Slurm job ID.
"""
# Without --states all, squeue omits terminated jobs.
cmd = f'squeue -h --jobs {job_id} --states all -o "%r"'
rc, stdout, stderr = self._runner.run(cmd,
require_outputs=True,
stream_logs=False)
if rc != 0:
logger.debug(f'Failed to get job info for job {job_id}: {stderr}')
return None

output = stdout.strip()
if not output:
return None

return output if output != 'None' else None

@timeline.event
def wait_for_job_nodes(self, job_id: str, timeout: int) -> None:
"""Wait for a Slurm job to have nodes allocated.

Args:
job_id: The Slurm job ID.
timeout: Maximum time to wait in seconds.
"""
start_time = time.time()
last_state = None

while time.time() - start_time < timeout:
state = self.get_job_state(job_id)

if state != last_state:
logger.debug(f'Job {job_id} state: {state}')
last_state = state

if state is None:
raise RuntimeError(f'Job {job_id} not found. It may have been '
'cancelled or failed.')

if state in ('COMPLETED', 'CANCELLED', 'FAILED', 'TIMEOUT'):
raise RuntimeError(
f'Job {job_id} terminated with state {state} '
'before nodes were allocated.')
# TODO(kevin): Log reason for pending.

# Check if nodes are allocated by trying to get node list
cmd = f'squeue -h --jobs {job_id} -o "%N"'
rc, stdout, stderr = self._runner.run(cmd,
require_outputs=True,
stream_logs=False)

if rc == 0 and stdout.strip():
# Nodes are allocated
logger.debug(
f'Job {job_id} has nodes allocated: {stdout.strip()}')
return
elif rc != 0:
logger.debug(f'Failed to get nodes for job {job_id}: {stderr}')

# Wait before checking again
time.sleep(2)

raise TimeoutError(f'Job {job_id} did not get nodes allocated within '
f'{timeout} seconds. Last state: {last_state}')

@timeline.event
def get_job_nodes(
self,
job_id: str,
wait: bool = True,
timeout: Optional[int] = None) -> Tuple[List[str], List[str]]:
"""Get the list of nodes and their IPs for a given job ID.

The ordering is guaranteed to be stable for the lifetime of the job.

Args:
job_id: The Slurm job ID.
wait: If True, wait for nodes to be allocated before returning.
timeout: Maximum time to wait in seconds. Only used when wait=True.

Returns:
A tuple of (nodes, node_ips) where nodes is a list of node names
and node_ips is a list of corresponding IP addresses.
"""
# Wait for nodes to be allocated if requested
if wait:
if timeout is None:
timeout = _SLURM_DEFAULT_PROVISION_TIMEOUT
self.wait_for_job_nodes(job_id, timeout=timeout)

cmd = (
f'squeue -h --jobs {job_id} -o "%N" | tr \',\' \'\\n\' | '
f'while read node; do '
# TODO(kevin): Use json output for more robust parsing.
f'ip=$(scontrol show node=$node | grep NodeAddr= | '
f'awk -F= \'{{print $2}}\' | awk \'{{print $1}}\'); '
f'echo "$node $ip"; '
f'done')
rc, stdout, stderr = self._runner.run(cmd,
require_outputs=True,
stream_logs=False)
subprocess_utils.handle_returncode(
rc, cmd, f'Failed to get nodes for job {job_id}.', stderr=stderr)
logger.debug(f'Successfully got nodes for job {job_id}: {stdout}')

node_info = {}
for line in stdout.strip().splitlines():
line = line.strip()
if line:
parts = line.split()
if len(parts) >= 2:
node_name = parts[0]
node_ip = parts[1]
node_info[node_name] = node_ip

nodes = list(node_info.keys())
node_ips = [node_info[node] for node in nodes]
if not nodes:
raise RuntimeError(
f'No nodes found for job {job_id}. '
f'The job may have terminated or the output was empty.')
assert (len(nodes) == len(node_ips)
), f'Number of nodes and IPs do not match: {nodes} != {node_ips}'

return nodes, node_ips

def submit_job(
self,
partition: str,
job_name: str,
script_path: str,
) -> str:
"""Submit a Slurm job script.

Args:
partition: Slurm partition to submit to.
job_name: Name to give the job.
script_path: Remote path where the script will be stored.

Returns:
The job ID of the submitted job.
"""
cmd = f'sbatch --partition={partition} {script_path}'
rc, stdout, stderr = self._runner.run(cmd,
require_outputs=True,
stream_logs=False)
subprocess_utils.handle_returncode(rc,
cmd,
'Failed to submit Slurm job.',
stderr=f'{stdout}\n{stderr}')

# Parse job ID from sbatch output (format: "Submitted batch job 12345")
job_id_match = re.search(r'Submitted batch job (\d+)', stdout)
if not job_id_match:
raise RuntimeError(
f'Failed to parse job ID from sbatch output: {stdout}')

job_id = job_id_match.group(1).strip()
logger.debug(f'Successfully submitted Slurm job {job_id} with name '
f'{job_name}: {stdout}')

return job_id

def get_partitions_info(self) -> List[SlurmPartition]:
"""Get the partitions information for the Slurm cluster.

Returns:
List of SlurmPartition objects.
"""
cmd = 'scontrol show partitions -o'
rc, stdout, stderr = self._runner.run(cmd,
require_outputs=True,
stream_logs=False)
subprocess_utils.handle_returncode(rc,
cmd,
'Failed to get Slurm partitions.',
stderr=stderr)

partitions = []
for line in stdout.strip().splitlines():
is_default = False
match = _PARTITION_NAME_REGEX.search(line)
if 'Default=YES' in line:
is_default = True
if match:
partition = match.group(1).strip()
if partition:
partitions.append(
SlurmPartition(name=partition, is_default=is_default))
return partitions

def get_default_partition(self) -> Optional[str]:
"""Get the default partition name for the Slurm cluster.

Returns:
The default partition name, or None if it cannot be determined.
"""
partitions = self.get_partitions_info()
for partition in partitions:
if partition.is_default:
return partition.name
return None

def get_partitions(self) -> List[str]:
"""Get unique partition names in the Slurm cluster.

Returns:
List of partition names. The default partition will not have a '*'
at the end of the name.
"""
return [partition.name for partition in self.get_partitions_info()]

+ 45
- 4
sky/backends/backend_utils.py View File

@@ -147,6 +147,19 @@ CLUSTER_TUNNEL_LOCK_TIMEOUT_SECONDS = 10.0
# Remote dir that holds our runtime files.
_REMOTE_RUNTIME_FILES_DIR = '~/.sky/.runtime_files'

# The maximum size of a command line arguments is 128 KB, i.e. the command
# executed with /bin/sh should be less than 128KB.
# https://github.com/torvalds/linux/blob/master/include/uapi/linux/binfmts.h
#
# If a user have very long run or setup commands, the generated command may
# exceed the limit, as we directly include scripts in job submission commands.
# If the command is too long, we instead write it to a file, rsync and execute
# it.
#
# We use 100KB as a threshold to be safe for other arguments that
# might be added during ssh.
_MAX_INLINE_SCRIPT_LENGTH = 100 * 1024

_ENDPOINTS_RETRY_MESSAGE = ('If the cluster was recently started, '
'please retry after a while.')

@@ -225,6 +238,18 @@ _ACK_MESSAGE = 'ack'
_FORWARDING_FROM_MESSAGE = 'Forwarding from'


def is_command_length_over_limit(command: str) -> bool:
"""Check if the length of the command exceeds the limit.

We calculate the length of the command after quoting the command twice as
when it is executed by the CommandRunner, the command will be quoted twice
to ensure the correctness, which will add significant length to the command.
"""

quoted_length = len(shlex.quote(shlex.quote(command)))
return quoted_length > _MAX_INLINE_SCRIPT_LENGTH


def is_ip(s: str) -> bool:
"""Returns whether this string matches IP_ADDR_REGEX."""
return len(re.findall(IP_ADDR_REGEX, s)) == 1
@@ -946,6 +971,9 @@ def write_cluster_config(
'{conda_auto_activate}',
conda_auto_activate).replace('{is_custom_docker}',
is_custom_docker),
# Currently only used by Slurm. For other clouds, it is
# already part of ray_skypilot_installation_commands
'setup_sky_dirs_commands': constants.SETUP_SKY_DIRS_COMMANDS,
'ray_skypilot_installation_commands':
(constants.RAY_SKYPILOT_INSTALLATION_COMMANDS.replace(
'{sky_wheel_hash}',
@@ -1058,7 +1086,11 @@ def write_cluster_config(
with open(tmp_yaml_path, 'w', encoding='utf-8') as f:
f.write(restored_yaml_content)

config_dict['cluster_name_on_cloud'] = cluster_name_on_cloud
# Read the cluster_name_on_cloud from the restored yaml. This is a hack to
# make sure that launching on the same cluster across multiple users works
# correctly. See #8232.
yaml_config = yaml_utils.read_yaml(tmp_yaml_path)
config_dict['cluster_name_on_cloud'] = yaml_config['cluster_name']

# Make sure to do this before we optimize file mounts. Optimization is
# non-deterministic, but everything else before this point should be
@@ -1105,17 +1137,21 @@ def _add_auth_to_cluster_config(cloud: clouds.Cloud, tmp_yaml_path: str):
"""
config = yaml_utils.read_yaml(tmp_yaml_path)
# Check the availability of the cloud type.
if isinstance(cloud, (
if isinstance(
cloud,
(
clouds.AWS,
clouds.OCI,
clouds.SCP,
# TODO(jwj): Handle Slurm-specific auth logic
clouds.Slurm,
clouds.Vsphere,
clouds.Cudo,
clouds.Paperspace,
clouds.Azure,
clouds.DO,
clouds.Nebius,
)):
)):
config = auth.configure_ssh_info(config)
elif isinstance(cloud, clouds.GCP):
config = auth.setup_gcp_authentication(config)
@@ -2361,7 +2397,12 @@ def _update_cluster_status(
# remain healthy for a while before the cloud completely preempts the VMs.
# We have mitigated this by again first querying the VM state from the cloud
# provider.
if all_nodes_up and run_ray_status_to_check_ray_cluster_healthy():
cloud = handle.launched_resources.cloud

# For Slurm, skip Ray health check since it doesn't use Ray.
should_check_ray = cloud is not None and cloud.uses_ray()
if all_nodes_up and (not should_check_ray or
run_ray_status_to_check_ray_cluster_healthy()):
# NOTE: all_nodes_up calculation is fast due to calling cloud CLI;
# run_ray_status_to_check_all_nodes_up() is slow due to calling `ray get
# head-ip/worker-ips`.


+ 32
- 33
sky/backends/cloud_vm_ray_backend.py View File

@@ -192,18 +192,6 @@ _RAY_UP_WITH_MONKEY_PATCHED_HASH_LAUNCH_CONF_PATH = (
pathlib.Path(directory_utils.get_sky_dir()) / 'backends' /
'monkey_patches' / 'monkey_patch_ray_up.py')

# The maximum size of a command line arguments is 128 KB, i.e. the command
# executed with /bin/sh should be less than 128KB.
# https://github.com/torvalds/linux/blob/master/include/uapi/linux/binfmts.h
#
# If a user have very long run or setup commands, the generated command may
# exceed the limit, as we directly include scripts in job submission commands.
# If the command is too long, we instead write it to a file, rsync and execute
# it.
#
# We use 100KB as a threshold to be safe for other arguments that
# might be added during ssh.
_MAX_INLINE_SCRIPT_LENGTH = 100 * 1024
_EXCEPTION_MSG_AND_RETURNCODE_FOR_DUMP_INLINE_SCRIPT = [
('too long', 255),
('request-uri too large', 1),
@@ -218,18 +206,6 @@ _RESOURCES_UNAVAILABLE_LOG = (
_CLUSTER_LOCK_TIMEOUT = 5.0


def _is_command_length_over_limit(command: str) -> bool:
"""Check if the length of the command exceeds the limit.

We calculate the length of the command after quoting the command twice as
when it is executed by the CommandRunner, the command will be quoted twice
to ensure the correctness, which will add significant length to the command.
"""

quoted_length = len(shlex.quote(shlex.quote(command)))
return quoted_length > _MAX_INLINE_SCRIPT_LENGTH


def _is_message_too_long(returncode: int,
output: Optional[str] = None,
file_path: Optional[str] = None) -> bool:
@@ -294,6 +270,7 @@ def _get_cluster_config_template(cloud):
clouds.Lambda: 'lambda-ray.yml.j2',
clouds.IBM: 'ibm-ray.yml.j2',
clouds.SCP: 'scp-ray.yml.j2',
clouds.Slurm: 'slurm-ray.yml.j2',
clouds.OCI: 'oci-ray.yml.j2',
clouds.Paperspace: 'paperspace-ray.yml.j2',
clouds.PrimeIntellect: 'primeintellect-ray.yml.j2',
@@ -2516,7 +2493,9 @@ class CloudVmRayResourceHandle(backends.backend.ResourceHandle):
@property
def is_grpc_enabled_with_flag(self) -> bool:
"""Returns whether this handle has gRPC enabled and gRPC flag is set."""
return env_options.Options.ENABLE_GRPC.get() and self.is_grpc_enabled
return (env_options.Options.ENABLE_GRPC.get() and
self.is_grpc_enabled and
not isinstance(self.launched_resources.cloud, clouds.Slurm))

def __getstate__(self):
state = self.__dict__.copy()
@@ -3596,6 +3575,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):

def _setup(self, handle: CloudVmRayResourceHandle, task: task_lib.Task,
detach_setup: bool) -> None:

start = time.time()

if task.setup is None:
@@ -3647,7 +3627,8 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
_dump_final_script(setup_script,
constants.PERSISTENT_SETUP_SCRIPT_PATH)

if detach_setup or _is_command_length_over_limit(encoded_script):
if (detach_setup or
backend_utils.is_command_length_over_limit(encoded_script)):
_dump_final_script(setup_script)
create_script_code = 'true'
else:
@@ -3804,7 +3785,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
code = job_lib.JobLibCodeGen.queue_job(job_id, job_submit_cmd)
job_submit_cmd = ' && '.join([mkdir_code, create_script_code, code])

# Should also be ealier than _is_command_length_over_limit
# Should also be ealier than is_command_length_over_limit
# Same reason as in _setup
if self._dump_final_script:
_dump_code_to_file(job_submit_cmd,
@@ -3837,7 +3818,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
tasks=managed_job_tasks,
user_id=managed_job_user_id)

if _is_command_length_over_limit(codegen):
if backend_utils.is_command_length_over_limit(codegen):
_dump_code_to_file(codegen)
queue_job_request = jobsv1_pb2.QueueJobRequest(
job_id=job_id,
@@ -3859,7 +3840,7 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
use_legacy = True

if use_legacy:
if _is_command_length_over_limit(job_submit_cmd):
if backend_utils.is_command_length_over_limit(job_submit_cmd):
_dump_code_to_file(codegen)
job_submit_cmd = f'{mkdir_code} && {code}'

@@ -5850,6 +5831,22 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
return task.envs[constants.USER_ID_ENV_VAR]
return None

def _get_task_codegen_class(
self, handle: CloudVmRayResourceHandle) -> task_codegen.TaskCodeGen:
"""Returns the appropriate TaskCodeGen for the given handle."""
if isinstance(handle.launched_resources.cloud, clouds.Slurm):
assert (handle.cached_cluster_info
is not None), ('cached_cluster_info must be set')
head_instance = handle.cached_cluster_info.get_head_instance()
assert (head_instance is not None), (
'Head instance not found in cached cluster info')
slurm_job_id = head_instance.tags.get('job_id')
assert (slurm_job_id
is not None), ('job_id tag not found in head instance')
return task_codegen.SlurmCodeGen(slurm_job_id=slurm_job_id)
else:
return task_codegen.RayCodeGen()

def _execute_task_one_node(self, handle: CloudVmRayResourceHandle,
task: task_lib.Task, job_id: int,
remote_log_dir: str) -> None:
@@ -5862,15 +5859,16 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):

task_env_vars = self._get_task_env_vars(task, job_id, handle)

codegen = task_codegen.RayCodeGen()
codegen = self._get_task_codegen_class(handle)

codegen.add_prologue(job_id)
codegen.add_setup(
1,
resources_dict,
stable_cluster_internal_ips=internal_ips,
env_vars=task_env_vars,
log_dir=log_dir,
setup_cmd=self._setup_cmd,
setup_log_path=os.path.join(log_dir, 'setup.log'),
)

codegen.add_task(
@@ -5907,15 +5905,16 @@ class CloudVmRayBackend(backends.Backend['CloudVmRayResourceHandle']):
num_actual_nodes = task.num_nodes * handle.num_ips_per_node
task_env_vars = self._get_task_env_vars(task, job_id, handle)

codegen = task_codegen.RayCodeGen()
codegen = self._get_task_codegen_class(handle)

codegen.add_prologue(job_id)
codegen.add_setup(
num_actual_nodes,
resources_dict,
stable_cluster_internal_ips=internal_ips,
env_vars=task_env_vars,
log_dir=log_dir,
setup_cmd=self._setup_cmd,
setup_log_path=os.path.join(log_dir, 'setup.log'),
)

codegen.add_task(


+ 340
- 2
sky/backends/task_codegen.py View File

@@ -4,6 +4,7 @@ import copy
import inspect
import json
import math
import os
import textwrap
from typing import Dict, List, Optional, Tuple

@@ -181,8 +182,8 @@ class TaskCodeGen:
resources_dict: Dict[str, float],
stable_cluster_internal_ips: List[str],
env_vars: Dict[str, str],
log_dir: str,
setup_cmd: Optional[str] = None,
setup_log_path: Optional[str] = None,
) -> None:
"""Generates code to set up the task on each node.

@@ -379,13 +380,15 @@ class RayCodeGen(TaskCodeGen):
resources_dict: Dict[str, float],
stable_cluster_internal_ips: List[str],
env_vars: Dict[str, str],
log_dir: str,
setup_cmd: Optional[str] = None,
setup_log_path: Optional[str] = None,
) -> None:
assert self._has_prologue, ('Call add_prologue() before '
'add_setup().')
self._has_setup = True

setup_log_path = os.path.join(log_dir, 'setup.log')

bundles = [copy.copy(resources_dict) for _ in range(num_nodes)]
# Set CPU to avoid ray hanging the resources allocation
# for remote functions, since the task will request 1 CPU
@@ -631,3 +634,338 @@ class RayCodeGen(TaskCodeGen):
"""Generates code that waits for all tasks, then exits."""
self._code.append('returncodes, _ = get_or_fail(futures, pg)')
super().add_epilogue()


class SlurmCodeGen(TaskCodeGen):
"""Code generator for task execution on Slurm using native srun."""

def __init__(self, slurm_job_id: str):
"""Initialize SlurmCodeGen

Args:
slurm_job_id: The Slurm job ID, i.e. SLURM_JOB_ID
"""
super().__init__()
self._slurm_job_id = slurm_job_id

def add_prologue(self, job_id: int) -> None:
assert not self._has_prologue, 'add_prologue() called twice?'
self._has_prologue = True
self.job_id = job_id

self._add_common_imports()

self._code.append(
textwrap.dedent("""\
import colorama
import copy
import json
import multiprocessing
import signal
import threading
from sky.backends import backend_utils
"""))
self._add_skylet_imports()

self._add_constants()

self._add_logging_functions()

self._code.append(
textwrap.dedent(f"""\
def _cancel_slurm_job_steps():
slurm_job_id = {self._slurm_job_id!r}
assert slurm_job_id is not None, 'SLURM_JOB_ID is not set'
try:
# Query steps for this job: squeue -s -j JOBID -h -o "%i %j"
# Output format: "JOBID.STEPID STEPNAME"
# TODO(kevin): This assumes that compute node is able
# to run client commands against the controller.
# Validate this assumption.
result = subprocess.run(
['squeue', '-s', '-j', slurm_job_id, '-h', '-o', '%i %j'],
capture_output=True, text=True, check=False)
for line in result.stdout.strip().split('\\n'):
if not line:
continue
parts = line.split()
assert len(parts) >= 2, 'Expected at least 2 parts'
step_id, step_name = parts[0], parts[1]
if step_name == f'sky-{self.job_id}':
subprocess.run(['scancel', step_id],
check=False, capture_output=True)
except Exception as e:
print(f'Error in _cancel_slurm_job_steps: {{e}}', flush=True)
pass

def _slurm_cleanup_handler(signum, _frame):
_cancel_slurm_job_steps()
# Re-raise to let default handler terminate.
signal.signal(signum, signal.SIG_DFL)
os.kill(os.getpid(), signum)

signal.signal(signal.SIGTERM, _slurm_cleanup_handler)
"""))

self._code += [
'autostop_lib.set_last_active_time_to_now()',
f'job_lib.set_status({job_id!r}, job_lib.JobStatus.PENDING)',
]

self._setup_cmd: Optional[str] = None
self._setup_envs: Optional[Dict[str, str]] = None
self._setup_log_dir: Optional[str] = None
self._setup_num_nodes: Optional[int] = None

def add_setup(
self,
num_nodes: int,
resources_dict: Dict[str, float],
stable_cluster_internal_ips: List[str],
env_vars: Dict[str, str],
log_dir: str,
setup_cmd: Optional[str] = None,
) -> None:
assert self._has_prologue, ('Call add_prologue() before add_setup().')
self._has_setup = True
self._cluster_num_nodes = len(stable_cluster_internal_ips)
self._stable_cluster_ips = stable_cluster_internal_ips

self._add_waiting_for_resources_msg(num_nodes)

# Store setup information for use in add_task().
if setup_cmd is not None:
setup_envs = env_vars.copy()
setup_envs[constants.SKYPILOT_NUM_NODES] = str(num_nodes)
self._setup_cmd = setup_cmd
self._setup_envs = setup_envs
self._setup_log_dir = log_dir
self._setup_num_nodes = num_nodes

def add_task(
self,
num_nodes: int,
bash_script: Optional[str],
task_name: Optional[str],
resources_dict: Dict[str, float],
log_dir: str,
env_vars: Optional[Dict[str, str]] = None,
) -> None:
"""Generates code for invoking a bash command
using srun within sbatch allocation.
"""
assert self._has_setup, 'Call add_setup() before add_task().'
env_vars = env_vars or {}
task_name = task_name if task_name is not None else 'task'

acc_name, acc_count = self._get_accelerator_details(resources_dict)
num_gpus = 0
if (acc_name is not None and
not accelerator_registry.is_schedulable_non_gpu_accelerator(
acc_name)):
num_gpus = int(math.ceil(acc_count))

# Slurm does not support fractional CPUs.
task_cpu_demand = int(math.ceil(resources_dict.pop('CPU')))

sky_env_vars_dict_str = [
textwrap.dedent(f"""\
sky_env_vars_dict = {{}}
sky_env_vars_dict['SKYPILOT_INTERNAL_JOB_ID'] = {self.job_id}
""")
]

if env_vars:
sky_env_vars_dict_str.extend(f'sky_env_vars_dict[{k!r}] = {v!r}'
for k, v in env_vars.items())
sky_env_vars_dict_str = '\n'.join(sky_env_vars_dict_str)

rclone_flush_script = self._get_rclone_flush_script()
streaming_msg = self._get_job_started_msg()
has_setup_cmd = self._setup_cmd is not None

self._code += [
sky_env_vars_dict_str,
textwrap.dedent(f"""\
script = {bash_script!r}
if script is None:
script = ''
rclone_flush_script = {rclone_flush_script!r}

if script or {has_setup_cmd!r}:
script += rclone_flush_script
sky_env_vars_dict['{constants.SKYPILOT_NUM_GPUS_PER_NODE}'] = {num_gpus}

# Signal files for setup/run synchronization:
# 1. alloc_signal_file: srun has acquired allocation
# 2. setup_done_signal_file: Driver has finished setup, run can proceed
#
# Signal files are stored in home directory, which is
# assumed to be on a shared NFS mount accessible by all nodes.
# To support clusters with non-NFS home directories, we would
# need to let users specify an NFS-backed "working directory"
# or use a different coordination mechanism.
alloc_signal_file = f'~/.sky_alloc_{self._slurm_job_id}_{self.job_id}'
alloc_signal_file = os.path.expanduser(alloc_signal_file)
setup_done_signal_file = f'~/.sky_setup_done_{self._slurm_job_id}_{self.job_id}'
setup_done_signal_file = os.path.expanduser(setup_done_signal_file)

# Start exclusive srun in a thread to reserve allocation (similar to ray.get(pg.ready()))
gpu_arg = f'--gpus-per-node={num_gpus}' if {num_gpus} > 0 else ''

def build_task_runner_cmd(user_script, extra_flags, log_dir, env_vars_dict,
task_name=None, is_setup=False,
alloc_signal=None, setup_done_signal=None):
env_vars_json = json.dumps(env_vars_dict)

log_dir = shlex.quote(log_dir)
env_vars = shlex.quote(env_vars_json)
cluster_ips = shlex.quote(",".join({self._stable_cluster_ips!r}))

runner_args = f'--log-dir={{log_dir}} --env-vars={{env_vars}} --cluster-num-nodes={self._cluster_num_nodes} --cluster-ips={{cluster_ips}}'

if task_name is not None:
runner_args += f' --task-name={{shlex.quote(task_name)}}'

if is_setup:
runner_args += ' --is-setup'

if alloc_signal is not None:
runner_args += f' --alloc-signal-file={{shlex.quote(alloc_signal)}}'

if setup_done_signal is not None:
runner_args += f' --setup-done-signal-file={{shlex.quote(setup_done_signal)}}'

script_path = None
prefix = 'sky_setup_' if is_setup else 'sky_task_'
if backend_utils.is_command_length_over_limit(user_script):
with tempfile.NamedTemporaryFile('w', prefix=prefix, suffix='.sh', delete=False) as f:
f.write(user_script)
script_path = f.name
runner_args += f' --script-path={{shlex.quote(script_path)}}'
else:
runner_args += f' --script={{shlex.quote(user_script)}}'

# Use /usr/bin/env explicitly to work around a Slurm quirk where
# srun's execvp() doesn't check execute permissions, failing when
# $HOME/.local/bin/env (non-executable, from uv installation)
# shadows /usr/bin/env.
job_suffix = '-setup' if is_setup else ''
srun_cmd = (
f'srun --export=ALL --quiet --unbuffered --kill-on-bad-exit --jobid={self._slurm_job_id} '
f'--job-name=sky-{self.job_id}{{job_suffix}} --ntasks-per-node=1 {{extra_flags}} '
f'{{constants.SKY_SLURM_PYTHON_CMD}} -m sky.skylet.executor.slurm {{runner_args}}'
)
return srun_cmd, script_path

def run_thread_func():
# This blocks until Slurm allocates resources (--exclusive)
# --mem=0 to match RayCodeGen's behavior where we don't explicitly request memory.
run_flags = f'--nodes={num_nodes} --cpus-per-task={task_cpu_demand} --mem=0 {{gpu_arg}} --exclusive'
srun_cmd, task_script_path = build_task_runner_cmd(
script, run_flags, {log_dir!r}, sky_env_vars_dict,
task_name={task_name!r},
alloc_signal=alloc_signal_file,
setup_done_signal=setup_done_signal_file
)

proc = subprocess.Popen(srun_cmd, shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True)
for line in proc.stdout:
print(line, end='', flush=True)
proc.wait()

if task_script_path is not None:
os.remove(task_script_path)
return {{'return_code': proc.returncode, 'pid': proc.pid}}

run_thread_result = {{'result': None}}
def run_thread_wrapper():
run_thread_result['result'] = run_thread_func()

run_thread = threading.Thread(target=run_thread_wrapper)
run_thread.start()

# Wait for allocation signal from inside srun
while not os.path.exists(alloc_signal_file):
if not run_thread.is_alive():
# srun failed before creating the signal file.
run_thread.join()
result = run_thread_result['result']
returncode = int(result.get('return_code', 1))
pid = result.get('pid', os.getpid())
msg = f'ERROR: {colorama.Fore.RED}Job {self.job_id}\\'s setup failed with return code {{returncode}} (pid={{pid}}).'
msg += f' See error logs above for more details.{colorama.Style.RESET_ALL}'
print(msg, flush=True)
returncodes = [returncode]
job_lib.set_status({self.job_id!r}, job_lib.JobStatus.FAILED_SETUP)
sys.exit(1)
time.sleep(0.1)

print({streaming_msg!r}, flush=True)

if {has_setup_cmd!r}:
job_lib.set_status({self.job_id!r}, job_lib.JobStatus.SETTING_UP)

# The schedule_step should be called after the job status is set to
# non-PENDING, otherwise, the scheduler will think the current job
# is not submitted yet, and skip the scheduling step.
job_lib.scheduler.schedule_step()

# --overlap as we have already secured allocation with the srun for the run section,
# and otherwise this srun would get blocked and deadlock.
setup_flags = f'--overlap --nodes={self._setup_num_nodes}'
setup_srun, setup_script_path = build_task_runner_cmd(
{self._setup_cmd!r}, setup_flags, {self._setup_log_dir!r}, {self._setup_envs!r},
is_setup=True
)

# Run setup srun directly, streaming output to driver stdout
setup_proc = subprocess.Popen(setup_srun, shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True)
for line in setup_proc.stdout:
print(line, end='', flush=True)
setup_proc.wait()

if setup_script_path is not None:
os.remove(setup_script_path)

setup_returncode = setup_proc.returncode
if setup_returncode != 0:
setup_pid = setup_proc.pid
msg = f'ERROR: {colorama.Fore.RED}Job {self.job_id}\\'s setup failed with return code {{setup_returncode}} (pid={{setup_pid}}).'
msg += f' See error logs above for more details.{colorama.Style.RESET_ALL}'
print(msg, flush=True)
job_lib.set_status({self.job_id!r}, job_lib.JobStatus.FAILED_SETUP)
# Cancel the srun spawned by run_thread_func.
_cancel_slurm_job_steps()
sys.exit(1)

job_lib.set_job_started({self.job_id!r})
if not {has_setup_cmd!r}:
# Need to call schedule_step() to make sure the scheduler
# schedule the next pending job.
job_lib.scheduler.schedule_step()

# Signal run thread to proceed.
pathlib.Path(setup_done_signal_file).touch()

# Wait for run thread to complete.
run_thread.join()
result = run_thread_result['result']

# Cleanup signal files
if os.path.exists(alloc_signal_file):
os.remove(alloc_signal_file)
if os.path.exists(setup_done_signal_file):
os.remove(setup_done_signal_file)

returncodes = [int(result.get('return_code', 1))]
else:
returncodes = [0]
"""),
]

+ 0
- 3
sky/catalog/__init__.py View File

@@ -127,12 +127,9 @@ def list_accelerator_realtime(
case_sensitive: bool = True,
) -> Tuple[Dict[str, List[int]], Dict[str, int], Dict[str, int]]:
"""Lists all accelerators offered by Sky with their realtime availability.

Realtime availability is the total number of accelerators in the cluster
and number of accelerators available at the time of the call.

Used for fixed size cluster settings, such as Kubernetes.

Returns:
A tuple of three dictionaries mapping canonical accelerator names to:
- A list of available counts. (e.g., [1, 2, 4])


+ 12
- 4
sky/catalog/kubernetes_catalog.py View File

@@ -204,6 +204,9 @@ def _list_accelerators(
min_quantity_filter = quantity_filter if quantity_filter else 1

for node in nodes:
# Check if node is ready
node_is_ready = node.is_ready()

for key in keys:
if key in node.metadata.labels:
accelerator_name = lf.get_accelerator_from_label_value(
@@ -260,6 +263,15 @@ def _list_accelerators(
total_accelerators_capacity[
accelerator_name] += quantized_count

# Initialize the total_accelerators_available to make sure the
# key exists in the dictionary.
total_accelerators_available[accelerator_name] = (
total_accelerators_available.get(accelerator_name, 0))

# Skip availability counting for not-ready nodes
if not node_is_ready:
continue

if error_on_get_allocated_gpu_qty_by_node:
# If we can't get the allocated GPU quantity by each node,
# we can't get the GPU usage.
@@ -268,10 +280,6 @@ def _list_accelerators(

allocated_qty = allocated_qty_by_node[node.metadata.name]
accelerators_available = accelerator_count - allocated_qty
# Initialize the total_accelerators_available to make sure the
# key exists in the dictionary.
total_accelerators_available[accelerator_name] = (
total_accelerators_available.get(accelerator_name, 0))

if accelerators_available >= min_quantity_filter:
quantized_availability = min_quantity_filter * (


+ 243
- 0
sky/catalog/slurm_catalog.py View File

@@ -0,0 +1,243 @@
"""Slurm Catalog."""

import collections
import re
from typing import Dict, List, Optional, Set, Tuple

from sky import check as sky_check
from sky import clouds as sky_clouds
from sky import sky_logging
from sky.catalog import common
from sky.clouds import cloud
from sky.provision.slurm import utils as slurm_utils
from sky.utils import resources_utils

logger = sky_logging.init_logger(__name__)

_DEFAULT_NUM_VCPUS = 2
_DEFAULT_MEMORY_CPU_RATIO = 1


def instance_type_exists(instance_type: str) -> bool:
"""Check if the given instance type is valid for Slurm."""
return slurm_utils.SlurmInstanceType.is_valid_instance_type(instance_type)


def get_default_instance_type(cpus: Optional[str] = None,
memory: Optional[str] = None,
disk_tier: Optional[
resources_utils.DiskTier] = None,
region: Optional[str] = None,
zone: Optional[str] = None) -> Optional[str]:
# Delete unused parameters.
del disk_tier, region, zone

# Slurm provisions resources via --cpus-per-task and --mem.
instance_cpus = float(
cpus.strip('+')) if cpus is not None else _DEFAULT_NUM_VCPUS
if memory is not None:
if memory.endswith('+'):
instance_mem = float(memory[:-1])
elif memory.endswith('x'):
instance_mem = float(memory[:-1]) * instance_cpus
else:
instance_mem = float(memory)
else:
instance_mem = instance_cpus * _DEFAULT_MEMORY_CPU_RATIO
virtual_instance_type = slurm_utils.SlurmInstanceType(
instance_cpus, instance_mem).name
return virtual_instance_type


def list_accelerators(
gpus_only: bool,
name_filter: Optional[str],
region_filter: Optional[str],
quantity_filter: Optional[int],
case_sensitive: bool = True,
all_regions: bool = False,
require_price: bool = True) -> Dict[str, List[common.InstanceTypeInfo]]:
"""List accelerators in Slurm clusters.

Returns a dictionary mapping GPU type to a list of InstanceTypeInfo objects.
"""
return list_accelerators_realtime(gpus_only, name_filter, region_filter,
quantity_filter, case_sensitive,
all_regions, require_price)[0]


def list_accelerators_realtime(
gpus_only: bool = True,
name_filter: Optional[str] = None,
region_filter: Optional[str] = None,
quantity_filter: Optional[int] = None,
case_sensitive: bool = True,
all_regions: bool = False,
require_price: bool = False,
) -> Tuple[Dict[str, List[common.InstanceTypeInfo]], Dict[str, int], Dict[str,
int]]:
"""Fetches real-time accelerator information from the Slurm cluster.

Uses the `get_slurm_node_info_list` helper function.

Args:
gpus_only: If True, only return GPU accelerators.
name_filter: Regex filter for accelerator names (e.g., 'V100', 'gpu').
region_filter: Optional filter for Slurm partitions.
quantity_filter: Minimum number of accelerators required per node.
case_sensitive: Whether name_filter is case-sensitive.
all_regions: Unused in Slurm context.
require_price: Unused in Slurm context.

Returns:
A tuple of three dictionaries:
- qtys_map: Maps GPU type to set of InstanceTypeInfo objects for unique
counts found per node.
- total_capacity: Maps GPU type to total count across all nodes.
- total_available: Maps GPU type to total free count across all nodes.
"""
del gpus_only, all_regions, require_price

enabled_clouds = sky_check.get_cached_enabled_clouds_or_refresh(
cloud.CloudCapability.COMPUTE)
if not sky_clouds.cloud_in_iterable(sky_clouds.Slurm(), enabled_clouds):
return {}, {}, {}

if region_filter is None:
# Get the first available cluster as default
all_clusters = slurm_utils.get_all_slurm_cluster_names()
if not all_clusters:
return {}, {}, {}
slurm_cluster = all_clusters[0]
else:
slurm_cluster = region_filter

partition_filter = slurm_utils.get_cluster_default_partition(slurm_cluster)

# Call the helper function to get node info
slurm_nodes_info = slurm_utils.slurm_node_info(
slurm_cluster_name=slurm_cluster)

if not slurm_nodes_info:
# Customize error message based on filters
err_msg = 'No matching GPU nodes found in the Slurm cluster'
filters_applied = []
if name_filter:
filters_applied.append(f'gpu_name={name_filter!r}')
if quantity_filter:
filters_applied.append(f'quantity>={quantity_filter}')
if region_filter:
filters_applied.append(f'cluster={region_filter!r}')
if filters_applied:
err_msg += f' with filters ({", ".join(filters_applied)})'
err_msg += '.'
logger.error(
err_msg) # Log as error as it indicates no usable resources found
raise ValueError(err_msg)

# Aggregate results into the required format
qtys_map: Dict[str,
Set[common.InstanceTypeInfo]] = collections.defaultdict(set)
total_capacity: Dict[str, int] = collections.defaultdict(int)
total_available: Dict[str, int] = collections.defaultdict(int)

for node_info in slurm_nodes_info:
gpu_type = node_info['gpu_type']
node_total_gpus = node_info['total_gpus']
node_free_gpus = node_info['free_gpus']
partition = node_info['partition']

# Apply name filter to the determined GPU type
regex_flags = 0 if case_sensitive else re.IGNORECASE
if name_filter and not re.match(
name_filter, gpu_type, flags=regex_flags):
continue

# Apply quantity filter (total GPUs on node must meet this)
if quantity_filter and node_total_gpus < quantity_filter:
continue

# Apply partition filter if specified
# TODO(zhwu): when a node is in multiple partitions, the partition
# mapping from node to partition does not work.
# if partition_filter and partition != partition_filter:
# continue

# Create InstanceTypeInfo objects for various GPU counts
# Similar to Kubernetes, generate powers of 2 up to node_total_gpus
if node_total_gpus > 0:
count = 1
while count <= node_total_gpus:
instance_info = common.InstanceTypeInfo(
instance_type=None, # Slurm doesn't have instance types
accelerator_name=gpu_type,
accelerator_count=count,
cpu_count=node_info['vcpu_count'],
memory=node_info['memory_gb'],
price=0.0, # Slurm doesn't have price info
region=partition, # Use partition as region
cloud='slurm', # Specify cloud as 'slurm'
device_memory=0.0, # No GPU memory info from Slurm
spot_price=0.0, # Slurm doesn't have spot pricing
)
qtys_map[gpu_type].add(instance_info)
count *= 2

# Add the actual total if it's not already included
# (e.g., if node has 12 GPUs, include counts 1, 2, 4, 8, 12)
if count // 2 != node_total_gpus:
instance_info = common.InstanceTypeInfo(
instance_type=None,
accelerator_name=gpu_type,
accelerator_count=node_total_gpus,
cpu_count=node_info['vcpu_count'],
memory=node_info['memory_gb'],
price=0.0,
region=partition,
cloud='slurm',
device_memory=0.0,
spot_price=0.0,
)
qtys_map[gpu_type].add(instance_info)

# Map of GPU type -> total count across all matched nodes
total_capacity[gpu_type] += node_total_gpus

# Map of GPU type -> total *free* count across all matched nodes
total_available[gpu_type] += node_free_gpus

# Check if any GPUs were found after applying filters
if not total_capacity:
err_msg = 'No matching GPU nodes found in the Slurm cluster'
filters_applied = []
if name_filter:
filters_applied.append(f'gpu_name={name_filter!r}')
if quantity_filter:
filters_applied.append(f'quantity>={quantity_filter}')
if partition_filter:
filters_applied.append(f'partition={partition_filter!r}')
if filters_applied:
err_msg += f' with filters ({", ".join(filters_applied)})'
err_msg += '.'
logger.error(err_msg)
raise ValueError(err_msg)

# Convert sets of InstanceTypeInfo to sorted lists
final_qtys_map = {
gpu: sorted(list(instances), key=lambda x: x.accelerator_count)
for gpu, instances in qtys_map.items()
}

logger.debug(f'Aggregated Slurm GPU Info: '
f'qtys={final_qtys_map}, '
f'capacity={dict(total_capacity)}, '
f'available={dict(total_available)}')

return final_qtys_map, dict(total_capacity), dict(total_available)


def validate_region_zone(
region_name: Optional[str],
zone_name: Optional[str],
) -> Tuple[Optional[str], Optional[str]]:
return (region_name, zone_name)

+ 14
- 3
sky/check.py View File

@@ -586,6 +586,9 @@ def _format_context_details(cloud: Union[str, sky_clouds.Cloud],
if isinstance(cloud_type, sky_clouds.SSH):
# Get the cluster names by reading from the node pools file
contexts = sky_clouds.SSH.get_ssh_node_pool_contexts()
elif isinstance(cloud_type, sky_clouds.Slurm):
# Get the cluster names from SLURM config
contexts = sky_clouds.Slurm.existing_allowed_clusters()
else:
assert isinstance(cloud_type, sky_clouds.Kubernetes)
contexts = sky_clouds.Kubernetes.existing_allowed_contexts()
@@ -657,8 +660,12 @@ def _format_context_details(cloud: Union[str, sky_clouds.Cloud],
'to set up.'))
contexts_formatted.append(
f'\n {symbol}{cleaned_context}{text_suffix}')
identity_str = ('SSH Node Pools' if isinstance(cloud_type, sky_clouds.SSH)
else 'Allowed contexts')
if isinstance(cloud_type, sky_clouds.SSH):
identity_str = 'SSH Node Pools'
elif isinstance(cloud_type, sky_clouds.Slurm):
identity_str = 'Allowed clusters'
else:
identity_str = 'Allowed contexts'
return f'\n {identity_str}:{"".join(contexts_formatted)}'


@@ -677,7 +684,11 @@ def _format_enabled_cloud(cloud_name: str,
cloud_and_capabilities = f'{cloud_name} [{", ".join(capabilities)}]'
title = _green_color(cloud_and_capabilities)

if cloud_name in [repr(sky_clouds.Kubernetes()), repr(sky_clouds.SSH())]:
if cloud_name in [
repr(sky_clouds.Kubernetes()),
repr(sky_clouds.SSH()),
repr(sky_clouds.Slurm())
]:
return (f'{title}' + _format_context_details(
cloud_name, show_details=False, ctx2text=ctx2text))
return _green_color(cloud_and_capabilities)


+ 329
- 22
sky/client/cli/command.py View File

@@ -189,6 +189,7 @@ def _get_cluster_records_and_set_ssh_config(
# can still exist in the record, and we check for credentials to avoid
# updating the SSH config for non-existent clusters.
credentials = record['credentials']
ips = handle.cached_external_ips
if isinstance(handle.launched_resources.cloud, clouds.Kubernetes):
# Replace the proxy command to proxy through the SkyPilot API
# server with websocket.
@@ -217,10 +218,44 @@ def _get_cluster_records_and_set_ssh_config(
f'{server_common.get_server_url()} '
f'{handle.cluster_name}\"')
credentials['ssh_proxy_command'] = proxy_command
elif isinstance(handle.launched_resources.cloud, clouds.Slurm):
# TODO(kevin): This is a temporary workaround, ideally we want to
# get a shell through srun --pty bash on the existing sbatch job.

# Proxy through the controller/login node to reach the worker node.
if (handle.cached_internal_ips is None or
not handle.cached_internal_ips):
logger.debug(
f'Cluster {name} does not have cached internal IPs. '
'Skipping SSH config update.')
cluster_utils.SSHConfigHelper.remove_cluster(name)
continue

escaped_key_path = shlex.quote(
cluster_utils.SSHConfigHelper.generate_local_key_file(
handle.cluster_name, credentials))
controller_host = handle.cached_external_ips[0]

# Build jump proxy: ssh to worker via controller/login node
proxy_command = (f'ssh -tt -i {escaped_key_path} '
'-o StrictHostKeyChecking=no '
'-o UserKnownHostsFile=/dev/null '
'-o IdentitiesOnly=yes '
'-W %h:%p '
f'{handle.ssh_user}@{controller_host}')
original_proxy = credentials.get('ssh_proxy_command')
if original_proxy:
proxy_command += (
f' -o ProxyCommand={shlex.quote(original_proxy)}')

credentials['ssh_proxy_command'] = proxy_command

# For Slurm, use the worker's internal IP as the SSH target
ips = handle.cached_internal_ips

cluster_utils.SSHConfigHelper.add_cluster(
handle.cluster_name,
handle.cached_external_ips,
ips,
credentials,
handle.cached_external_ssh_ports,
handle.docker_user,
@@ -832,7 +867,19 @@ class _NaturalOrderGroup(click.Group):
"""

def list_commands(self, ctx): # pylint: disable=unused-argument
return self.commands.keys()
# Preserve definition order but hide aliases (same command object) and
# commands explicitly marked as hidden.
seen_commands = set()
names = []
for name, command in self.commands.items():
if getattr(command, 'hidden', False):
continue
command_id = id(command)
if command_id in seen_commands:
continue
seen_commands.add(command_id)
names.append(name)
return names

@usage_lib.entrypoint('sky.cli', fallback=True)
def invoke(self, ctx):
@@ -3535,6 +3582,10 @@ def show_gpus(
maximum quantities of the GPU available on a single node and the real-time
availability of the GPU across all nodes in the Kubernetes cluster.

If ``--cloud slurm`` is specified, it will show the maximum quantities of
the GPU available on a single node and the real-time availability of the
GPU across all nodes in the Slurm cluster.

Definitions of certain fields:

* ``DEVICE_MEM``: Memory of a single device; does not depend on the device
@@ -3590,6 +3641,8 @@ def show_gpus(
cloud_is_kubernetes = isinstance(
cloud_obj, clouds.Kubernetes) and not isinstance(cloud_obj, clouds.SSH)
cloud_is_ssh = isinstance(cloud_obj, clouds.SSH)
cloud_is_slurm = isinstance(cloud_obj, clouds.Slurm)

# TODO(romilb): We should move this to the backend.
kubernetes_autoscaling = skypilot_config.get_effective_region_config(
cloud='kubernetes',
@@ -3598,6 +3651,7 @@ def show_gpus(
default_value=None) is not None
kubernetes_is_enabled = clouds.Kubernetes.canonical_name() in enabled_clouds
ssh_is_enabled = clouds.SSH.canonical_name() in enabled_clouds
slurm_is_enabled = clouds.Slurm.canonical_name() in enabled_clouds
query_k8s_realtime_gpu = (kubernetes_is_enabled and
(cloud_name is None or cloud_is_kubernetes))
query_ssh_realtime_gpu = (ssh_is_enabled and
@@ -3657,8 +3711,9 @@ def show_gpus(
raise ValueError(full_err_msg)
no_permissions_str = '<no permissions>'
realtime_gpu_infos = []
# Stores per-GPU totals as [ready_capacity, available, not_ready].
total_gpu_info: Dict[str, List[int]] = collections.defaultdict(
lambda: [0, 0])
lambda: [0, 0, 0])
all_nodes_info = []

# display an aggregated table for all contexts
@@ -3669,6 +3724,33 @@ def show_gpus(

num_filtered_contexts = 0

def _count_not_ready_gpus(
nodes_info: Optional['models.KubernetesNodesInfo']
) -> Dict[str, int]:
"""Return counts of GPUs on not ready nodes keyed by GPU type."""
not_ready_counts: Dict[str, int] = collections.defaultdict(int)
if nodes_info is None:
return not_ready_counts

node_info_dict = getattr(nodes_info, 'node_info_dict', {}) or {}
for node_info in node_info_dict.values():
accelerator_type = getattr(node_info, 'accelerator_type', None)
if not accelerator_type:
continue

total_info = getattr(node_info, 'total', {})
accelerator_count = 0
if isinstance(total_info, dict):
accelerator_count = int(
total_info.get('accelerator_count', 0))
if accelerator_count <= 0:
continue

node_is_ready = getattr(node_info, 'is_ready', True)
if not node_is_ready:
not_ready_counts[accelerator_type] += accelerator_count
return not_ready_counts

if realtime_gpu_availability_lists:
for (ctx, availability_list) in realtime_gpu_availability_lists:
if not _filter_ctx(ctx):
@@ -3678,6 +3760,12 @@ def show_gpus(
else:
display_ctx = ctx
num_filtered_contexts += 1
# Collect node info for this context before building tables so
# we can exclude GPUs on not ready nodes from the totals.
nodes_info = sdk.stream_and_get(
sdk.kubernetes_node_info(context=ctx))
context_not_ready_counts = _count_not_ready_gpus(nodes_info)

realtime_gpu_table = log_utils.create_table(
['GPU', qty_header, 'UTILIZATION'])
for realtime_gpu_availability in sorted(availability_list):
@@ -3686,24 +3774,116 @@ def show_gpus(
available_qty = (gpu_availability.available
if gpu_availability.available != -1 else
no_permissions_str)
# Exclude GPUs on not ready nodes from capacity counts.
not_ready_count = min(
context_not_ready_counts.get(gpu_availability.gpu, 0),
gpu_availability.capacity)
# Ensure capacity is never below the reported available
# quantity (if available is unknown, treat as 0 for totals).
available_for_totals = max(
gpu_availability.available
if gpu_availability.available != -1 else 0, 0)
effective_capacity = max(
gpu_availability.capacity - not_ready_count,
available_for_totals)
utilization = (
f'{available_qty} of {effective_capacity} free')
if not_ready_count > 0:
utilization += f' ({not_ready_count} not ready)'
realtime_gpu_table.add_row([
gpu_availability.gpu,
_list_to_str(gpu_availability.counts),
f'{available_qty} of {gpu_availability.capacity} free',
utilization,
])
gpu = gpu_availability.gpu
capacity = gpu_availability.capacity
# we want total, so skip permission denied.
available = max(gpu_availability.available, 0)
if capacity > 0:
total_gpu_info[gpu][0] += capacity
total_gpu_info[gpu][1] += available
if effective_capacity > 0 or not_ready_count > 0:
total_gpu_info[gpu][0] += effective_capacity
total_gpu_info[gpu][1] += available_for_totals
total_gpu_info[gpu][2] += not_ready_count
realtime_gpu_infos.append((display_ctx, realtime_gpu_table))
# Collect node info for this context
nodes_info = sdk.stream_and_get(
sdk.kubernetes_node_info(context=ctx))
all_nodes_info.append((display_ctx, nodes_info))
if num_filtered_contexts > 1:
total_realtime_gpu_table = log_utils.create_table(
['GPU', 'UTILIZATION'])
for gpu, stats in total_gpu_info.items():
not_ready = stats[2]
utilization = f'{stats[1]} of {stats[0]} free'
if not_ready > 0:
utilization += f' ({not_ready} not ready)'
total_realtime_gpu_table.add_row([gpu, utilization])
else:
total_realtime_gpu_table = None

return realtime_gpu_infos, total_realtime_gpu_table, all_nodes_info

def _get_slurm_realtime_gpu_tables(
name_filter: Optional[str] = None,
quantity_filter: Optional[int] = None
) -> Tuple[List[Tuple[str, 'prettytable.PrettyTable']],
Optional['prettytable.PrettyTable']]:
"""Get Slurm GPU availability tables.

Args:
name_filter: Filter GPUs by name.
quantity_filter: Filter GPUs by quantity.

Returns:
A tuple of (realtime_gpu_infos, total_realtime_gpu_table).
"""
if quantity_filter:
qty_header = 'QTY_FILTER'
else:
qty_header = 'REQUESTABLE_QTY_PER_NODE'

realtime_gpu_availability_lists = sdk.stream_and_get(
sdk.realtime_slurm_gpu_availability(
name_filter=name_filter, quantity_filter=quantity_filter))
if not realtime_gpu_availability_lists:
err_msg = 'No GPUs found in any Slurm partition. '
debug_msg = 'To further debug, run: sky check slurm '
if name_filter is not None:
gpu_info_msg = f' {name_filter!r}'
if quantity_filter is not None:
gpu_info_msg += (' with requested quantity'
f' {quantity_filter}')
err_msg = (f'Resources{gpu_info_msg} not found '
'in any Slurm partition. ')
debug_msg = ('To show available accelerators on Slurm,'
' run: sky show-gpus --cloud slurm ')
raise ValueError(err_msg + debug_msg)

realtime_gpu_infos = []
total_gpu_info: Dict[str, List[int]] = collections.defaultdict(
lambda: [0, 0])

for (slurm_cluster,
availability_list) in realtime_gpu_availability_lists:
realtime_gpu_table = log_utils.create_table(
['GPU', qty_header, 'UTILIZATION'])
for realtime_gpu_availability in sorted(availability_list):
gpu_availability = models.RealtimeGpuAvailability(
*realtime_gpu_availability)
# Use the counts directly from the backend, which are already
# generated in powers of 2 (plus any actual maximums)
requestable_quantities = gpu_availability.counts
realtime_gpu_table.add_row([
gpu_availability.gpu,
_list_to_str(requestable_quantities),
(f'{gpu_availability.available} of '
f'{gpu_availability.capacity} free'),
])
gpu = gpu_availability.gpu
capacity = gpu_availability.capacity
available = gpu_availability.available
if capacity > 0:
total_gpu_info[gpu][0] += capacity
total_gpu_info[gpu][1] += available
realtime_gpu_infos.append((slurm_cluster, realtime_gpu_table))

# display an aggregated table for all partitions
# if there are more than one partitions with GPUs
if len(realtime_gpu_infos) > 1:
total_realtime_gpu_table = log_utils.create_table(
['GPU', 'UTILIZATION'])
for gpu, stats in total_gpu_info.items():
@@ -3712,7 +3892,7 @@ def show_gpus(
else:
total_realtime_gpu_table = None

return realtime_gpu_infos, total_realtime_gpu_table, all_nodes_info
return realtime_gpu_infos, total_realtime_gpu_table

def _format_kubernetes_node_info_combined(
contexts_info: List[Tuple[str, 'models.KubernetesNodesInfo']],
@@ -3736,11 +3916,16 @@ def show_gpus(
acc_type = node_info.accelerator_type
if acc_type is None:
acc_type = '-'
node_table.add_row([
context_name, node_name, acc_type,
f'{available} of {node_info.total["accelerator_count"]} '
'free'
])
utilization_str = (
f'{available} of '
f'{node_info.total["accelerator_count"]} free')
# Check if node is ready (defaults to True for backward
# compatibility with older server versions)
node_is_ready = getattr(node_info, 'is_ready', True)
if not node_is_ready:
utilization_str += ' (Node NotReady)'
node_table.add_row(
[context_name, node_name, acc_type, utilization_str])

k8s_per_node_acc_message = (f'{cloud_str} per-node GPU availability')
if hints:
@@ -3751,6 +3936,43 @@ def show_gpus(
f'{colorama.Style.RESET_ALL}\n'
f'{node_table.get_string()}')

def _format_slurm_node_info() -> str:
node_table = log_utils.create_table([
'CLUSTER',
'NODE',
'PARTITION',
'STATE',
'GPU',
'UTILIZATION',
])

# Get all cluster names
slurm_cluster_names = clouds.Slurm.existing_allowed_clusters()

# Query each cluster
for cluster_name in slurm_cluster_names:
nodes_info = sdk.stream_and_get(
sdk.slurm_node_info(slurm_cluster_name=cluster_name))

for node_info in nodes_info:
node_table.add_row([
cluster_name,
node_info.get('node_name'),
node_info.get('partition', '-'),
node_info.get('node_state'),
node_info.get('gpu_type') or '',
(f'{node_info.get("free_gpus", 0)} of '
f'{node_info.get("total_gpus", 0)} free'),
])

slurm_per_node_msg = 'Slurm per node accelerator availability'
# Optional: Add hint message if needed, similar to k8s

return (f'{colorama.Fore.LIGHTMAGENTA_EX}{colorama.Style.NORMAL}'
f'{slurm_per_node_msg}'
f'{colorama.Style.RESET_ALL}\n'
f'{node_table.get_string()}')

def _format_kubernetes_realtime_gpu(
total_table: Optional['prettytable.PrettyTable'],
k8s_realtime_infos: List[Tuple[str, 'prettytable.PrettyTable']],
@@ -3880,6 +4102,28 @@ def show_gpus(
return True, print_section_titles
return False, print_section_titles

def _format_slurm_realtime_gpu(
total_table, slurm_realtime_infos,
show_node_info: bool) -> Generator[str, None, None]:
# print total table
yield (f'{colorama.Fore.GREEN}{colorama.Style.BRIGHT}'
'Slurm GPUs'
f'{colorama.Style.RESET_ALL}\n')
if total_table is not None:
yield from total_table.get_string()
yield '\n'

# print individual infos.
for (partition, slurm_realtime_table) in slurm_realtime_infos:
partition_str = f'Slurm Cluster: {partition}'
yield (f'{colorama.Fore.CYAN}{colorama.Style.BRIGHT}'
f'{partition_str}'
f'{colorama.Style.RESET_ALL}\n')
yield from slurm_realtime_table.get_string()
yield '\n'
if show_node_info:
yield _format_slurm_node_info()

def _output() -> Generator[str, None, None]:
gpu_table = log_utils.create_table(
['COMMON_GPU', 'AVAILABLE_QUANTITIES'])
@@ -3897,10 +4141,12 @@ def show_gpus(
if cloud_name is None:
clouds_to_list = [
c for c in constants.ALL_CLOUDS
if c != 'kubernetes' and c != 'ssh'
if c != 'kubernetes' and c != 'ssh' and c != 'slurm'
]

k8s_messages = ''
slurm_messages = ''
k8s_printed = False
if accelerator_str is None:
# Collect k8s related messages in k8s_messages and print them at end
print_section_titles = False
@@ -3912,6 +4158,7 @@ def show_gpus(
yield '\n\n'
stop_iter_one, print_section_titles_one, k8s_messages_one = (
yield from _possibly_show_k8s_like_realtime(is_ssh))
k8s_printed = True
stop_iter = stop_iter or stop_iter_one
print_section_titles = (print_section_titles or
print_section_titles_one)
@@ -3919,11 +4166,45 @@ def show_gpus(
prev_print_section_titles = print_section_titles_one
if stop_iter:
return
# If cloud is slurm, we want to show real-time capacity
if slurm_is_enabled and (cloud_name is None or cloud_is_slurm):
try:
# If --cloud slurm is not specified, we want to catch
# the case where no GPUs are available on the cluster and
# print the warning at the end.
slurm_realtime_infos, total_table = (
_get_slurm_realtime_gpu_tables())
except ValueError as e:
if not cloud_is_slurm:
# Make it a note if cloud is not slurm
slurm_messages += 'Note: '
slurm_messages += str(e)
else:
print_section_titles = True
if k8s_printed:
yield '\n'

yield from _format_slurm_realtime_gpu(total_table,
slurm_realtime_infos,
show_node_info=True)

if cloud_is_slurm:
# Do not show clouds if --cloud slurm is specified
if not slurm_is_enabled:
yield ('Slurm is not enabled. To fix, run: '
'sky check slurm ')
yield slurm_messages
return

# For show_all, show the k8s message at the start since output is
# long and the user may not scroll to the end.
if show_all and k8s_messages:
yield k8s_messages
if show_all and (k8s_messages or slurm_messages):
if k8s_messages:
yield k8s_messages
if slurm_messages:
if k8s_messages:
yield '\n'
yield slurm_messages
yield '\n\n'

list_accelerator_counts_result = sdk.stream_and_get(
@@ -3971,9 +4252,10 @@ def show_gpus(
else:
yield ('\n\nHint: use -a/--all to see all accelerators '
'(including non-common ones) and pricing.')
if k8s_messages:
if k8s_messages or slurm_messages:
yield '\n'
yield k8s_messages
yield slurm_messages
return
else:
# Parse accelerator string
@@ -4013,6 +4295,31 @@ def show_gpus(
if stop_iter:
return

# Handle Slurm filtering by name and quantity
if (slurm_is_enabled and (cloud_name is None or cloud_is_slurm) and
not show_all):
# Print section title if not showing all and instead a specific
# accelerator is requested
print_section_titles = True
try:
slurm_realtime_infos, total_table = (
_get_slurm_realtime_gpu_tables(name_filter=name,
quantity_filter=quantity))

yield from _format_slurm_realtime_gpu(total_table,
slurm_realtime_infos,
show_node_info=False)
except ValueError as e:
# In the case of a specific accelerator, show the error message
# immediately (e.g., "Resources A10G not found ...")
yield str(e)
yield slurm_messages
if cloud_is_slurm:
# Do not show clouds if --cloud slurm is specified
if not slurm_is_enabled:
yield ('Slurm is not enabled. To fix, run: '
'sky check slurm ')
return
# For clouds other than Kubernetes, get the accelerator details
# Case-sensitive
list_accelerators_result = sdk.stream_and_get(


+ 56
- 2
sky/client/sdk.py View File

@@ -42,6 +42,7 @@ from sky.server.requests import request_names
from sky.server.requests import requests as requests_lib
from sky.skylet import autostop_lib
from sky.skylet import constants
from sky.ssh_node_pools import utils as ssh_utils
from sky.usage import usage_lib
from sky.utils import admin_policy_utils
from sky.utils import annotations
@@ -57,7 +58,6 @@ from sky.utils import status_lib
from sky.utils import subprocess_utils
from sky.utils import ux_utils
from sky.utils import yaml_utils
from sky.utils.kubernetes import ssh_utils

if typing.TYPE_CHECKING:
import base64
@@ -675,7 +675,7 @@ def _launch(
clusters = get(status_request_id)
cluster_user_hash = common_utils.get_user_hash()
cluster_user_hash_str = ''
current_user = common_utils.get_current_user_name()
current_user = common_utils.get_local_user_name()
cluster_user_name = current_user
if not clusters:
# Show the optimize log before the prompt if the cluster does not
@@ -2744,3 +2744,57 @@ def api_logout() -> None:
_clear_api_server_config()
logger.info(f'{colorama.Fore.GREEN}Logged out of SkyPilot API server.'
f'{colorama.Style.RESET_ALL}')


@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
@versions.minimal_api_version(24)
@annotations.client_api
def realtime_slurm_gpu_availability(
name_filter: Optional[str] = None,
quantity_filter: Optional[int] = None) -> server_common.RequestId:
"""Gets the real-time Slurm GPU availability.

Args:
name_filter: Optional name filter for GPUs.
quantity_filter: Optional quantity filter for GPUs.

Returns:
The request ID of the Slurm GPU availability request.
"""
body = payloads.SlurmGpuAvailabilityRequestBody(
name_filter=name_filter,
quantity_filter=quantity_filter,
)
response = server_common.make_authenticated_request(
'POST',
'/slurm_gpu_availability',
json=json.loads(body.model_dump_json()),
)
return server_common.get_request_id(response)


@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
@versions.minimal_api_version(24)
@annotations.client_api
def slurm_node_info(
slurm_cluster_name: Optional[str] = None) -> server_common.RequestId:
"""Gets the resource information for all nodes in the Slurm cluster.

Returns:
The request ID of the Slurm node info request.

Request Returns:
List[Dict[str, Any]]: A list of dictionaries, each containing info
for a single Slurm node (node_name, partition, node_state,
gpu_type, total_gpus, free_gpus, vcpu_count, memory_gb).
"""
body = payloads.SlurmNodeInfoRequestBody(
slurm_cluster_name=slurm_cluster_name)
response = server_common.make_authenticated_request(
'GET',
'/slurm_node_info',
json=json.loads(body.model_dump_json()),
)
return server_common.get_request_id(response)

+ 2
- 0
sky/clouds/__init__.py View File

@@ -31,6 +31,7 @@ from sky.clouds.runpod import RunPod
from sky.clouds.scp import SCP
from sky.clouds.seeweb import Seeweb
from sky.clouds.shadeform import Shadeform
from sky.clouds.slurm import Slurm
from sky.clouds.ssh import SSH
from sky.clouds.vast import Vast
from sky.clouds.vsphere import Vsphere
@@ -48,6 +49,7 @@ __all__ = [
'Paperspace',
'PrimeIntellect',
'SCP',
'Slurm',
'RunPod',
'Shadeform',
'Vast',


+ 7
- 0
sky/clouds/cloud.py View File

@@ -182,6 +182,13 @@ class Cloud:
"""
return cls._SUPPORTS_SERVICE_ACCOUNT_ON_REMOTE

@classmethod
def uses_ray(cls) -> bool:
"""Returns whether this cloud uses Ray as the distributed
execution framework.
"""
return True

#### Regions/Zones ####

@classmethod


+ 578
- 0
sky/clouds/slurm.py View File

@@ -0,0 +1,578 @@
"""Slurm."""

import typing
from typing import Dict, Iterator, List, Optional, Tuple, Union

from sky import catalog
from sky import clouds
from sky import sky_logging
from sky import skypilot_config
from sky.adaptors import slurm
from sky.provision.slurm import utils as slurm_utils
from sky.utils import annotations
from sky.utils import common_utils
from sky.utils import registry
from sky.utils import resources_utils

if typing.TYPE_CHECKING:
from sky import resources as resources_lib
from sky.utils import volume as volume_lib

logger = sky_logging.init_logger(__name__)

CREDENTIAL_PATH = slurm_utils.DEFAULT_SLURM_PATH


@registry.CLOUD_REGISTRY.register
class Slurm(clouds.Cloud):
"""Slurm."""

_REPR = 'Slurm'
_CLOUD_UNSUPPORTED_FEATURES = {
clouds.CloudImplementationFeatures.AUTOSTOP: 'Slurm does not '
'support autostop.',
clouds.CloudImplementationFeatures.STOP: 'Slurm does not support '
'stopping instances.',
clouds.CloudImplementationFeatures.SPOT_INSTANCE: 'Spot instances are '
'not supported in '
'Slurm.',
clouds.CloudImplementationFeatures.CUSTOM_MULTI_NETWORK:
'Customized multiple network interfaces are not supported in '
'Slurm.',
clouds.CloudImplementationFeatures.OPEN_PORTS: 'Opening ports is not '
'supported in Slurm.',
clouds.CloudImplementationFeatures.HOST_CONTROLLERS:
'Running '
'controllers is not '
'well tested with '
'Slurm.',
clouds.CloudImplementationFeatures.IMAGE_ID: 'Specifying image ID is '
'not supported in Slurm.',
clouds.CloudImplementationFeatures.DOCKER_IMAGE: 'Docker image is not '
'supported in Slurm.',
}
_MAX_CLUSTER_NAME_LEN_LIMIT = 120
_regions: List[clouds.Region] = []
_INDENT_PREFIX = ' '

# Using the latest SkyPilot provisioner API to provision and check status.
PROVISIONER_VERSION = clouds.ProvisionerVersion.SKYPILOT
STATUS_VERSION = clouds.StatusVersion.SKYPILOT

@classmethod
def _unsupported_features_for_resources(
cls,
resources: 'resources_lib.Resources',
region: Optional[str] = None,
) -> Dict[clouds.CloudImplementationFeatures, str]:
del region # unused
# logger.critical('[BYPASS] Check Slurm's unsupported features...')
return cls._CLOUD_UNSUPPORTED_FEATURES

@classmethod
def _max_cluster_name_length(cls) -> Optional[int]:
return cls._MAX_CLUSTER_NAME_LEN_LIMIT

@classmethod
def uses_ray(cls) -> bool:
return False

@classmethod
def get_vcpus_mem_from_instance_type(
cls,
instance_type: str,
) -> Tuple[Optional[float], Optional[float]]:
inst = slurm_utils.SlurmInstanceType.from_instance_type(instance_type)
return inst.cpus, inst.memory

@classmethod
def zones_provision_loop(
cls,
*,
region: str,
num_nodes: int,
instance_type: str,
accelerators: Optional[Dict[str, int]] = None,
use_spot: bool = False,
) -> Iterator[Optional[List[clouds.Zone]]]:
"""Iterate over partitions (zones) for provisioning with failover.

Yields one partition at a time for failover retry logic.
"""
del num_nodes # unused

regions = cls.regions_with_offering(instance_type,
accelerators,
use_spot,
region=region,
zone=None)

for r in regions:
if r.zones:
# Yield one partition at a time for failover
for zone in r.zones:
yield [zone]
else:
# No partitions discovered, use default
yield None

@classmethod
@annotations.lru_cache(scope='global', maxsize=1)
def _log_skipped_clusters_once(cls, skipped_clusters: Tuple[str,
...]) -> None:
"""Log skipped clusters for only once.

We don't directly cache the result of existing_allowed_clusters
as the config may update the allowed clusters.
"""
if skipped_clusters:
logger.warning(
f'Slurm clusters {set(skipped_clusters)!r} specified in '
'"allowed_clusters" not found in ~/.slurm/config. '
'Ignoring these clusters.')

@classmethod
def existing_allowed_clusters(cls, silent: bool = False) -> List[str]:
"""Get existing allowed clusters.

Returns clusters based on the following logic:
1. If 'allowed_clusters' is set to 'all' in ~/.sky/config.yaml,
return all clusters from ~/.slurm/config
2. If specific clusters are listed in 'allowed_clusters',
return only those that exist in ~/.slurm/config
3. If no configuration is specified, return all clusters
from ~/.slurm/config (default behavior)
"""
all_clusters = slurm_utils.get_all_slurm_cluster_names()
if len(all_clusters) == 0:
return []

all_clusters = set(all_clusters)

# Workspace-level allowed_clusters should take precedence over
# the global allowed_clusters.
allowed_clusters = skypilot_config.get_workspace_cloud('slurm').get(
'allowed_clusters', None)
if allowed_clusters is None:
allowed_clusters = skypilot_config.get_effective_region_config(
cloud='slurm',
region=None,
keys=('allowed_clusters',),
default_value=None)

allow_all_clusters = allowed_clusters == 'all'
if allow_all_clusters:
allowed_clusters = list(all_clusters)

if allowed_clusters is None:
# Default to all clusters if no configuration is specified
allowed_clusters = list(all_clusters)

existing_clusters = []
skipped_clusters = []
for cluster in allowed_clusters:
if cluster in all_clusters:
existing_clusters.append(cluster)
else:
skipped_clusters.append(cluster)

if not silent:
cls._log_skipped_clusters_once(tuple(sorted(skipped_clusters)))

return existing_clusters

@classmethod
def regions_with_offering(
cls,
instance_type: Optional[str],
accelerators: Optional[Dict[str, int]],
use_spot: bool,
region: Optional[str],
zone: Optional[str],
resources: Optional['resources_lib.Resources'] = None
) -> List[clouds.Region]:
del accelerators, use_spot, resources # unused
existing_clusters = cls.existing_allowed_clusters()

regions: List[clouds.Region] = []
for cluster in existing_clusters:
# Filter by region if specified
if region is not None and cluster != region:
continue

# Fetch partitions for this cluster and attach as zones
try:
partitions = slurm_utils.get_partitions(cluster)
if zone is not None:
# Filter by zone (partition) if specified
partitions = [p for p in partitions if p == zone]
zones = [clouds.Zone(p) for p in partitions]
except Exception as e: # pylint: disable=broad-except
logger.debug(f'Failed to get partitions for {cluster}: {e}')
zones = []

r = clouds.Region(cluster)
if zones:
r.set_zones(zones)
regions.append(r)

# Check if requested instance type will fit in the cluster.
if instance_type is None:
return regions

regions_to_return = []
for r in regions:
cluster = r.name

# Check each partition (zone) in the cluster
partitions_to_check = [z.name for z in r.zones] if r.zones else []
valid_zones = []

# TODO(kevin): Batch this check to reduce number of roundtrips.
for partition in partitions_to_check:
fits, reason = slurm_utils.check_instance_fits(
cluster, instance_type, partition)
if fits:
if partition:
valid_zones.append(clouds.Zone(partition))
else:
logger.debug(
f'Instance type {instance_type} does not fit in '
f'{cluster}/{partition}: {reason}')

if valid_zones:
r.set_zones(valid_zones)
regions_to_return.append(r)

return regions_to_return

def instance_type_to_hourly_cost(self,
instance_type: str,
use_spot: bool,
region: Optional[str] = None,
zone: Optional[str] = None) -> float:
"""For now, we assume zero cost for Slurm clusters."""
return 0.0

def accelerators_to_hourly_cost(self,
accelerators: Dict[str, int],
use_spot: bool,
region: Optional[str] = None,
zone: Optional[str] = None) -> float:
"""Returns the hourly cost of the accelerators, in dollars/hour."""
del accelerators, use_spot, region, zone # unused
return 0.0

def get_egress_cost(self, num_gigabytes: float) -> float:
return 0.0

def __repr__(self):
return self._REPR

def is_same_cloud(self, other: clouds.Cloud) -> bool:
# Returns true if the two clouds are the same cloud type.
return isinstance(other, Slurm)

@classmethod
def get_default_instance_type(cls,
cpus: Optional[str] = None,
memory: Optional[str] = None,
disk_tier: Optional[
resources_utils.DiskTier] = None,
region: Optional[str] = None,
zone: Optional[str] = None) -> Optional[str]:
"""Returns the default instance type for Slurm."""
return catalog.get_default_instance_type(cpus=cpus,
memory=memory,
disk_tier=disk_tier,
region=region,
zone=zone,
clouds='slurm')

@classmethod
def get_accelerators_from_instance_type(
cls, instance_type: str) -> Optional[Dict[str, Union[int, float]]]:
inst = slurm_utils.SlurmInstanceType.from_instance_type(instance_type)
return {
inst.accelerator_type: inst.accelerator_count
} if (inst.accelerator_count is not None and
inst.accelerator_type is not None) else None

@classmethod
def get_zone_shell_cmd(cls) -> Optional[str]:
return None

def make_deploy_resources_variables(
self,
resources: 'resources_lib.Resources',
cluster_name: 'resources_utils.ClusterName',
region: Optional['clouds.Region'],
zones: Optional[List['clouds.Zone']],
num_nodes: int,
dryrun: bool = False,
volume_mounts: Optional[List['volume_lib.VolumeMount']] = None,
) -> Dict[str, Optional[str]]:
del cluster_name, dryrun, volume_mounts # Unused.
if region is not None:
cluster = region.name
else:
cluster = 'localcluster'
assert cluster is not None, 'No available Slurm cluster found.'

# Use zone as partition if specified, otherwise default
if zones and len(zones) > 0:
partition = zones[0].name
else:
partition = slurm_utils.get_cluster_default_partition(cluster)

# cluster is our target slurmctld host.
ssh_config = slurm_utils.get_slurm_ssh_config()
ssh_config_dict = ssh_config.lookup(cluster)

resources = resources.assert_launchable()
acc_dict = self.get_accelerators_from_instance_type(
resources.instance_type)
custom_resources = resources_utils.make_ray_custom_resources_str(
acc_dict)

# resources.memory and cpus are none if they are not explicitly set.
# we fetch the default values for the instance type in that case.
s = slurm_utils.SlurmInstanceType.from_instance_type(
resources.instance_type)
cpus = s.cpus
mem = s.memory
# Optionally populate accelerator information.
acc_count = s.accelerator_count if s.accelerator_count else 0
acc_type = s.accelerator_type if s.accelerator_type else None

deploy_vars = {
'instance_type': resources.instance_type,
'custom_resources': custom_resources,
'cpus': str(cpus),
'memory': str(mem),
'accelerator_count': str(acc_count),
'accelerator_type': acc_type,
'slurm_cluster': cluster,
'slurm_partition': partition,
# TODO(jwj): Pass SSH config in a smarter way
'ssh_hostname': ssh_config_dict['hostname'],
'ssh_port': str(ssh_config_dict.get('port', 22)),
'ssh_user': ssh_config_dict['user'],
'slurm_proxy_command': ssh_config_dict.get('proxycommand', None),
# TODO(jwj): Solve naming collision with 'ssh_private_key'.
# Please refer to slurm-ray.yml.j2 'ssh' and 'auth' sections.
'slurm_private_key': ssh_config_dict['identityfile'][0],
}

return deploy_vars

def _get_feasible_launchable_resources(
self, resources: 'resources_lib.Resources'
) -> 'resources_utils.FeasibleResources':
"""Returns a list of feasible resources for the given resources."""
if resources.instance_type is not None:
assert resources.is_launchable(), resources
# Check if the instance type is available in at least one cluster
available_regions = self.regions_with_offering(
resources.instance_type,
accelerators=None,
use_spot=resources.use_spot,
region=resources.region,
zone=resources.zone)
if not available_regions:
return resources_utils.FeasibleResources([], [], None)

# Return a single resource without region set.
# The optimizer will call make_launchables_for_valid_region_zones()
# which will create one resource per region/cluster.
resources = resources.copy(accelerators=None)
return resources_utils.FeasibleResources([resources], [], None)

def _make(instance_list):
resource_list = []
for instance_type in instance_list:
r = resources.copy(
cloud=Slurm(),
instance_type=instance_type,
accelerators=None,
)
resource_list.append(r)
return resource_list

# Currently, handle a filter on accelerators only.
accelerators = resources.accelerators

default_instance_type = Slurm.get_default_instance_type(
cpus=resources.cpus,
memory=resources.memory,
disk_tier=resources.disk_tier,
region=resources.region,
zone=resources.zone)
if default_instance_type is None:
return resources_utils.FeasibleResources([], [], None)

if accelerators is None:
chosen_instance_type = default_instance_type
else:
assert len(accelerators) == 1, resources

# Build GPU-enabled instance type.
acc_type, acc_count = list(accelerators.items())[0]

slurm_instance_type = (slurm_utils.SlurmInstanceType.
from_instance_type(default_instance_type))

gpu_task_cpus = slurm_instance_type.cpus
gpu_task_memory = slurm_instance_type.memory
# if resources.cpus is None:
# gpu_task_cpus = self._DEFAULT_NUM_VCPUS_WITH_GPU * acc_count
# gpu_task_memory = (float(resources.memory.strip('+')) if
# resources.memory is not None else
# gpu_task_cpus *
# self._DEFAULT_MEMORY_CPU_RATIO_WITH_GPU)

chosen_instance_type = (
slurm_utils.SlurmInstanceType.from_resources(
gpu_task_cpus, gpu_task_memory, acc_count, acc_type).name)

# Check the availability of the specified instance type in all
# Slurm clusters.
available_regions = self.regions_with_offering(
chosen_instance_type,
accelerators=None,
use_spot=resources.use_spot,
region=resources.region,
zone=resources.zone)
if not available_regions:
return resources_utils.FeasibleResources([], [], None)

return resources_utils.FeasibleResources(_make([chosen_instance_type]),
[], None)

@classmethod
def _check_compute_credentials(
cls) -> Tuple[bool, Optional[Union[str, Dict[str, str]]]]:
"""Checks if the user has access credentials to the Slurm cluster."""
try:
ssh_config = slurm_utils.get_slurm_ssh_config()
except FileNotFoundError:
return (
False,
f'Slurm configuration file {slurm_utils.DEFAULT_SLURM_PATH} '
'does not exist.\n'
f'{cls._INDENT_PREFIX}For more info: '
'https://docs.skypilot.co/en/latest/getting-started/'
'installation.html#slurm-installation')
except Exception as e: # pylint: disable=broad-except
return (False, 'Failed to load SSH configuration from '
f'{slurm_utils.DEFAULT_SLURM_PATH}: '
f'{common_utils.format_exception(e)}.')
existing_allowed_clusters = cls.existing_allowed_clusters()

if not existing_allowed_clusters:
return (False, 'No SLURM clusters found in ~/.slurm/config. '
'Please configure at least one SLURM cluster.')

# Check credentials for each cluster and return ctx2text mapping
ctx2text = {}
success = False
for cluster in existing_allowed_clusters:
# Retrieve the config options for a given SlurmctldHost name alias.
ssh_config_dict = ssh_config.lookup(cluster)

try:
client = slurm.SlurmClient(
ssh_config_dict['hostname'],
int(ssh_config_dict.get('port', 22)),
ssh_config_dict['user'],
ssh_config_dict['identityfile'][0],
ssh_proxy_command=ssh_config_dict.get('proxycommand', None))
info = client.info()
logger.debug(f'Slurm cluster {cluster} sinfo: {info}')
ctx2text[cluster] = 'enabled'
success = True
except Exception as e: # pylint: disable=broad-except
error_msg = (f'Credential check failed: '
f'{common_utils.format_exception(e)}')
ctx2text[cluster] = f'disabled. {error_msg}'

return success, ctx2text

def get_credential_file_mounts(self) -> Dict[str, str]:
########
# TODO #
########
# Return dictionary of credential file paths. This may look
# something like:
return {}

@classmethod
def get_current_user_identity(cls) -> Optional[List[str]]:
# NOTE: used for very advanced SkyPilot functionality
# Can implement later if desired
return None

def instance_type_exists(self, instance_type: str) -> bool:
return catalog.instance_type_exists(instance_type, 'slurm')

def validate_region_zone(self, region: Optional[str], zone: Optional[str]):
"""Validate region (cluster) and zone (partition).

Args:
region: Slurm cluster name.
zone: Slurm partition name (optional).

Returns:
Tuple of (region, zone) if valid.

Raises:
ValueError: If cluster or partition not found.
"""
all_clusters = slurm_utils.get_all_slurm_cluster_names()
if region and region not in all_clusters:
raise ValueError(
f'Cluster {region} not found in Slurm config. Slurm only '
'supports cluster names as regions. Available '
f'clusters: {all_clusters}')

# Validate partition (zone) if specified
if zone is not None:
if region is None:
raise ValueError(
'Cannot specify partition (zone) without specifying '
'cluster (region) for Slurm.')

partitions = slurm_utils.get_partitions(region)
if zone not in partitions:
raise ValueError(
f'Partition {zone!r} not found in cluster {region!r}. '
f'Available partitions: {partitions}')

return region, zone

def accelerator_in_region_or_zone(self,
accelerator: str,
acc_count: int,
region: Optional[str] = None,
zone: Optional[str] = None) -> bool:
del zone # unused for now
regions = catalog.get_region_zones_for_accelerators(accelerator,
acc_count,
use_spot=False,
clouds='slurm')
if not regions:
return False
if region is None:
return True
return any(r.name == region for r in regions)

@classmethod
def expand_infras(cls) -> List[str]:
"""Returns a list of enabled Slurm clusters.

Each is returned as 'Slurm/cluster-name'.
"""
infras = []
for cluster in cls.existing_allowed_clusters(silent=True):
infras.append(f'{cls.canonical_name()}/{cluster}')
return infras

+ 2
- 1
sky/clouds/ssh.py View File

@@ -9,6 +9,7 @@ from sky import skypilot_config
from sky.adaptors import kubernetes as kubernetes_adaptor
from sky.clouds import kubernetes
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.ssh_node_pools import constants as ssh_constants
from sky.utils import annotations
from sky.utils import common_utils
from sky.utils import registry
@@ -20,7 +21,7 @@ if typing.TYPE_CHECKING:

logger = sky_logging.init_logger(__name__)

SSH_NODE_POOLS_PATH = os.path.expanduser('~/.sky/ssh_node_pools.yaml')
SSH_NODE_POOLS_PATH = ssh_constants.DEFAULT_SSH_NODE_POOLS_PATH


@registry.CLOUD_REGISTRY.register()


+ 10
- 0
sky/clouds/vast.py View File

@@ -6,6 +6,7 @@ from typing import Dict, Iterator, List, Optional, Tuple, Union

from sky import catalog
from sky import clouds
from sky import skypilot_config
from sky.adaptors import common
from sky.utils import registry
from sky.utils import resources_utils
@@ -196,11 +197,20 @@ class Vast(clouds.Cloud):
else:
image_id = resources.image_id[resources.region]

secure_only = skypilot_config.get_effective_region_config(
cloud='vast',
region=region.name,
keys=('secure_only',),
default_value=False,
override_configs=resources.cluster_config_overrides,
)

return {
'instance_type': resources.instance_type,
'custom_resources': custom_resources,
'region': region.name,
'image_id': image_id,
'secure_only': secure_only,
}

def _get_feasible_launchable_resources(


+ 128
- 36
sky/core.py View File

@@ -1211,6 +1211,7 @@ def enabled_clouds(workspace: Optional[str] = None,
return [cloud.canonical_name() for cloud in cached_clouds]
enabled_ssh_infras = []
enabled_k8s_infras = []
enabled_slurm_infras = []
enabled_cloud_infras = []
for cloud in cached_clouds:
cloud_infra = cloud.expand_infras()
@@ -1218,10 +1219,16 @@ def enabled_clouds(workspace: Optional[str] = None,
enabled_ssh_infras.extend(cloud_infra)
elif isinstance(cloud, clouds.Kubernetes):
enabled_k8s_infras.extend(cloud_infra)
elif isinstance(cloud, clouds.Slurm):
enabled_slurm_infras.extend(cloud_infra)
else:
enabled_cloud_infras.extend(cloud_infra)
# We do not sort slurm infras alphabetically because the
# default partition should appear first.
# Ordering of slurm infras is enforced in Slurm implementation.
all_infras = sorted(enabled_ssh_infras) + sorted(
enabled_k8s_infras) + sorted(enabled_cloud_infras)
enabled_k8s_infras) + enabled_slurm_infras + sorted(
enabled_cloud_infras)
return all_infras


@@ -1232,7 +1239,14 @@ def realtime_kubernetes_gpu_availability(
quantity_filter: Optional[int] = None,
is_ssh: Optional[bool] = None
) -> List[Tuple[str, List[models.RealtimeGpuAvailability]]]:
"""Gets the real-time Kubernetes GPU availability.

Returns:
A list of tuples, where each tuple contains:
- context (str): The Kubernetes context.
- availability_list (List[models.RealtimeGpuAvailability]): A list
of RealtimeGpuAvailability objects for that context.
"""
if context is None:
# Include contexts from both Kubernetes and SSH clouds
kubernetes_contexts = clouds.Kubernetes.existing_allowed_contexts()
@@ -1314,6 +1328,119 @@ def realtime_kubernetes_gpu_availability(
return availability_lists


def realtime_slurm_gpu_availability(
slurm_cluster_name: Optional[str] = None,
name_filter: Optional[str] = None,
quantity_filter: Optional[int] = None,
env_vars: Optional[Dict[str, str]] = None,
**kwargs) -> List[Tuple[str, List[models.RealtimeGpuAvailability]]]:
"""Gets Slurm real-time GPU availability grouped by partition.

This function calls the Slurm backend to fetch GPU info.

Args:
name_filter: Optional name filter for GPUs.
quantity_filter: Optional quantity filter for GPUs.
env_vars: Environment variables (may be needed for backend).
kwargs: Additional keyword arguments.

Returns:
A list of tuples, where each tuple contains:
- partition_name (str): The name of the Slurm partition.
- availability_list (List[models.RealtimeGpuAvailability]): A list
of RealtimeGpuAvailability objects for that partition.
Example structure:
[
('gpu_partition_1', [
RealtimeGpuAvailability(gpu='V100', counts=[4, 8],
capacity=16, available=10),
RealtimeGpuAvailability(gpu='A100', counts=[8],
capacity=8, available=0),
]),
('gpu_partition_2', [
RealtimeGpuAvailability(gpu='V100', counts=[4],
capacity=4, available=4),
])
]

Raises:
ValueError: If Slurm is not configured or no matching GPUs are found.
exceptions.NotSupportedError: If Slurm is not enabled or configured.
"""
del env_vars, kwargs # Currently unused

if slurm_cluster_name is None:
# Include contexts from both Kubernetes and SSH clouds
slurm_cluster_names = clouds.Slurm.existing_allowed_clusters()
else:
slurm_cluster_names = [slurm_cluster_name]

# Optional: Check if Slurm is enabled first
# enabled = global_user_state.get_enabled_clouds(
# capability=sky_cloud.CloudCapability.COMPUTE)
# if not clouds.Slurm() in enabled:
# raise exceptions.NotSupportedError(
# "Slurm is not enabled. Run 'sky check' to enable it.")

def realtime_slurm_gpu_availability_single(
slurm_cluster_name: str) -> List[models.RealtimeGpuAvailability]:
try:
# This function now returns aggregated data per GPU type:
# Tuple[Dict[str, List[InstanceTypeInfo]], Dict[str, int],
# Dict[str, int]]
# (qtys_map, total_capacity, total_available)
accelerator_counts, total_capacity, total_available = (
catalog.list_accelerator_realtime(
gpus_only=True, # Ensure we only query for GPUs
name_filter=name_filter,
# Pass None for region_filter here; filtering happens
# inside if needed, but we want all partitions returned
# for grouping.
region_filter=slurm_cluster_name,
quantity_filter=quantity_filter,
clouds='slurm',
case_sensitive=False,
))
except exceptions.NotSupportedError as e:
logger.error(f'Failed to query Slurm GPU availability: {e}')
raise
except ValueError as e:
# Re-raise ValueError if no GPUs are found matching the filters
logger.error(f'Error querying Slurm GPU availability: {e}')
raise
except Exception as e:
logger.error(
'Error querying Slurm GPU availability: '
f'{common_utils.format_exception(e, use_bracket=True)}')
raise ValueError(
f'Error querying Slurm GPU availability: {e}') from e

# --- Format the output ---
realtime_gpu_availability_list: List[
models.RealtimeGpuAvailability] = []
for gpu_type, _ in sorted(accelerator_counts.items()):
realtime_gpu_availability_list.append(
models.RealtimeGpuAvailability(
gpu_type,
accelerator_counts.pop(gpu_type),
total_capacity[gpu_type],
total_available[gpu_type],
))
return realtime_gpu_availability_list

parallel_queried = subprocess_utils.run_in_parallel(
realtime_slurm_gpu_availability_single, slurm_cluster_names)
availability_lists: List[Tuple[str,
List[models.RealtimeGpuAvailability]]] = []
for slurm_cluster_name, queried in zip(slurm_cluster_names,
parallel_queried):
if len(queried) == 0:
logger.debug(f'No gpus found in Slurm cluster {slurm_cluster_name}')
continue
availability_lists.append((slurm_cluster_name, queried))
return availability_lists


# =================
# = Local Cluster =
# =================
@@ -1330,41 +1457,6 @@ def local_down(name: Optional[str] = None) -> None:
kubernetes_deploy_utils.teardown_local_cluster(name)


@usage_lib.entrypoint
def ssh_up(infra: Optional[str] = None, cleanup: bool = False) -> None:
"""Deploys or tears down a Kubernetes cluster on SSH targets.

Args:
infra: Name of the cluster configuration in ssh_node_pools.yaml.
If None, the first cluster in the file is used.
cleanup: If True, clean up the cluster instead of deploying.
"""
kubernetes_deploy_utils.deploy_ssh_cluster(
cleanup=cleanup,
infra=infra,
)


@usage_lib.entrypoint
def ssh_status(context_name: str) -> Tuple[bool, str]:
"""Check the status of an SSH Node Pool context.

Args:
context_name: The SSH context name (e.g., 'ssh-my-cluster')

Returns:
Tuple[bool, str]: (is_ready, reason)
- is_ready: True if the SSH Node Pool is ready, False otherwise
- reason: Explanation of the status
"""
try:
is_ready, reason = clouds.SSH.check_single_context(context_name)
return is_ready, reason
except Exception as e: # pylint: disable=broad-except
return False, ('Failed to check SSH context: '
f'{common_utils.format_exception(e)}')


def get_all_contexts() -> List[str]:
"""Get all available contexts from Kubernetes and SSH clouds.



+ 19
- 0
sky/dashboard/src/components/elements/icons.jsx View File

@@ -686,3 +686,22 @@ export function RssIcon(props) {
export function UserCircleIcon(props) {
return <UserCircle {...props} />;
}

export function KeyIcon(props) {
return (
<svg
{...props}
xmlns="http://www.w3.org/2000/svg"
width="24"
height="24"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
strokeWidth="2"
strokeLinecap="round"
strokeLinejoin="round"
>
<path d="m21 2-2 2m-7.61 7.61a5.5 5.5 0 1 1-7.778 7.778 5.5 5.5 0 0 1 7.777-7.777zm0 0L15.5 7.5m0 0 3 3L22 7l-3-3m-3.5 3.5L19 4" />
</svg>
);
}

+ 240
- 5
sky/dashboard/src/components/elements/sidebar.jsx View File

@@ -21,11 +21,24 @@ import {
UsersIcon,
StarIcon,
VolumeIcon,
KeyIcon,
} from '@/components/elements/icons';
import { Settings, User } from 'lucide-react';

// Map icon names to icon components for plugin nav links
const ICON_MAP = {
key: KeyIcon,
server: ServerIcon,
briefcase: BriefcaseIcon,
chip: ChipIcon,
book: BookDocIcon,
users: UsersIcon,
volume: VolumeIcon,
};
import { BASE_PATH, ENDPOINT } from '@/data/connectors/constants';
import { CustomTooltip } from '@/components/utils';
import { useMobile } from '@/hooks/useMobile';
import { useGroupedNavLinks, usePluginRoutes } from '@/plugins/PluginProvider';

// Create a context for sidebar state management
const SidebarContext = createContext(null);
@@ -161,9 +174,13 @@ export function TopBar() {
const { userEmail, userRole, isMobileSidebarOpen, toggleMobileSidebar } =
useSidebar();
const [isDropdownOpen, setIsDropdownOpen] = useState(false);
const [openNavDropdown, setOpenNavDropdown] = useState(null);
const { ungrouped, groups } = useGroupedNavLinks();
const pluginRoutes = usePluginRoutes();

const dropdownRef = useRef(null);
const mobileNavRef = useRef(null);
const navDropdownRef = useRef(null);

useEffect(() => {
function handleClickOutside(event) {
@@ -180,6 +197,13 @@ export function TopBar() {
toggleMobileSidebar();
}
}
// Handle navigation dropdown menu clicks outside
if (
navDropdownRef.current &&
!navDropdownRef.current.contains(event.target)
) {
setOpenNavDropdown(null);
}
}
// Bind the event listener
document.addEventListener('mousedown', handleClickOutside);
@@ -226,6 +250,190 @@ export function TopBar() {
}`;
};

const getMobileLinkClasses = (path, forceInactive = false) => {
const isActive = !forceInactive && isActivePath(path);
return `flex items-center px-4 py-3 text-sm font-medium rounded-md transition-colors ${
isActive
? 'bg-blue-50 text-blue-600'
: 'text-gray-700 hover:bg-gray-100 hover:text-blue-600'
}`;
};

const renderPluginIcon = (icon, className) => {
const IconComponent = ICON_MAP[icon];
if (IconComponent) {
return React.createElement(IconComponent, { className });
}
return icon;
};

const renderNavLabel = (link) => (
<>
{link.icon && (
<span className="text-base leading-none mr-1" aria-hidden="true">
{renderPluginIcon(link.icon, 'w-4 h-4')}
</span>
)}
<span className="inline-flex items-center gap-1">
<span>{link.label}</span>
{link.badge && (
<span className="text-[10px] uppercase tracking-wide bg-blue-100 text-blue-700 px-1.5 py-0.5 rounded-full">
{link.badge}
</span>
)}
</span>
</>
);

const resolvePluginHref = (href) => {
if (typeof href !== 'string') {
return href;
}
const route = pluginRoutes.find((entry) => entry.path === href);
if (!route || !route.path.startsWith('/plugins')) {
return href;
}
const slugSegments = route.path
.replace(/^\/+/, '')
.split('/')
.slice(1)
.filter(Boolean);
return {
pathname: '/plugins/[...slug]',
query: slugSegments.length ? { slug: slugSegments } : {},
};
};

const renderDesktopPluginNavLink = (link) => {
if (link.external) {
return (
<a
key={link.id}
href={link.href}
target={link.target}
rel={link.rel}
className="inline-flex items-center border-b-2 border-transparent px-1 pt-1 space-x-2 text-gray-700 hover:text-blue-600"
>
{renderNavLabel(link)}
</a>
);
}

return (
<Link
key={link.id}
href={resolvePluginHref(link.href)}
className={getLinkClasses(link.href)}
prefetch={false}
>
{renderNavLabel(link)}
</Link>
);
};

const renderMobilePluginNavLink = (link) => {
const content = (
<>
{link.icon && (
<span className="text-base leading-none mr-2" aria-hidden="true">
{renderPluginIcon(link.icon, 'w-5 h-5')}
</span>
)}
<span className="flex items-center gap-2">
<span>{link.label}</span>
{link.badge && (
<span className="text-[10px] uppercase tracking-wide bg-blue-100 text-blue-700 px-1.5 py-0.5 rounded-full">
{link.badge}
</span>
)}
</span>
</>
);

if (link.external) {
return (
<a
key={link.id}
href={link.href}
target={link.target}
rel={link.rel}
className={getMobileLinkClasses(link.href, true)}
onClick={toggleMobileSidebar}
>
{content}
</a>
);
}

return (
<Link
key={link.id}
href={resolvePluginHref(link.href)}
className={getMobileLinkClasses(link.href)}
onClick={toggleMobileSidebar}
prefetch={false}
>
{content}
</Link>
);
};

// Render desktop dropdown menu for grouped plugins
const renderDesktopDropdownMenu = (groupName, links) => {
const isOpen = openNavDropdown === groupName;

return (
<div className="relative" key={groupName} ref={navDropdownRef}>
<button
onClick={() => setOpenNavDropdown(isOpen ? null : groupName)}
className={`inline-flex items-center align-middle border-b-2 px-1 pt-1 space-x-1 ${
isOpen
? 'text-blue-600 border-blue-600'
: 'border-transparent text-gray-700 hover:text-blue-600'
}`}
>
<span>{groupName}</span>
<svg
className={`w-4 h-4 transition-transform ${isOpen ? 'rotate-180' : ''}`}
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fillRule="evenodd"
d="M5.293 7.293a1 1 0 011.414 0L10 10.586l3.293-3.293a1 1 0 111.414 1.414l-4 4a1 1 0 01-1.414 0l-4-4a1 1 0 010-1.414z"
clipRule="evenodd"
/>
</svg>
</button>

{isOpen && (
<div className="absolute top-full left-0 mt-1 min-w-[8rem] bg-white rounded-md shadow-lg border border-gray-200 z-50">
<div className="py-1">
{links.map((link) => (
<Link
key={link.id}
href={resolvePluginHref(link.href)}
className="block px-4 py-2 text-sm text-gray-700 hover:bg-gray-100 transition-colors"
onClick={() => setOpenNavDropdown(null)}
prefetch={false}
>
<div className="flex items-center gap-2">
{link.icon && (
<span className="text-base leading-none">
{renderPluginIcon(link.icon, 'w-4 h-4')}
</span>
)}
<span>{link.label}</span>
</div>
</Link>
))}
</div>
</div>
)}
</div>
);
};

return (
<>
<div className="fixed top-0 left-0 right-0 bg-white z-30 h-14 px-4 border-b border-gray-200 shadow-sm">
@@ -342,6 +550,14 @@ export function TopBar() {
<div className="flex items-center space-x-1 ml-auto">
{!isMobile && (
<>
{/* Ungrouped plugin links - positioned on the right */}
{ungrouped.map((link) => renderDesktopPluginNavLink(link))}

{/* Grouped dropdown menus (e.g., Enterprise) - positioned on the right */}
{Object.entries(groups).map(([groupName, links]) =>
renderDesktopDropdownMenu(groupName, links)
)}

<CustomTooltip
content="Documentation"
className="text-sm text-muted-foreground"
@@ -350,10 +566,10 @@ export function TopBar() {
href="https://skypilot.readthedocs.io/en/latest/"
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center px-2 py-1 text-gray-600 hover:text-blue-600 transition-colors duration-150 cursor-pointer"
className="inline-flex items-center align-middle border-b-2 border-transparent px-1 pt-1 space-x-1 text-gray-600 hover:text-blue-600 transition-colors duration-150 cursor-pointer"
title="Docs"
>
<span className="mr-1">Docs</span>
<span className="leading-none">Docs</span>
<ExternalLinkIcon className="w-3.5 h-3.5" />
</a>
</CustomTooltip>
@@ -366,7 +582,7 @@ export function TopBar() {
href="https://github.com/skypilot-org/skypilot"
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center justify-center p-2 rounded-full text-gray-600 hover:bg-gray-100 transition-colors duration-150 cursor-pointer"
className="inline-flex items-center justify-center align-middle p-2 rounded-full text-gray-600 hover:bg-gray-100 transition-colors duration-150 cursor-pointer"
title="GitHub"
>
<GitHubIcon className="w-5 h-5" />
@@ -381,7 +597,7 @@ export function TopBar() {
href="https://slack.skypilot.co/"
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center justify-center p-2 rounded-full text-gray-600 hover:bg-gray-100 transition-colors duration-150 cursor-pointer"
className="inline-flex items-center justify-center align-middle p-2 rounded-full text-gray-600 hover:bg-gray-100 transition-colors duration-150 cursor-pointer"
title="Slack"
>
<SlackIcon className="w-5 h-5" />
@@ -396,7 +612,7 @@ export function TopBar() {
href="https://github.com/skypilot-org/skypilot/issues/new"
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center justify-center p-2 rounded-full text-gray-600 hover:bg-gray-100 transition-colors duration-150 cursor-pointer"
className="inline-flex items-center justify-center align-middle p-2 rounded-full text-gray-600 hover:bg-gray-100 transition-colors duration-150 cursor-pointer"
title="Leave Feedback"
>
<CommentFeedbackIcon className="w-5 h-5" />
@@ -492,6 +708,8 @@ export function TopBar() {
</div>
)}
</div>

<div className="border-l border-gray-200 h-6 mx-1"></div>
</div>
</div>

@@ -603,6 +821,23 @@ export function TopBar() {

<div className="border-t border-gray-200 my-4"></div>

{/* Ungrouped plugins */}
{ungrouped.map((link) => renderMobilePluginNavLink(link))}

{/* Grouped plugins (displayed flat on mobile) */}
{Object.entries(groups).map(([groupName, links]) => (
<div key={groupName}>
<div className="px-4 py-2 text-xs font-semibold text-gray-500 uppercase tracking-wider">
{groupName}
</div>
{links.map((link) => renderMobilePluginNavLink(link))}
</div>
))}

{(ungrouped.length > 0 || Object.keys(groups).length > 0) && (
<div className="border-t border-gray-200 my-4"></div>
)}

{/* External links in mobile */}
<a
href="https://skypilot.readthedocs.io/en/latest/"


+ 180
- 68
sky/dashboard/src/components/infra.jsx View File

@@ -68,6 +68,68 @@ import {
const REFRESH_INTERVAL = REFRESH_INTERVALS.REFRESH_INTERVAL;
const NAME_TRUNCATE_LENGTH = UI_CONFIG.NAME_TRUNCATE_LENGTH;

// Shared GPU utilization bar to avoid duplicating percentage math and markup
const GpuUtilizationBar = ({
gpu,
heightClass = 'h-4',
wrapperClassName = '',
}) => {
const total = gpu?.gpu_total || 0;
const notReady = gpu?.gpu_not_ready || 0;
const free = gpu?.gpu_free || 0;
const used = Math.max(0, total - free - notReady);
const notReadyLabel = `${notReady} not ready`;
const usedLabel = `${used} used`;
const freeLabel = `${free} free`;
const toPercentage = total > 0 ? (value) => (value / total) * 100 : () => 0;
const notReadyPercentage = toPercentage(notReady);
const usedPercentage = toPercentage(used);
const freePercentage = toPercentage(free);

return (
<div
className={`bg-gray-100 rounded-md flex overflow-hidden shadow-sm ${heightClass} ${wrapperClassName}`.trim()}
>
{notReadyPercentage > 0 && (
<div
style={{
width: `${notReadyPercentage}%`,
fontSize: 'clamp(8px, 1.2vw, 12px)',
}}
title={notReadyLabel}
className="bg-gray-400 h-full flex items-center justify-center text-white font-medium overflow-hidden whitespace-nowrap px-1"
>
{notReadyPercentage > 15 && notReadyLabel}
</div>
)}
{usedPercentage > 0 && (
<div
style={{
width: `${usedPercentage}%`,
fontSize: 'clamp(8px, 1.2vw, 12px)',
}}
title={usedLabel}
className="bg-yellow-500 h-full flex items-center justify-center text-white font-medium overflow-hidden whitespace-nowrap px-1"
>
{usedPercentage > 15 && usedLabel}
</div>
)}
{freePercentage > 0 && (
<div
style={{
width: `${freePercentage}%`,
fontSize: 'clamp(8px, 1.2vw, 12px)',
}}
title={freeLabel}
className="bg-green-700 h-full flex items-center justify-center text-white font-medium overflow-hidden whitespace-nowrap px-1"
>
{freePercentage > 15 && freeLabel}
</div>
)}
</div>
);
};

// Reusable component for infrastructure sections (SSH Node Pool or Kubernetes)
export function InfrastructureSection({
title,
@@ -82,6 +144,7 @@ export function InfrastructureSection({
jobsData = {},
isJobsDataLoading = true,
isSSH = false, // To differentiate between SSH and Kubernetes
isSlurm = false, // To differentiate Slurm clusters
actionButton = null, // Optional action button for the header
contextWorkspaceMap = {}, // Mapping of contexts to workspaces
contextErrors = {}, // Mapping of contexts to error messages
@@ -134,10 +197,14 @@ export function InfrastructureSection({
{safeContexts.length === 1
? isSSH
? 'pool'
: 'context'
: isSlurm
? 'cluster'
: 'context'
: isSSH
? 'pools'
: 'contexts'}
: isSlurm
? 'clusters'
: 'contexts'}
</span>
</div>
{actionButton}
@@ -149,7 +216,7 @@ export function InfrastructureSection({
<thead className="bg-gray-50">
<tr>
<th className="p-3 text-left font-medium text-gray-600 w-1/4">
{isSSH ? 'Node Pool' : 'Context'}
Name
</th>
<th className="p-3 text-left font-medium text-gray-600 w-1/8">
Clusters
@@ -360,27 +427,20 @@ export function InfrastructureSection({
className={`bg-white divide-y divide-gray-200 ${gpus.length > 5 ? 'max-h-[250px] overflow-y-auto block' : ''}`}
>
{gpus.map((gpu) => {
const usedGpus = gpu.gpu_total - gpu.gpu_free;
const freePercentage =
gpu.gpu_total > 0
? (gpu.gpu_free / gpu.gpu_total) * 100
: 0;
const usedPercentage =
gpu.gpu_total > 0
? (usedGpus / gpu.gpu_total) * 100
: 0;

// Find the requestable quantities from contexts
const requestableQtys = groupedPerContextGPUs
? Object.values(groupedPerContextGPUs)
.flat()
.filter(
(g) =>
g.gpu_name === gpu.gpu_name &&
(isSSH
? g.context.startsWith('ssh-')
: !g.context.startsWith('ssh-'))
)
.filter((g) => {
if (g.gpu_name !== gpu.gpu_name) return false;
if (isSlurm) return true; // For Slurm, include all
// For Kubernetes/SSH, filter by context type
const contextKey = g.context || g.cluster;
if (!contextKey) return false;
return isSSH
? contextKey.startsWith('ssh-')
: !contextKey.startsWith('ssh-');
})
.map((g) => g.gpu_requestable_qty_per_node)
.filter((qty, i, arr) => arr.indexOf(qty) === i) // Unique values
.join(', ')
@@ -396,26 +456,11 @@ export function InfrastructureSection({
</td>
<td className="p-3 w-2/3">
<div className="flex items-center gap-3">
<div className="flex-1 bg-gray-100 rounded-md h-5 flex overflow-hidden shadow-sm min-w-[100px] w-full">
{usedPercentage > 0 && (
<div
style={{ width: `${usedPercentage}%` }}
className="bg-yellow-500 h-full flex items-center justify-center text-white text-xs font-medium"
>
{usedPercentage > 15 &&
`${usedGpus} used`}
</div>
)}
{freePercentage > 0 && (
<div
style={{ width: `${freePercentage}%` }}
className="bg-green-700 h-full flex items-center justify-center text-white text-xs font-medium"
>
{freePercentage > 15 &&
`${gpu.gpu_free} free`}
</div>
)}
</div>
<GpuUtilizationBar
gpu={gpu}
heightClass="h-5"
wrapperClassName="flex-1 min-w-[100px] w-full"
/>
</div>
</td>
</tr>
@@ -571,12 +616,6 @@ export function ContextDetails({ contextName, gpusInContext, nodesInContext }) {
</div>
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4 mb-6">
{gpusInContext.map((gpu) => {
const usedGpus = gpu.gpu_total - gpu.gpu_free;
const freePercentage =
gpu.gpu_total > 0 ? (gpu.gpu_free / gpu.gpu_total) * 100 : 0;
const usedPercentage =
gpu.gpu_total > 0 ? (usedGpus / gpu.gpu_total) * 100 : 0;

return (
<div
key={gpu.gpu_name}
@@ -593,23 +632,12 @@ export function ContextDetails({ contextName, gpusInContext, nodesInContext }) {
{gpu.gpu_free} free / {gpu.gpu_total} total
</span>
</div>
<div className="w-full bg-gray-100 rounded-md h-4 flex overflow-hidden shadow-sm">
{usedPercentage > 0 && (
<div
style={{ width: `${usedPercentage}%` }}
className="bg-yellow-500 h-full flex items-center justify-center text-white text-xs"
>
{usedPercentage > 15 && `${usedGpus} used`}
</div>
)}
{freePercentage > 0 && (
<div
style={{ width: `${freePercentage}%` }}
className="bg-green-700 h-full flex items-center justify-center text-white text-xs"
>
{freePercentage > 15 && `${gpu.gpu_free} free`}
</div>
)}
<div className="w-full">
<GpuUtilizationBar
gpu={gpu}
heightClass="h-4"
wrapperClassName="w-full"
/>
</div>
</div>
);
@@ -653,7 +681,9 @@ export function ContextDetails({ contextName, gpusInContext, nodesInContext }) {
{node.gpu_name}
</td>
<td className="p-3 whitespace-nowrap text-right text-gray-700">
{`${node.gpu_free} of ${node.gpu_total} free`}
{node.is_ready === false
? `0 of ${node.gpu_total} free (Node NotReady)`
: `${node.gpu_free} of ${node.gpu_total} free`}
</td>
</tr>
))}
@@ -1615,6 +1645,9 @@ export function GPUs() {
const [allGPUs, setAllGPUs] = useState([]);
const [perContextGPUs, setPerContextGPUs] = useState([]);
const [perNodeGPUs, setPerNodeGPUs] = useState([]);
const [allSlurmGPUs, setAllSlurmGPUs] = useState([]);
const [perClusterSlurmGPUs, setPerClusterSlurmGPUs] = useState([]);
const [perNodeSlurmGPUs, setPerNodeSlurmGPUs] = useState([]);
const [cloudInfraData, setCloudInfraData] = useState([]);
const [totalClouds, setTotalClouds] = useState(0);
const [enabledClouds, setEnabledClouds] = useState(0);
@@ -1719,6 +1752,9 @@ export function GPUs() {
allGPUs: fetchedAllGPUs,
perContextGPUs: fetchedPerContextGPUs,
perNodeGPUs: fetchedPerNodeGPUs,
allSlurmGPUs: fetchedAllSlurmGPUs,
perClusterSlurmGPUs: fetchedPerClusterSlurmGPUs,
perNodeSlurmGPUs: fetchedPerNodeSlurmGPUs,
contextStats: fetchedContextStats,
contextWorkspaceMap: fetchedContextWorkspaceMap,
contextErrors: fetchedContextErrors,
@@ -1729,6 +1765,9 @@ export function GPUs() {
setAllGPUs(fetchedAllGPUs || []);
setPerContextGPUs(fetchedPerContextGPUs || []);
setPerNodeGPUs(fetchedPerNodeGPUs || []);
setAllSlurmGPUs(fetchedAllSlurmGPUs || []);
setPerClusterSlurmGPUs(fetchedPerClusterSlurmGPUs || []);
setPerNodeSlurmGPUs(fetchedPerNodeSlurmGPUs || []);
setContextStats(fetchedContextStats || {});
setContextWorkspaceMap(fetchedContextWorkspaceMap || {});
setContextErrors(fetchedContextErrors || {});
@@ -1747,6 +1786,9 @@ export function GPUs() {
setAllGPUs([]);
setPerContextGPUs([]);
setPerNodeGPUs([]);
setAllSlurmGPUs([]);
setPerClusterSlurmGPUs([]);
setPerNodeSlurmGPUs([]);
setContextStats({});
setContextWorkspaceMap({});
setContextErrors({});
@@ -1761,6 +1803,9 @@ export function GPUs() {
setAllGPUs([]);
setPerContextGPUs([]);
setPerNodeGPUs([]);
setAllSlurmGPUs([]);
setPerClusterSlurmGPUs([]);
setPerNodeSlurmGPUs([]);
setContextStats({});
setContextWorkspaceMap({});
setContextErrors({});
@@ -2112,6 +2157,43 @@ export function GPUs() {
return allGPUs.filter((gpu) => kubeGpuNames.has(gpu.gpu_name));
}, [allGPUs, perContextGPUs]);

// Extract Slurm cluster names from perClusterSlurmGPUs
const slurmClusters = React.useMemo(() => {
if (!perClusterSlurmGPUs || !Array.isArray(perClusterSlurmGPUs)) {
return [];
}
const clusters = [
...new Set(perClusterSlurmGPUs.map((gpu) => gpu.cluster)),
];
return clusters.sort();
}, [perClusterSlurmGPUs]);

// Group perClusterSlurmGPUs by cluster
const groupedPerClusterSlurmGPUs = React.useMemo(() => {
if (!perClusterSlurmGPUs) return {};
return perClusterSlurmGPUs.reduce((acc, gpu) => {
const { cluster } = gpu;
if (!acc[cluster]) {
acc[cluster] = [];
}
acc[cluster].push(gpu);
return acc;
}, {});
}, [perClusterSlurmGPUs]);

// Group perNodeSlurmGPUs by cluster
const groupedPerNodeSlurmGPUs = React.useMemo(() => {
if (!perNodeSlurmGPUs) return {};
return perNodeSlurmGPUs.reduce((acc, node) => {
const { cluster } = node;
if (!acc[cluster]) {
acc[cluster] = [];
}
acc[cluster].push(node);
return acc;
}, {});
}, [perNodeSlurmGPUs]);

// Group perNodeGPUs by context
const groupedPerNodeGPUs = React.useMemo(() => {
if (!perNodeGPUs) return {};
@@ -2354,6 +2436,27 @@ export function GPUs() {
);
};

const renderSlurmInfrastructure = () => {
return (
<InfrastructureSection
title="Slurm"
isLoading={kubeLoading}
isDataLoaded={kubeDataLoaded}
contexts={slurmClusters}
gpus={allSlurmGPUs}
groupedPerContextGPUs={groupedPerClusterSlurmGPUs}
groupedPerNodeGPUs={groupedPerNodeSlurmGPUs}
handleContextClick={handleContextClick}
contextStats={{}}
jobsData={{}}
isJobsDataLoading={false}
isSSH={false}
isSlurm={true}
contextWorkspaceMap={{}}
/>
);
};

const renderKubernetesTab = () => {
// If a context is selected, show its details instead of the summary
if (selectedContext) {
@@ -2383,7 +2486,16 @@ export function GPUs() {
});
};

// Always add all three sections (they handle their own loading/empty states)
// Always add all sections (they handle their own loading/empty states)

// Add Slurm section (always show) - Priority 1 to show at top
const slurmHasActivity = slurmClusters.length > 0;
sections.push({
name: 'Slurm',
render: renderSlurmInfrastructure,
hasActivity: slurmHasActivity,
priority: 1, // Slurm gets priority 1 within same activity level
});

// Add Kubernetes section (always show)
// Kubernetes section is active if there are any contexts available (similar to Cloud logic)
@@ -2392,7 +2504,7 @@ export function GPUs() {
name: 'Kubernetes',
render: renderKubernetesInfrastructure,
hasActivity: kubeHasActivity,
priority: 1, // Kubernetes gets priority 1 within same activity level
priority: 2, // Kubernetes gets priority 2 within same activity level
});

// Add Cloud section (always show)
@@ -2402,7 +2514,7 @@ export function GPUs() {
name: 'Cloud',
render: renderCloudInfrastructure,
hasActivity: cloudHasActivity,
priority: 2, // Cloud gets priority 2 within same activity level
priority: 3, // Cloud gets priority 3 within same activity level
});

// Add SSH section (always show)
@@ -2412,7 +2524,7 @@ export function GPUs() {
name: 'SSH Node Pool',
render: renderSSHNodePoolInfrastructure,
hasActivity: sshHasActivity,
priority: 3, // SSH gets priority 3 within same activity level
priority: 4, // SSH gets priority 4 within same activity level
});

// Dynamic sorting: enabled/active sections move to front automatically


+ 214
- 0
sky/dashboard/src/data/connectors/infra.jsx View File

@@ -205,6 +205,9 @@ export async function getWorkspaceInfrastructure() {
allGPUs: [],
perContextGPUs: [],
perNodeGPUs: [],
allSlurmGPUs: [],
perClusterSlurmGPUs: [],
perNodeSlurmGPUs: [],
contextStats: {},
contextWorkspaceMap: {},
contextErrors: {},
@@ -334,12 +337,28 @@ export async function getWorkspaceInfrastructure() {
console.error('Error fetching Kubernetes GPUs:', error);
}

// Get Slurm GPU data
let slurmGpuData = {
allSlurmGPUs: [],
perClusterSlurmGPUs: [],
perNodeSlurmGPUs: [],
};
try {
slurmGpuData = await getSlurmServiceGPUs();
console.log('[DEBUG] Slurm GPU data in infra.jsx:', slurmGpuData);
} catch (error) {
console.error('Error fetching Slurm GPUs:', error);
}

const finalResult = {
workspaces: workspaceInfraData,
allContextNames: [...new Set(allContextsAcrossWorkspaces)].sort(),
allGPUs: gpuData.allGPUs || [],
perContextGPUs: gpuData.perContextGPUs || [],
perNodeGPUs: gpuData.perNodeGPUs || [],
allSlurmGPUs: slurmGpuData.allSlurmGPUs || [],
perClusterSlurmGPUs: slurmGpuData.perClusterSlurmGPUs || [],
perNodeSlurmGPUs: slurmGpuData.perNodeSlurmGPUs || [],
contextStats: contextStats,
contextWorkspaceMap: contextWorkspaceMap,
contextErrors: gpuData.contextErrors || {},
@@ -422,6 +441,8 @@ async function getKubernetesGPUsFromContexts(contextNames) {
const gpuName = nodeData['accelerator_type'] || '-';
const totalCount = nodeData['total']?.['accelerator_count'] || 0;
const freeCount = nodeData['free']?.['accelerators_available'] || 0;
// Check if node is ready (defaults to true for backward compatibility)
const isReady = nodeData['is_ready'] !== false;

if (totalCount > 0) {
if (!gpuToData[gpuName]) {
@@ -430,11 +451,15 @@ async function getKubernetesGPUsFromContexts(contextNames) {
gpu_requestable_qty_per_node: 0,
gpu_total: 0,
gpu_free: 0,
gpu_not_ready: 0,
context: context,
};
}
gpuToData[gpuName].gpu_total += totalCount;
gpuToData[gpuName].gpu_free += freeCount;
if (isReady === false) {
gpuToData[gpuName].gpu_not_ready += totalCount;
}
gpuToData[gpuName].gpu_requestable_qty_per_node = totalCount;
}
}
@@ -443,10 +468,13 @@ async function getKubernetesGPUsFromContexts(contextNames) {
if (gpuName in allGPUsSummary) {
allGPUsSummary[gpuName].gpu_total += gpuToData[gpuName].gpu_total;
allGPUsSummary[gpuName].gpu_free += gpuToData[gpuName].gpu_free;
allGPUsSummary[gpuName].gpu_not_ready +=
gpuToData[gpuName].gpu_not_ready;
} else {
allGPUsSummary[gpuName] = {
gpu_total: gpuToData[gpuName].gpu_total,
gpu_free: gpuToData[gpuName].gpu_free,
gpu_not_ready: gpuToData[gpuName].gpu_not_ready,
gpu_name: gpuName,
};
}
@@ -476,6 +504,8 @@ async function getKubernetesGPUsFromContexts(contextNames) {
nodeData['total']?.['accelerator_count'] ?? 0;
const freeAccelerators =
nodeData['free']?.['accelerators_available'] ?? 0;
// Check if node is ready (defaults to true for backward compatibility)
const nodeIsReady = nodeData['is_ready'] !== false;

perNodeGPUs_dict[`${context}/${nodeName}`] = {
node_name: nodeData['name'] || nodeName,
@@ -484,6 +514,7 @@ async function getKubernetesGPUsFromContexts(contextNames) {
gpu_free: freeAccelerators,
ip_address: nodeData['ip_address'] || null,
context: context,
is_ready: nodeIsReady,
};

// If this node provides a GPU type not found via GPU availability,
@@ -499,6 +530,7 @@ async function getKubernetesGPUsFromContexts(contextNames) {
allGPUsSummary[acceleratorType] = {
gpu_total: 0,
gpu_free: 0,
gpu_not_ready: 0,
gpu_name: acceleratorType,
};
}
@@ -508,6 +540,7 @@ async function getKubernetesGPUsFromContexts(contextNames) {
if (!existingGpuEntry) {
perContextGPUsData[context].push({
gpu_name: acceleratorType,
gpu_not_ready: 0,
gpu_requestable_qty_per_node: '-',
gpu_total: 0,
gpu_free: 0,
@@ -921,3 +954,184 @@ export async function getDetailedGpuInfo(filter) {
throw error;
}
}

async function getSlurmClusterGPUs() {
try {
const response = await apiClient.post(`/slurm_gpu_availability`, {});
if (!response.ok) {
const msg = `Failed to get slurm cluster GPUs with status ${response.status}`;
throw new Error(msg);
}
const id =
response.headers.get('X-Skypilot-Request-ID') ||
response.headers.get('x-request-id');
if (!id) {
const msg = 'No request ID received from server for slurm cluster GPUs';
throw new Error(msg);
}
const fetchedData = await apiClient.get(`/api/get?request_id=${id}`);
if (fetchedData.status === 500) {
try {
const data = await fetchedData.json();
if (data.detail && data.detail.error) {
try {
const error = JSON.parse(data.detail.error);
console.error('Error fetching Slurm cluster GPUs:', error.message);
} catch (jsonError) {
console.error('Error parsing JSON for Slurm error:', jsonError);
}
}
} catch (parseError) {
console.error('Error parsing JSON for Slurm 500 response:', parseError);
}
return [];
}
if (!fetchedData.ok) {
const msg = `Failed to get slurm cluster GPUs result with status ${fetchedData.status}`;
throw new Error(msg);
}
const data = await fetchedData.json();
const clusterGPUs = data.return_value ? JSON.parse(data.return_value) : [];
return clusterGPUs;
} catch (error) {
console.error('Error fetching Slurm cluster GPUs:', error);
return [];
}
}

async function getSlurmPerNodeGPUs() {
try {
const response = await apiClient.get(`/slurm_node_info`);
if (!response.ok) {
const msg = `Failed to get slurm node info with status ${response.status}`;
throw new Error(msg);
}
const id =
response.headers.get('X-Skypilot-Request-ID') ||
response.headers.get('x-request-id');
if (!id) {
const msg = 'No request ID received from server for slurm node info';
throw new Error(msg);
}
const fetchedData = await apiClient.get(`/api/get?request_id=${id}`);
if (fetchedData.status === 500) {
try {
const data = await fetchedData.json();
if (data.detail && data.detail.error) {
try {
const error = JSON.parse(data.detail.error);
console.error('Error fetching Slurm per node GPUs:', error.message);
} catch (jsonError) {
console.error(
'Error parsing JSON for Slurm node error:',
jsonError
);
}
}
} catch (parseError) {
console.error(
'Error parsing JSON for Slurm node 500 response:',
parseError
);
}
return [];
}
if (!fetchedData.ok) {
const msg = `Failed to get slurm node info result with status ${fetchedData.status}`;
throw new Error(msg);
}
const data = await fetchedData.json();
const nodeInfo = data.return_value ? JSON.parse(data.return_value) : [];
return nodeInfo;
} catch (error) {
console.error('Error fetching Slurm per node GPUs:', error);
return [];
}
}

async function getSlurmServiceGPUs() {
try {
const clusterGPUsRaw = await getSlurmClusterGPUs();
const nodeGPUsRaw = await getSlurmPerNodeGPUs();

const allSlurmGPUs = {};
const perClusterSlurmGPUs = {}; // Similar to perContextGPUs for Kubernetes
const perNodeSlurmGPUs = {}; // { 'cluster/node_name': { ... } }

// Process cluster GPUs (similar to Kubernetes context GPUs)
// clusterGPUsRaw is expected to be like: [ [cluster_name, [ [gpu_name, counts, capacity, available], ... ] ], ... ]
for (const clusterData of clusterGPUsRaw) {
const clusterName = clusterData[0];
const gpusInCluster = clusterData[1];

for (const gpuRaw of gpusInCluster) {
const gpuName = gpuRaw[0];
// gpuRaw[1] is counts (list of requestable quantities), e.g., [1, 2, 4]
const gpuRequestableQtyPerNode = gpuRaw[1].join(', ');
const gpuTotal = gpuRaw[2]; // capacity
const gpuFree = gpuRaw[3]; // available

// Aggregate for allSlurmGPUs
if (gpuName in allSlurmGPUs) {
allSlurmGPUs[gpuName].gpu_total += gpuTotal;
allSlurmGPUs[gpuName].gpu_free += gpuFree;
} else {
allSlurmGPUs[gpuName] = {
gpu_total: gpuTotal,
gpu_free: gpuFree,
gpu_name: gpuName,
};
}

// Store for perClusterSlurmGPUs (similar to perContextGPUs)
const clusterGpuKey = `${clusterName}#${gpuName}`; // Unique key for cluster-gpu combo
perClusterSlurmGPUs[clusterGpuKey] = {
gpu_name: gpuName,
gpu_requestable_qty_per_node: gpuRequestableQtyPerNode,
gpu_total: gpuTotal,
gpu_free: gpuFree,
cluster: clusterName,
};
}
}

// Process node GPUs
// nodeGPUsRaw is expected to be like: [ {node_name, slurm_cluster_name, partition, gpu_type, total_gpus, free_gpus}, ... ]
for (const node of nodeGPUsRaw) {
const clusterName = node.slurm_cluster_name || 'default';
const key = `${clusterName}/${node.node_name}/${node.gpu_type || '-'}`;
perNodeSlurmGPUs[key] = {
node_name: node.node_name,
gpu_name: node.gpu_type || '-', // gpu_type might be null
gpu_total: node.total_gpus || 0,
gpu_free: node.free_gpus || 0,
cluster: clusterName,
partition: node.partition || 'default', // partition might be null
};
}

return {
allSlurmGPUs: Object.values(allSlurmGPUs).sort((a, b) =>
a.gpu_name.localeCompare(b.gpu_name)
),
perClusterSlurmGPUs: Object.values(perClusterSlurmGPUs).sort(
(a, b) =>
a.cluster.localeCompare(b.cluster) ||
a.gpu_name.localeCompare(b.gpu_name)
),
perNodeSlurmGPUs: Object.values(perNodeSlurmGPUs).sort(
(a, b) =>
(a.cluster || '').localeCompare(b.cluster || '') ||
(a.node_name || '').localeCompare(b.node_name || '') ||
(a.gpu_name || '').localeCompare(b.gpu_name || '')
),
};
} catch (error) {
console.error('Error fetching Slurm GPUs:', error);
return {
allSlurmGPUs: [],
perClusterSlurmGPUs: [],
perNodeSlurmGPUs: [],
};
}
}

+ 15
- 5
sky/dashboard/src/pages/_app.js View File

@@ -1,18 +1,26 @@
'use client';

import React from 'react';
import ReactDOM from 'react-dom/client';
import dynamic from 'next/dynamic';
import PropTypes from 'prop-types';
import '@/app/globals.css';
import { useEffect } from 'react';
import { BASE_PATH } from '@/data/connectors/constants';
import { TourProvider } from '@/hooks/useTour';
import { PluginProvider } from '@/plugins/PluginProvider';

const Layout = dynamic(
() => import('@/components/elements/layout').then((mod) => mod.Layout),
{ ssr: false }
);

// Expose React and ReactDOM to window for plugins to use
if (typeof window !== 'undefined') {
window.React = React;
window.ReactDOM = ReactDOM;
}

function App({ Component, pageProps }) {
useEffect(() => {
const link = document.createElement('link');
@@ -22,11 +30,13 @@ function App({ Component, pageProps }) {
}, []);

return (
<TourProvider>
<Layout highlighted={pageProps.highlighted}>
<Component {...pageProps} />
</Layout>
</TourProvider>
<PluginProvider>
<TourProvider>
<Layout highlighted={pageProps.highlighted}>
<Component {...pageProps} />
</Layout>
</TourProvider>
</PluginProvider>
);
}



+ 139
- 0
sky/dashboard/src/pages/plugins/[...slug].js View File

@@ -0,0 +1,139 @@
import React, { useEffect, useRef, useState } from 'react';
import Head from 'next/head';
import { useRouter } from 'next/router';
import { CircularProgress } from '@mui/material';
import { usePluginRoute } from '@/plugins/PluginProvider';

function normalizeSlug(slug) {
if (!slug) {
return null;
}
const segments = Array.isArray(slug) ? slug : [slug];
const filtered = segments.filter(Boolean);
if (!filtered.length) {
return null;
}
return `/plugins/${filtered.join('/')}`;
}

function stripBasePath(pathname = '', basePath = '') {
if (!basePath) {
return pathname;
}
if (pathname === basePath) {
return '/';
}
if (pathname.startsWith(basePath)) {
const stripped = pathname.slice(basePath.length);
return stripped.startsWith('/') ? stripped : `/${stripped}`;
}
return pathname;
}

function derivePathname(router) {
const slugPath = normalizeSlug(router?.query?.slug);
if (slugPath) {
return slugPath;
}
const asPath = router?.asPath;
if (!asPath || typeof asPath !== 'string') {
return null;
}
const withoutQuery = asPath.split('?')[0];
const normalized = stripBasePath(withoutQuery, router?.basePath || '');
if (normalized && normalized.startsWith('/plugins')) {
return normalized;
}
return null;
}

export default function PluginRoutePage() {
const router = useRouter();
const containerRef = useRef(null);
const [mountError, setMountError] = useState(null);
const pathname = derivePathname(router);
const route = usePluginRoute(pathname);

useEffect(() => {
const container = containerRef.current;
if (!route || !container) {
return undefined;
}
setMountError(null);

let cleanup;
try {
cleanup = route.mount({
container,
route,
});
} catch (error) {
console.error(
'[SkyDashboardPlugin] Failed to mount plugin route:',
route.id,
error
);
setMountError(
'Failed to render the plugin page. Check the browser console for details.'
);
}

return () => {
if (typeof cleanup === 'function') {
try {
cleanup();
} catch (error) {
console.warn(
'[SkyDashboardPlugin] Error during plugin route cleanup:',
error
);
}
} else if (route.unmount) {
try {
route.unmount({ container, route });
} catch (error) {
console.warn(
'[SkyDashboardPlugin] Error during plugin unmount:',
error
);
}
}
if (container) {
container.innerHTML = '';
}
};
}, [route, pathname]);

const title = route?.title
? `${route.title} | SkyPilot Dashboard`
: 'Plugin | SkyPilot Dashboard';

return (
<>
<Head>
<title>{title}</title>
</Head>
<div className="min-h-[50vh]">
{mountError ? (
<div className="max-w-3xl mx-auto p-6 bg-red-50 text-red-700 rounded-lg border border-red-200">
{mountError}
</div>
) : (
<>
{!route && (
<div className="flex justify-center items-center h-64">
<CircularProgress size={20} />
<span className="ml-2 text-gray-500">
{router.isReady
? 'Loading plugin resources...'
: 'Preparing plugin route...'}
</span>
</div>
)}
<div ref={containerRef} />
</>
)}
</div>
</>
);
}

+ 345
- 0
sky/dashboard/src/plugins/PluginProvider.jsx View File

@@ -0,0 +1,345 @@
'use client';

import React, {
createContext,
useContext,
useEffect,
useMemo,
useReducer,
} from 'react';
import { BASE_PATH, ENDPOINT } from '@/data/connectors/constants';
import { apiClient } from '@/data/connectors/client';

const PluginContext = createContext({
topNavLinks: [],
routes: [],
});

const initialState = {
topNavLinks: [],
routes: [],
};

const actions = {
REGISTER_TOP_NAV_LINK: 'REGISTER_TOP_NAV_LINK',
REGISTER_ROUTE: 'REGISTER_ROUTE',
};

function pluginReducer(state, action) {
switch (action.type) {
case actions.REGISTER_TOP_NAV_LINK:
return {
...state,
topNavLinks: upsertById(state.topNavLinks, action.payload),
};
case actions.REGISTER_ROUTE:
return {
...state,
routes: upsertById(state.routes, action.payload),
};
default:
return state;
}
}

function upsertById(collection, item) {
const index = collection.findIndex((entry) => entry.id === item.id);
if (index === -1) {
return [...collection, item];
}
const next = [...collection];
next[index] = item;
return next;
}

const pluginScriptPromises = new Map();

function resolveScriptUrl(jsPath) {
if (!jsPath || typeof jsPath !== 'string') {
return null;
}
if (/^https?:\/\//.test(jsPath)) {
return jsPath;
}
if (typeof window === 'undefined') {
return jsPath;
}
try {
return new URL(jsPath, window.location.origin).toString();
} catch (error) {
console.warn(
'[SkyDashboardPlugin] Failed to resolve plugin script path:',
jsPath,
error
);
return null;
}
}

function loadPluginScript(jsPath) {
if (typeof window === 'undefined') {
return null;
}
const resolved = resolveScriptUrl(jsPath);
if (!resolved) {
return null;
}
if (pluginScriptPromises.has(resolved)) {
return pluginScriptPromises.get(resolved);
}

console.log('Loading plugin script:', resolved);
const promise = new Promise((resolve) => {
const script = document.createElement('script');
script.type = 'text/javascript';
script.async = true;
script.src = resolved;
script.onload = () => resolve();
script.onerror = (error) => {
console.warn(
'[SkyDashboardPlugin] Failed to load plugin script:',
resolved,
error
);
resolve();
};
document.head.appendChild(script);
});

pluginScriptPromises.set(resolved, promise);
return promise;
}

async function fetchPluginManifest() {
try {
const response = await apiClient.get(`/api/plugins`);
if (!response.ok) {
console.warn(
'[SkyDashboardPlugin] Failed to fetch plugin manifest:',
response.status,
response.statusText
);
return [];
}
const payload = await response.json();
if (!payload || !Array.isArray(payload.plugins)) {
return [];
}
console.log('Plugin manifest:', payload.plugins);
return payload.plugins;
} catch (error) {
console.warn('[SkyDashboardPlugin] Error fetching plugin manifest:', error);
return [];
}
}

function extractJsPath(pluginDescriptor) {
if (!pluginDescriptor || typeof pluginDescriptor !== 'object') {
return null;
}
if (pluginDescriptor.js_extension_path) {
console.log(
'Extracting JS extension path:',
pluginDescriptor.js_extension_path
);
return pluginDescriptor.js_extension_path;
}
return null;
}

function normalizeNavLink(link) {
if (!link || !link.id || !link.label || !link.href) {
console.warn(
'[SkyDashboardPlugin] Invalid top nav link registration:',
link
);
return null;
}

const normalized = {
id: String(link.id),
label: String(link.label),
href: String(link.href),
order: Number.isFinite(link.order) ? link.order : 0,
group: link.group ? String(link.group) : null,
target: link.target === '_blank' ? '_blank' : '_self',
rel:
link.rel ??
(link.target === '_blank' || /^https?:\/\//.test(String(link.href))
? 'noopener noreferrer'
: undefined),
external:
link.external ??
(/^(https?:)?\/\//.test(String(link.href)) || link.target === '_blank'),
badge: typeof link.badge === 'string' ? link.badge : null,
icon: typeof link.icon === 'string' ? link.icon : null,
description:
typeof link.description === 'string' ? link.description : undefined,
};

return normalized;
}

function normalizeRoute(route) {
if (
!route ||
typeof route !== 'object' ||
!route.id ||
!route.path ||
typeof route.mount !== 'function'
) {
console.warn('[SkyDashboardPlugin] Invalid route registration:', route);
return null;
}

const normalizedPath = String(route.path);
const pathname = normalizedPath.startsWith('/')
? normalizedPath
: `/${normalizedPath}`;

return {
id: String(route.id),
path: pathname,
title: typeof route.title === 'string' ? route.title : undefined,
description:
typeof route.description === 'string' ? route.description : undefined,
mount: route.mount,
unmount: typeof route.unmount === 'function' ? route.unmount : undefined,
context:
route.context && typeof route.context === 'object'
? route.context
: undefined,
};
}

function createPluginApi(dispatch) {
return {
registerTopNavLink(link) {
const normalized = normalizeNavLink(link);
if (!normalized) {
return null;
}
dispatch({
type: actions.REGISTER_TOP_NAV_LINK,
payload: normalized,
});
return normalized.id;
},
registerRoute(route) {
const normalized = normalizeRoute(route);
if (!normalized) {
return null;
}
dispatch({
type: actions.REGISTER_ROUTE,
payload: normalized,
});
return normalized.id;
},
getContext() {
return {
basePath: BASE_PATH,
apiEndpoint: ENDPOINT,
};
},
};
}

export function PluginProvider({ children }) {
const [state, dispatch] = useReducer(pluginReducer, initialState);

useEffect(() => {
if (typeof window === 'undefined') {
return;
}

let cancelled = false;
const api = createPluginApi(dispatch);
window.SkyDashboardPluginAPI = api;
window.dispatchEvent(
new CustomEvent('skydashboard:plugins-ready', { detail: api })
);
const bootstrapPlugins = async () => {
const manifest = await fetchPluginManifest();
if (cancelled) {
return;
}
manifest
.map((pluginDescriptor) => extractJsPath(pluginDescriptor))
.filter(Boolean)
.forEach((jsPath) => {
if (!cancelled) {
loadPluginScript(jsPath);
}
});
};
void bootstrapPlugins();

return () => {
cancelled = true;
if (window.SkyDashboardPluginAPI === api) {
delete window.SkyDashboardPluginAPI;
}
};
}, []);

const value = useMemo(() => state, [state]);

return (
<PluginContext.Provider value={value}>{children}</PluginContext.Provider>
);
}

export function usePluginState() {
return useContext(PluginContext);
}

export function useTopNavLinks() {
const { topNavLinks } = usePluginState();
return useMemo(
() =>
[...topNavLinks].sort((a, b) => {
return a.order - b.order;
}),
[topNavLinks]
);
}

export function useGroupedNavLinks() {
const { topNavLinks } = usePluginState();

return useMemo(() => {
const sorted = [...topNavLinks].sort((a, b) => a.order - b.order);

// Separate links with and without group
const ungrouped = sorted.filter((link) => !link.group);
const grouped = sorted.filter((link) => link.group);

// Categorize by group
const groups = grouped.reduce((acc, link) => {
const groupName = link.group;
if (!acc[groupName]) {
acc[groupName] = [];
}
acc[groupName].push(link);
return acc;
}, {});

return { ungrouped, groups };
}, [topNavLinks]);
}

export function usePluginRoutes() {
const { routes } = usePluginState();
return routes;
}

export function usePluginRoute(pathname) {
const routes = usePluginRoutes();
return useMemo(() => {
if (!pathname) {
return null;
}
return routes.find((route) => route.path === pathname) || null;
}, [pathname, routes]);
}

+ 16
- 2
sky/data/mounting_utils.py View File

@@ -223,7 +223,10 @@ def get_gcs_mount_cmd(bucket_name: str,
"""Returns a command to mount a GCS bucket using gcsfuse."""
bucket_sub_path_arg = f'--only-dir {_bucket_sub_path} '\
if _bucket_sub_path else ''
mount_cmd = ('gcsfuse -o allow_other '
log_file = '$(mktemp -t gcsfuse.XXXX.log)'
mount_cmd = (f'gcsfuse --log-file {log_file} '
'--debug_fuse_errors '
'-o allow_other '
'--implicit-dirs '
f'--stat-cache-capacity {_STAT_CACHE_CAPACITY} '
f'--stat-cache-ttl {_STAT_CACHE_TTL} '
@@ -646,8 +649,19 @@ def get_mounting_script(
else
echo "No goofys log file found in /tmp"
fi
elif [ "$MOUNT_BINARY" = "gcsfuse" ]; then
echo "Looking for gcsfuse log files..."
# Find gcsfuse log files in /tmp (created by mktemp -t gcsfuse.XXXX.log)
GCSFUSE_LOGS=$(ls -t /tmp/gcsfuse.*.log 2>/dev/null | head -1)
if [ -n "$GCSFUSE_LOGS" ]; then
echo "=== GCSFuse log file contents ==="
cat "$GCSFUSE_LOGS"
echo "=== End of gcsfuse log file ==="
else
echo "No gcsfuse log file found in /tmp"
fi
fi
# TODO(kevin): Print logs from rclone, etc too for observability.
# TODO(kevin): Print logs from rclone, blobfuse2, etc too for observability.
exit $MOUNT_EXIT_CODE
fi
echo "Mounting done."


+ 3
- 3
sky/global_user_state.py View File

@@ -2241,7 +2241,7 @@ def get_volumes(is_ephemeral: Optional[bool] = None) -> List[Dict[str, Any]]:
rows = session.query(volume_table).all()
else:
rows = session.query(volume_table).filter_by(
is_ephemeral=is_ephemeral).all()
is_ephemeral=int(is_ephemeral)).all()
records = []
for row in rows:
records.append({
@@ -2253,7 +2253,7 @@ def get_volumes(is_ephemeral: Optional[bool] = None) -> List[Dict[str, Any]]:
'last_attached_at': row.last_attached_at,
'last_use': row.last_use,
'status': status_lib.VolumeStatus[row.status],
'is_ephemeral': row.is_ephemeral,
'is_ephemeral': bool(row.is_ephemeral),
})
return records

@@ -2316,7 +2316,7 @@ def add_volume(
last_attached_at=last_attached_at,
last_use=last_use,
status=status.value,
is_ephemeral=is_ephemeral,
is_ephemeral=int(is_ephemeral),
)
do_update_stmt = insert_stmnt.on_conflict_do_nothing()
session.execute(do_update_stmt)


+ 2
- 0
sky/models.py View File

@@ -68,6 +68,8 @@ class KubernetesNodeInfo:
free: Dict[str, int]
# IP address of the node (external IP preferred, fallback to internal IP)
ip_address: Optional[str] = None
# Whether the node is ready (all conditions are satisfied)
is_ready: bool = True


@dataclasses.dataclass


+ 6
- 5
sky/optimizer.py View File

@@ -781,7 +781,7 @@ class Optimizer:
def _instance_type_str(resources: 'resources_lib.Resources') -> str:
instance_type = resources.instance_type
assert instance_type is not None, 'Instance type must be specified'
if isinstance(resources.cloud, clouds.Kubernetes):
if isinstance(resources.cloud, (clouds.Kubernetes, clouds.Slurm)):
instance_type = '-'
if resources.use_spot:
instance_type = ''
@@ -865,11 +865,12 @@ class Optimizer:
'use_spot': resources.use_spot
}

# Handle special case for Kubernetes and SSH clouds
if isinstance(resources.cloud, clouds.Kubernetes):
# Handle special case for Kubernetes, SSH, and SLURM clouds
if isinstance(resources.cloud, (clouds.Kubernetes, clouds.Slurm)):
# Region for Kubernetes-like clouds (SSH, Kubernetes) is the
# context name, i.e. different Kubernetes clusters. We add
# region to the key to show all the Kubernetes clusters in the
# context name, i.e. different Kubernetes clusters.
# Region for SLURM is the cluster name.
# We add region to the key to show all the clusters in the
# optimizer table for better UX.

if resources.cloud.__class__.__name__ == 'SSH':


+ 1
- 0
sky/provision/__init__.py View File

@@ -29,6 +29,7 @@ from sky.provision import runpod
from sky.provision import scp
from sky.provision import seeweb
from sky.provision import shadeform
from sky.provision import slurm
from sky.provision import ssh
from sky.provision import vast
from sky.provision import vsphere


+ 20
- 0
sky/provision/common.py View File

@@ -6,6 +6,7 @@ import os
from typing import Any, Dict, List, Optional, Tuple

from sky import sky_logging
from sky.utils import config_utils
from sky.utils import env_options
from sky.utils import resources_utils

@@ -36,6 +37,13 @@ class StopFailoverError(Exception):
"""


# These fields are sensitive and should be redacted from the config for logging
# purposes.
SENSITIVE_FIELDS = [
('docker_config', 'docker_login_config', 'password'),
]


@dataclasses.dataclass
class ProvisionConfig:
"""Configuration for provisioning."""
@@ -56,6 +64,18 @@ class ProvisionConfig:
# Optional ports to open on launch of the cluster.
ports_to_open_on_launch: Optional[List[int]]

def get_redacted_config(self) -> Dict[str, Any]:
"""Get the redacted config."""
config = dataclasses.asdict(self)

config_copy = config_utils.Config(config)

for field_list in SENSITIVE_FIELDS:
val = config_copy.get_nested(field_list, default_value=None)
if val is not None:
config_copy.set_nested(field_list, '<redacted>')
return dict(**config_copy)


# -------------------- output data model -------------------- #



+ 15
- 2
sky/provision/docker_utils.py View File

@@ -176,6 +176,17 @@ def _with_interactive(cmd):
return ['bash', '--login', '-c', '-i', shlex.quote(force_interactive)]


def _redact_docker_password(cmd: str) -> str:
parts = shlex.split(cmd)
for i, part in enumerate(parts):
if part.startswith('--password'):
if part.startswith('--password='):
parts[i] = '--password=<redacted>'
elif i + 1 < len(parts):
parts[i + 1] = '<redacted>'
return ' '.join(parts)


# SkyPilot: New class to initialize docker containers on a remote node.
# Adopted from ray.autoscaler._private.command_runner.DockerCommandRunner.
class DockerInitializer:
@@ -219,7 +230,9 @@ class DockerInitializer:
cmd = (f'flock {flock_args} /tmp/{flock_name} '
f'-c {shlex.quote(cmd)}')

logger.debug(f'+ {cmd}')
# Redact the password in the login command.
redacted_cmd = _redact_docker_password(cmd)
logger.debug(f'+ {redacted_cmd}')
start = time.time()
while True:
rc, stdout, stderr = self.runner.run(
@@ -251,7 +264,7 @@ class DockerInitializer:
break
subprocess_utils.handle_returncode(
rc,
cmd,
redacted_cmd,
error_msg='Failed to run docker setup commands.',
stderr=stdout + stderr,
# Print out the error message if the command failed.


+ 42
- 6
sky/provision/kubernetes/utils.py View File

@@ -1205,15 +1205,24 @@ class V1NodeAddress:
address: str


@dataclasses.dataclass
class V1NodeCondition:
"""Represents a Kubernetes node condition."""
type: str
status: str


@dataclasses.dataclass
class V1NodeStatus:
allocatable: Dict[str, str]
capacity: Dict[str, str]
addresses: List[V1NodeAddress]
conditions: List[V1NodeCondition]


@dataclasses.dataclass
class V1Node:
"""Represents a Kubernetes node."""
metadata: V1ObjectMeta
status: V1NodeStatus

@@ -1231,8 +1240,24 @@ class V1Node:
V1NodeAddress(type=addr['type'],
address=addr['address'])
for addr in data['status'].get('addresses', [])
],
conditions=[
V1NodeCondition(type=cond['type'],
status=cond['status'])
for cond in data['status'].get('conditions', [])
]))

def is_ready(self) -> bool:
"""Check if the node is ready based on its conditions.

A node is considered ready if it has a 'Ready' condition with
status 'True'.
"""
for condition in self.status.conditions:
if condition.type == 'Ready':
return condition.status == 'True'
return False


@annotations.lru_cache(scope='request', maxsize=10)
@_retry_on_error(resource_type='node')
@@ -1451,11 +1476,12 @@ def check_instance_fits(context: Optional[str],
return False, str(e)
# Get the set of nodes that have the GPU type
gpu_nodes = [
node for node in nodes if gpu_label_key in node.metadata.labels and
node for node in nodes
if node.is_ready() and gpu_label_key in node.metadata.labels and
node.metadata.labels[gpu_label_key] in gpu_label_values
]
if not gpu_nodes:
return False, f'No GPU nodes found with {acc_type} on the cluster'
return False, f'No ready GPU nodes found with {acc_type} on the cluster'
if is_tpu_on_gke(acc_type):
# If requested accelerator is a TPU type, check if the cluster
# has sufficient TPU resource to meet the requirement.
@@ -1479,7 +1505,9 @@ def check_instance_fits(context: Optional[str],
f'enough CPU (> {k8s_instance_type.cpus} CPUs) and/or '
f'memory (> {k8s_instance_type.memory} G). ')
else:
candidate_nodes = nodes
candidate_nodes = [node for node in nodes if node.is_ready()]
if not candidate_nodes:
return False, 'No ready nodes found in the cluster.'
not_fit_reason_prefix = (f'No nodes found with enough '
f'CPU (> {k8s_instance_type.cpus} CPUs) '
'and/or memory '
@@ -3078,16 +3106,23 @@ def get_kubernetes_node_info(

accelerator_count = get_node_accelerator_count(context,
node.status.allocatable)
# Check if node is ready
node_is_ready = node.is_ready()

if accelerator_count == 0:
node_info_dict[node.metadata.name] = models.KubernetesNodeInfo(
name=node.metadata.name,
accelerator_type=accelerator_name,
total={'accelerator_count': 0},
free={'accelerators_available': 0},
ip_address=node_ip)
ip_address=node_ip,
is_ready=node_is_ready)
continue

if not has_accelerator_nodes or error_on_get_allocated_gpu_qty_by_node:
if not node_is_ready:
# If node is not ready, report 0 available GPUs
accelerators_available = 0
elif not has_accelerator_nodes or error_on_get_allocated_gpu_qty_by_node:
accelerators_available = -1
else:
allocated_qty = allocated_qty_by_node[node.metadata.name]
@@ -3105,7 +3140,8 @@ def get_kubernetes_node_info(
accelerator_type=accelerator_name,
total={'accelerator_count': int(accelerator_count)},
free={'accelerators_available': int(accelerators_available)},
ip_address=node_ip)
ip_address=node_ip,
is_ready=node_is_ready)
hint = ''
if has_multi_host_tpu:
hint = ('(Note: Multi-host TPUs are detected and excluded from the '


+ 15
- 6
sky/provision/provisioner.py View File

@@ -157,9 +157,9 @@ def bulk_provision(
logger.debug(f'SkyPilot version: {sky.__version__}; '
f'commit: {sky.__commit__}')
logger.debug(_TITLE.format('Provisioning'))
logger.debug(
'Provision config:\n'
f'{json.dumps(dataclasses.asdict(bootstrap_config), indent=2)}')
redacted_config = bootstrap_config.get_redacted_config()
logger.debug('Provision config:\n'
f'{json.dumps(redacted_config, indent=2)}')
return _bulk_provision(cloud, region, cluster_name,
bootstrap_config)
except exceptions.NoClusterLaunchedError:
@@ -635,10 +635,15 @@ def _post_provision_setup(
status.update(
runtime_preparation_str.format(step=3, step_name='runtime'))

skip_ray_setup = False
ray_port = constants.SKY_REMOTE_RAY_PORT
head_ray_needs_restart = True
ray_cluster_healthy = False
if (not provision_record.is_instance_just_booted(
if (launched_resources.cloud is not None and
not launched_resources.cloud.uses_ray()):
skip_ray_setup = True
logger.debug('Skip Ray cluster setup as cloud does not use Ray.')
elif (not provision_record.is_instance_just_booted(
head_instance.instance_id)):
# Check if head node Ray is alive
(ray_port, ray_cluster_healthy,
@@ -663,7 +668,9 @@ def _post_provision_setup(
'async setup to complete...')
time.sleep(1)

if head_ray_needs_restart:
if skip_ray_setup:
logger.debug('Skip Ray cluster setup on the head node.')
elif head_ray_needs_restart:
logger.debug('Starting Ray on the entire cluster.')
instance_setup.start_ray_on_head_node(
cluster_name.name_on_cloud,
@@ -686,7 +693,9 @@ def _post_provision_setup(
# We don't need to restart ray on worker nodes if the ray cluster is
# already healthy, i.e. the head node has expected number of nodes
# connected to the ray cluster.
if cluster_info.num_instances > 1 and not ray_cluster_healthy:
if skip_ray_setup:
logger.debug('Skip Ray cluster setup on the worker nodes.')
elif cluster_info.num_instances > 1 and not ray_cluster_healthy:
instance_setup.start_ray_on_worker_nodes(
cluster_name.name_on_cloud,
no_restart=not head_ray_needs_restart,


+ 12
- 0
sky/provision/slurm/__init__.py View File

@@ -0,0 +1,12 @@
"""Slurm provisioner for SkyPilot."""

from sky.provision.slurm.config import bootstrap_instances
from sky.provision.slurm.instance import cleanup_ports
from sky.provision.slurm.instance import get_cluster_info
from sky.provision.slurm.instance import get_command_runners
from sky.provision.slurm.instance import open_ports
from sky.provision.slurm.instance import query_instances
from sky.provision.slurm.instance import run_instances
from sky.provision.slurm.instance import stop_instances
from sky.provision.slurm.instance import terminate_instances
from sky.provision.slurm.instance import wait_instances

+ 13
- 0
sky/provision/slurm/config.py View File

@@ -0,0 +1,13 @@
"""Slrum-specific configuration for the provisioner."""
import logging

from sky.provision import common

logger = logging.getLogger(__name__)


def bootstrap_instances(
region: str, cluster_name: str,
config: common.ProvisionConfig) -> common.ProvisionConfig:
del region, cluster_name # unused
return config

+ 572
- 0
sky/provision/slurm/instance.py View File

@@ -0,0 +1,572 @@
"""Slurm instance provisioning."""

import tempfile
import textwrap
import time
from typing import Any, cast, Dict, List, Optional, Tuple

from sky import sky_logging
from sky import skypilot_config
from sky.adaptors import slurm
from sky.provision import common
from sky.provision import constants
from sky.provision.slurm import utils as slurm_utils
from sky.utils import command_runner
from sky.utils import common_utils
from sky.utils import status_lib
from sky.utils import subprocess_utils
from sky.utils import timeline

logger = sky_logging.init_logger(__name__)

# TODO(kevin): This assumes $HOME is in a shared filesystem.
# We should probably make it configurable, and add a check
# during sky check.
SHARED_ROOT_SKY_DIRECTORY = '~/.sky_clusters'
PROVISION_SCRIPTS_DIRECTORY_NAME = '.sky_provision'
PROVISION_SCRIPTS_DIRECTORY = f'~/{PROVISION_SCRIPTS_DIRECTORY_NAME}'

POLL_INTERVAL_SECONDS = 2
# Default KillWait is 30 seconds, so we add some buffer time here.
_JOB_TERMINATION_TIMEOUT_SECONDS = 60
_SKY_DIR_CREATION_TIMEOUT_SECONDS = 30


def _sky_cluster_home_dir(cluster_name_on_cloud: str) -> str:
"""Returns the SkyPilot cluster's home directory path on the Slurm cluster.

This path is assumed to be on a shared NFS mount accessible by all nodes.
To support clusters with non-NFS home directories, we would need to let
users specify an NFS-backed "working directory" or use a different
coordination mechanism.
"""
return f'{SHARED_ROOT_SKY_DIRECTORY}/{cluster_name_on_cloud}'


def _sbatch_provision_script_path(filename: str) -> str:
"""Returns the path to the sbatch provision script on the login node."""
# Put sbatch script in $HOME instead of /tmp as there can be
# multiple login nodes, and different SSH connections
# can land on different login nodes.
return f'{PROVISION_SCRIPTS_DIRECTORY}/{filename}'


def _skypilot_runtime_dir(cluster_name_on_cloud: str) -> str:
"""Returns the SkyPilot runtime directory path on the Slurm cluster."""
return f'/tmp/{cluster_name_on_cloud}'


@timeline.event
def _create_virtual_instance(
region: str, cluster_name_on_cloud: str,
config: common.ProvisionConfig) -> common.ProvisionRecord:
"""Creates a Slurm virtual instance from the config.

A Slurm virtual instance is created by submitting a long-running
job with sbatch, to mimic a cloud VM.
"""
provider_config = config.provider_config
ssh_config_dict = provider_config['ssh']
ssh_host = ssh_config_dict['hostname']
ssh_port = int(ssh_config_dict['port'])
ssh_user = ssh_config_dict['user']
ssh_key = ssh_config_dict['private_key']
ssh_proxy_command = ssh_config_dict.get('proxycommand', None)
partition = slurm_utils.get_partition_from_config(provider_config)

client = slurm.SlurmClient(
ssh_host,
ssh_port,
ssh_user,
ssh_key,
ssh_proxy_command=ssh_proxy_command,
)

# COMPLETING state occurs when a job is being terminated - during this
# phase, slurmd sends SIGTERM to tasks, waits for KillWait period, sends
# SIGKILL if needed, runs epilog scripts, and notifies slurmctld. This
# typically happens when a previous job with the same name is being
# cancelled or has finished. Jobs can get stuck in COMPLETING if epilog
# scripts hang or tasks don't respond to signals, so we wait with a
# timeout.
completing_jobs = client.query_jobs(
cluster_name_on_cloud,
['completing'],
)
start_time = time.time()
while (completing_jobs and
time.time() - start_time < _JOB_TERMINATION_TIMEOUT_SECONDS):
logger.debug(f'Found {len(completing_jobs)} completing jobs. '
f'Waiting for them to finish: {completing_jobs}')
time.sleep(POLL_INTERVAL_SECONDS)
completing_jobs = client.query_jobs(
cluster_name_on_cloud,
['completing'],
)
if completing_jobs:
# TODO(kevin): Automatically handle this, following the suggestions in
# https://slurm.schedmd.com/troubleshoot.html#completing
raise RuntimeError(f'Found {len(completing_jobs)} jobs still in '
'completing state after '
f'{_JOB_TERMINATION_TIMEOUT_SECONDS}s. '
'This is typically due to non-killable processes '
'associated with the job.')

# Check if job already exists
existing_jobs = client.query_jobs(
cluster_name_on_cloud,
['pending', 'running'],
)

# Get provision_timeout from config. If not specified, use None,
# which will use the default timeout specified in the Slurm adaptor.
provision_timeout = skypilot_config.get_effective_region_config(
cloud='slurm',
region=region,
keys=('provision_timeout',),
default_value=None)

if existing_jobs:
assert len(existing_jobs) == 1, (
f'Multiple jobs found with name {cluster_name_on_cloud}: '
f'{existing_jobs}')

job_id = existing_jobs[0]
logger.debug(f'Job with name {cluster_name_on_cloud} already exists '
f'(JOBID: {job_id})')

# Wait for nodes to be allocated (job might be in PENDING state)
nodes, _ = client.get_job_nodes(job_id,
wait=True,
timeout=provision_timeout)
return common.ProvisionRecord(provider_name='slurm',
region=region,
zone=partition,
cluster_name=cluster_name_on_cloud,
head_instance_id=slurm_utils.instance_id(
job_id, nodes[0]),
resumed_instance_ids=[],
created_instance_ids=[])

resources = config.node_config

# Note: By default Slurm terminates the entire job allocation if any node
# fails in its range of allocated nodes.
# In the future we can consider running sbatch with --no-kill to not
# automatically terminate a job if one of the nodes it has been
# allocated fails.
num_nodes = config.count

accelerator_type = resources.get('accelerator_type')
accelerator_count_raw = resources.get('accelerator_count')
try:
accelerator_count = int(
accelerator_count_raw) if accelerator_count_raw is not None else 0
except (TypeError, ValueError):
accelerator_count = 0

skypilot_runtime_dir = _skypilot_runtime_dir(cluster_name_on_cloud)
sky_home_dir = _sky_cluster_home_dir(cluster_name_on_cloud)
ready_signal = f'{sky_home_dir}/.sky_sbatch_ready'

# Build the sbatch script
gpu_directive = ''
if (accelerator_type is not None and accelerator_type.upper() != 'NONE' and
accelerator_count > 0):
gpu_directive = (f'#SBATCH --gres=gpu:{accelerator_type.lower()}:'
f'{accelerator_count}')

# By default stdout and stderr will be written to $HOME/slurm-%j.out
# (because we invoke sbatch from $HOME). Redirect elsewhere to not pollute
# the home directory.
provision_script = textwrap.dedent(f"""\
#!/bin/bash
#SBATCH --job-name={cluster_name_on_cloud}
#SBATCH --output={PROVISION_SCRIPTS_DIRECTORY_NAME}/slurm-%j.out
#SBATCH --error={PROVISION_SCRIPTS_DIRECTORY_NAME}/slurm-%j.out
#SBATCH --nodes={num_nodes}
#SBATCH --wait-all-nodes=1
# Let the job be terminated rather than requeued implicitly.
#SBATCH --no-requeue
#SBATCH --cpus-per-task={int(resources["cpus"])}
#SBATCH --mem={int(resources["memory"])}G
{gpu_directive}

# Cleanup function to remove cluster dirs on job termination.
cleanup() {{
# The Skylet is daemonized, so it is not automatically terminated when
# the Slurm job is terminated, we need to kill it manually.
echo "Terminating Skylet..."
if [ -f "{skypilot_runtime_dir}/.sky/skylet_pid" ]; then
kill $(cat "{skypilot_runtime_dir}/.sky/skylet_pid") 2>/dev/null || true
fi
echo "Cleaning up sky directories..."
# Clean up sky runtime directory on each node.
# NOTE: We can do this because --nodes for both this srun and the
# sbatch is the same number. Otherwise, there are no guarantees
# that this srun will run on the same subset of nodes as the srun
# that created the sky directories.
srun --nodes={num_nodes} rm -rf {skypilot_runtime_dir}
rm -rf {sky_home_dir}
}}
trap cleanup TERM

# Create sky home directory for the cluster.
mkdir -p {sky_home_dir}
# Create sky runtime directory on each node.
srun --nodes={num_nodes} mkdir -p {skypilot_runtime_dir}
# Suppress login messages.
touch {sky_home_dir}/.hushlogin
# Signal that the sbatch script has completed setup.
touch {ready_signal}
sleep infinity
""")

# To bootstrap things, we need to do it with SSHCommandRunner first.
# SlurmCommandRunner is for after the virtual instances are created.
login_node_runner = command_runner.SSHCommandRunner(
(ssh_host, ssh_port),
ssh_user,
ssh_key,
ssh_proxy_command=ssh_proxy_command,
)

cmd = f'mkdir -p {PROVISION_SCRIPTS_DIRECTORY}'
rc, stdout, stderr = login_node_runner.run(cmd,
require_outputs=True,
stream_logs=False)
subprocess_utils.handle_returncode(
rc,
cmd,
'Failed to create provision scripts directory on login node.',
stderr=f'{stdout}\n{stderr}')
# Rsync the provision script to the login node
with tempfile.NamedTemporaryFile(mode='w', suffix='.sh', delete=True) as f:
f.write(provision_script)
f.flush()
src_path = f.name
tgt_path = _sbatch_provision_script_path(f'{cluster_name_on_cloud}.sh')
login_node_runner.rsync(src_path, tgt_path, up=True, stream_logs=False)

job_id = client.submit_job(partition, cluster_name_on_cloud, tgt_path)
logger.debug(f'Successfully submitted Slurm job {job_id} to partition '
f'{partition} for cluster {cluster_name_on_cloud} '
f'with {num_nodes} nodes')

nodes, _ = client.get_job_nodes(job_id,
wait=True,
timeout=provision_timeout)
created_instance_ids = [
slurm_utils.instance_id(job_id, node) for node in nodes
]

# Wait for the sbatch script to create the cluster's sky directories,
# to avoid a race condition where post-provision commands try to
# access the directories before they are created.
ready_check_cmd = (f'end=$((SECONDS+{_SKY_DIR_CREATION_TIMEOUT_SECONDS})); '
f'while [ ! -f {ready_signal} ]; do '
'if (( SECONDS >= end )); then '
'exit 1; fi; '
'sleep 0.5; '
'done')
rc, stdout, stderr = login_node_runner.run(ready_check_cmd,
require_outputs=True,
stream_logs=False)
subprocess_utils.handle_returncode(
rc,
ready_check_cmd,
'Failed to verify sky directories creation.',
stderr=f'{stdout}\n{stderr}')

return common.ProvisionRecord(provider_name='slurm',
region=region,
zone=partition,
cluster_name=cluster_name_on_cloud,
head_instance_id=created_instance_ids[0],
resumed_instance_ids=[],
created_instance_ids=created_instance_ids)


@common_utils.retry
def query_instances(
cluster_name: str,
cluster_name_on_cloud: str,
provider_config: Optional[Dict[str, Any]] = None,
non_terminated_only: bool = True,
retry_if_missing: bool = False,
) -> Dict[str, Tuple[Optional[status_lib.ClusterStatus], Optional[str]]]:
"""See sky/provision/__init__.py"""
del cluster_name, retry_if_missing # Unused for Slurm
assert provider_config is not None, (cluster_name_on_cloud, provider_config)

ssh_config_dict = provider_config['ssh']
ssh_host = ssh_config_dict['hostname']
ssh_port = int(ssh_config_dict['port'])
ssh_user = ssh_config_dict['user']
ssh_key = ssh_config_dict['private_key']
ssh_proxy_command = ssh_config_dict.get('proxycommand', None)

client = slurm.SlurmClient(
ssh_host,
ssh_port,
ssh_user,
ssh_key,
ssh_proxy_command=ssh_proxy_command,
)

# Map Slurm job states to SkyPilot ClusterStatus
# Slurm states:
# https://slurm.schedmd.com/squeue.html#SECTION_JOB-STATE-CODES
# TODO(kevin): Include more states here.
status_map = {
'pending': status_lib.ClusterStatus.INIT,
'running': status_lib.ClusterStatus.UP,
'completing': status_lib.ClusterStatus.UP,
'completed': None,
'cancelled': None,
# NOTE: Jobs that get cancelled (from sky down) will go to failed state
# with the reason 'NonZeroExitCode' and remain in the squeue output for
# a while.
'failed': None,
'node_fail': None,
}

statuses: Dict[str, Tuple[Optional[status_lib.ClusterStatus],
Optional[str]]] = {}
for state, sky_status in status_map.items():
jobs = client.query_jobs(
cluster_name_on_cloud,
[state],
)

for job_id in jobs:
if state in ('pending', 'failed', 'node_fail', 'cancelled',
'completed'):
reason = client.get_job_reason(job_id)
if non_terminated_only and sky_status is None:
# TODO(kevin): For better UX, we should also find out
# which node(s) exactly that failed if it's a node_fail
# state.
logger.debug(f'Job {job_id} is terminated, but '
'query_instances is called with '
f'non_terminated_only=True. State: {state}, '
f'Reason: {reason}')
continue
statuses[job_id] = (sky_status, reason)
else:
nodes, _ = client.get_job_nodes(job_id, wait=False)
for node in nodes:
instance_id = slurm_utils.instance_id(job_id, node)
statuses[instance_id] = (sky_status, None)

# TODO(kevin): Query sacct too to get more historical job info.
# squeue only includes completed jobs that finished in the last
# MinJobAge seconds (default 300s). Or could be earlier if it
# reaches MaxJobCount first (default 10_000).

return statuses


def run_instances(
region: str,
cluster_name: str, # pylint: disable=unused-argument
cluster_name_on_cloud: str,
config: common.ProvisionConfig) -> common.ProvisionRecord:
"""Run instances for the given cluster (Slurm in this case)."""
return _create_virtual_instance(region, cluster_name_on_cloud, config)


def wait_instances(region: str, cluster_name_on_cloud: str,
state: Optional[status_lib.ClusterStatus]) -> None:
"""See sky/provision/__init__.py"""
del region, cluster_name_on_cloud, state
# We already wait for the instances to be running in run_instances.
# So we don't need to wait here.


def get_cluster_info(
region: str,
cluster_name_on_cloud: str,
provider_config: Optional[Dict[str, Any]] = None) -> common.ClusterInfo:
del region
assert provider_config is not None, cluster_name_on_cloud

# The SSH host is the remote machine running slurmctld daemon.
# Cross-cluster operations are supported by interacting with
# the current controller. For details, please refer to
# https://slurm.schedmd.com/multi_cluster.html.
ssh_config_dict = provider_config['ssh']
ssh_host = ssh_config_dict['hostname']
ssh_port = int(ssh_config_dict['port'])
ssh_user = ssh_config_dict['user']
ssh_key = ssh_config_dict['private_key']
ssh_proxy_command = ssh_config_dict.get('proxycommand', None)

client = slurm.SlurmClient(
ssh_host,
ssh_port,
ssh_user,
ssh_key,
ssh_proxy_command=ssh_proxy_command,
)

# Find running job for this cluster
running_jobs = client.query_jobs(
cluster_name_on_cloud,
['running'],
)

if not running_jobs:
# No running jobs found - cluster may be in pending or terminated state
return common.ClusterInfo(
instances={},
head_instance_id=None,
ssh_user=ssh_user,
provider_name='slurm',
provider_config=provider_config,
)
assert len(running_jobs) == 1, (
f'Multiple running jobs found for cluster {cluster_name_on_cloud}: '
f'{running_jobs}')

job_id = running_jobs[0]
# Running jobs should already have nodes allocated, so don't wait
nodes, node_ips = client.get_job_nodes(job_id, wait=False)

instances = {
f'{slurm_utils.instance_id(job_id, node)}': [
common.InstanceInfo(
instance_id=slurm_utils.instance_id(job_id, node),
internal_ip=node_ip,
external_ip=ssh_host,
ssh_port=ssh_port,
tags={
constants.TAG_SKYPILOT_CLUSTER_NAME: cluster_name_on_cloud,
'job_id': job_id,
'node': node,
},
)
] for node, node_ip in zip(nodes, node_ips)
}

return common.ClusterInfo(
instances=instances,
head_instance_id=slurm_utils.instance_id(job_id, nodes[0]),
ssh_user=ssh_user,
provider_name='slurm',
provider_config=provider_config,
)


def stop_instances(
cluster_name_on_cloud: str,
provider_config: Optional[Dict[str, Any]] = None,
worker_only: bool = False,
) -> None:
"""Keep the Slurm virtual instances running."""
raise NotImplementedError()


def terminate_instances(
cluster_name_on_cloud: str,
provider_config: Optional[Dict[str, Any]] = None,
worker_only: bool = False,
) -> None:
"""See sky/provision/__init__.py"""
assert provider_config is not None, cluster_name_on_cloud

if worker_only:
logger.warning(
'worker_only=True is not supported for Slurm, this is a no-op.')
return

ssh_config_dict = provider_config['ssh']
ssh_host = ssh_config_dict['hostname']
ssh_port = int(ssh_config_dict['port'])
ssh_user = ssh_config_dict['user']
ssh_private_key = ssh_config_dict['private_key']
# Check if we are running inside a Slurm job (Only happens with autodown,
# where the Skylet will invoke terminate_instances on the remote cluster),
# where we assume SSH between nodes have been set up on each node's
# ssh config.
# TODO(kevin): Validate this assumption. Another way would be to
# mount the private key to the remote cluster, like we do with
# other clouds' API keys.
if slurm_utils.is_inside_slurm_job():
logger.debug('Running inside a Slurm job, using machine\'s ssh config')
ssh_private_key = None
ssh_proxy_command = ssh_config_dict.get('proxycommand', None)

client = slurm.SlurmClient(
ssh_host,
ssh_port,
ssh_user,
ssh_private_key,
ssh_proxy_command=ssh_proxy_command,
)
client.cancel_jobs_by_name(
cluster_name_on_cloud,
signal='TERM',
full=True,
)


def open_ports(
cluster_name_on_cloud: str,
ports: List[str],
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
"""See sky/provision/__init__.py"""
del cluster_name_on_cloud, ports, provider_config
pass


def cleanup_ports(
cluster_name_on_cloud: str,
ports: List[str],
provider_config: Optional[Dict[str, Any]] = None,
) -> None:
"""See sky/provision/__init__.py"""
del cluster_name_on_cloud, ports, provider_config
pass


def get_command_runners(
cluster_info: common.ClusterInfo,
**credentials: Dict[str, Any],
) -> List[command_runner.SlurmCommandRunner]:
"""Get a command runner for the given cluster."""
assert cluster_info.provider_config is not None, cluster_info

if cluster_info.head_instance_id is None:
# No running job found
return []

head_instance = cluster_info.get_head_instance()
assert head_instance is not None, 'Head instance not found'
cluster_name_on_cloud = head_instance.tags.get(
constants.TAG_SKYPILOT_CLUSTER_NAME, None)
assert cluster_name_on_cloud is not None, cluster_info

# There can only be one InstanceInfo per instance_id.
instances = [
instance_infos[0] for instance_infos in cluster_info.instances.values()
]

# Note: For Slurm, the external IP for all instances is the same,
# it is the login node's. The internal IP is the private IP of the node.
ssh_user = cast(str, credentials.pop('ssh_user'))
ssh_private_key = cast(str, credentials.pop('ssh_private_key'))
runners = [
command_runner.SlurmCommandRunner(
(instance_info.external_ip or '', instance_info.ssh_port),
ssh_user,
ssh_private_key,
sky_dir=_sky_cluster_home_dir(cluster_name_on_cloud),
skypilot_runtime_dir=_skypilot_runtime_dir(cluster_name_on_cloud),
job_id=instance_info.tags['job_id'],
slurm_node=instance_info.tags['node'],
**credentials) for instance_info in instances
]

return runners

+ 583
- 0
sky/provision/slurm/utils.py View File

@@ -0,0 +1,583 @@
"""Slurm utilities for SkyPilot."""
import math
import os
import re
from typing import Any, Dict, List, Optional, Tuple, Union

from paramiko.config import SSHConfig

from sky import exceptions
from sky import sky_logging
from sky.adaptors import slurm
from sky.utils import annotations
from sky.utils import common_utils

logger = sky_logging.init_logger(__name__)

# TODO(jwj): Choose commonly used default values.
DEFAULT_SLURM_PATH = '~/.slurm/config'
DEFAULT_CLUSTER_NAME = 'localcluster'
DEFAULT_PARTITION = 'dev'


def get_slurm_ssh_config() -> SSHConfig:
"""Get the Slurm SSH config."""
slurm_config_path = os.path.expanduser(DEFAULT_SLURM_PATH)
slurm_config = SSHConfig.from_path(slurm_config_path)
return slurm_config


class SlurmInstanceType:
"""Class to represent the "Instance Type" in a Slurm cluster.

Since Slurm does not have a notion of instances, we generate
virtual instance types that represent the resources requested by a
Slurm worker node.

This name captures the following resource requests:
- CPU
- Memory
- Accelerators

The name format is "{n}CPU--{k}GB" where n is the number of vCPUs and
k is the amount of memory in GB. Accelerators can be specified by
appending "--{type}:{a}" where type is the accelerator type and a
is the number of accelerators.
CPU and memory can be specified as floats. Accelerator count must be int.

Examples:
- 4CPU--16GB
- 0.5CPU--1.5GB
- 4CPU--16GB--V100:1
"""

def __init__(self,
cpus: float,
memory: float,
accelerator_count: Optional[int] = None,
accelerator_type: Optional[str] = None):
self.cpus = cpus
self.memory = memory
self.accelerator_count = accelerator_count
self.accelerator_type = accelerator_type

@property
def name(self) -> str:
"""Returns the name of the instance."""
assert self.cpus is not None
assert self.memory is not None
name = (f'{common_utils.format_float(self.cpus)}CPU--'
f'{common_utils.format_float(self.memory)}GB')
if self.accelerator_count is not None:
# Replace spaces with underscores in accelerator type to make it a
# valid logical instance type name.
assert self.accelerator_type is not None, self.accelerator_count
acc_name = self.accelerator_type.replace(' ', '_')
name += f'--{acc_name}:{self.accelerator_count}'
return name

@staticmethod
def is_valid_instance_type(name: str) -> bool:
"""Returns whether the given name is a valid instance type."""
pattern = re.compile(
r'^(\d+(\.\d+)?CPU--\d+(\.\d+)?GB)(--[\w\d-]+:\d+)?$')
return bool(pattern.match(name))

@classmethod
def _parse_instance_type(
cls,
name: str) -> Tuple[float, float, Optional[int], Optional[str]]:
"""Parses and returns resources from the given InstanceType name.

Returns:
cpus | float: Number of CPUs
memory | float: Amount of memory in GB
accelerator_count | float: Number of accelerators
accelerator_type | str: Type of accelerator
"""
pattern = re.compile(
r'^(?P<cpus>\d+(\.\d+)?)CPU--(?P<memory>\d+(\.\d+)?)GB(?:--(?P<accelerator_type>[\w\d-]+):(?P<accelerator_count>\d+))?$' # pylint: disable=line-too-long
)
match = pattern.match(name)
if match is not None:
cpus = float(match.group('cpus'))
memory = float(match.group('memory'))
accelerator_count = match.group('accelerator_count')
accelerator_type = match.group('accelerator_type')
if accelerator_count is not None:
accelerator_count = int(accelerator_count)
# This is to revert the accelerator types with spaces back to
# the original format.
accelerator_type = str(accelerator_type).replace(' ', '_')
else:
accelerator_count = None
accelerator_type = None
return cpus, memory, accelerator_count, accelerator_type
else:
raise ValueError(f'Invalid instance name: {name}')

@classmethod
def from_instance_type(cls, name: str) -> 'SlurmInstanceType':
"""Returns an instance name object from the given name."""
if not cls.is_valid_instance_type(name):
raise ValueError(f'Invalid instance name: {name}')
cpus, memory, accelerator_count, accelerator_type = \
cls._parse_instance_type(name)
return cls(cpus=cpus,
memory=memory,
accelerator_count=accelerator_count,
accelerator_type=accelerator_type)

@classmethod
def from_resources(cls,
cpus: float,
memory: float,
accelerator_count: Union[float, int] = 0,
accelerator_type: str = '') -> 'SlurmInstanceType':
"""Returns an instance name object from the given resources.

If accelerator_count is not an int, it will be rounded up since GPU
requests in Slurm must be int.

NOTE: Should we take MIG management into account? See
https://slurm.schedmd.com/gres.html#MIG_Management.
"""
name = f'{cpus}CPU--{memory}GB'
# Round up accelerator_count if it is not an int.
accelerator_count = math.ceil(accelerator_count)
if accelerator_count > 0:
name += f'--{accelerator_type}:{accelerator_count}'
return cls(cpus=cpus,
memory=memory,
accelerator_count=accelerator_count,
accelerator_type=accelerator_type)

def __str__(self):
return self.name

def __repr__(self):
return (f'SlurmInstanceType(cpus={self.cpus!r}, '
f'memory={self.memory!r}, '
f'accelerator_count={self.accelerator_count!r}, '
f'accelerator_type={self.accelerator_type!r})')


def instance_id(job_id: str, node: str) -> str:
"""Generates the SkyPilot-defined instance ID for Slurm.

A (job id, node) pair is unique within a Slurm cluster.
"""
return f'job{job_id}-{node}'


def get_cluster_name_from_config(provider_config: Dict[str, Any]) -> str:
"""Return the cluster name from the provider config.

The concept of cluster can be mapped to a cloud region.
"""
return provider_config.get('cluster', DEFAULT_CLUSTER_NAME)


def get_partition_from_config(provider_config: Dict[str, Any]) -> str:
"""Return the partition from the provider config.

The concept of partition can be mapped to a cloud zone.
"""
return provider_config.get('partition', DEFAULT_PARTITION)


@annotations.lru_cache(scope='request')
def get_cluster_default_partition(cluster_name: str) -> str:
"""Get the default partition for a Slurm cluster.

Queries the Slurm cluster for the partition marked with an asterisk (*)
in sinfo output. Falls back to DEFAULT_PARTITION if the query fails or
no default partition is found.

Args:
cluster_name: Name of the Slurm cluster.

Returns:
The default partition name for the cluster.
"""
try:
ssh_config = get_slurm_ssh_config()
ssh_config_dict = ssh_config.lookup(cluster_name)
except Exception as e:
raise ValueError(
f'Failed to load SSH configuration from {DEFAULT_SLURM_PATH}: '
f'{common_utils.format_exception(e)}') from e

client = slurm.SlurmClient(
ssh_config_dict['hostname'],
int(ssh_config_dict.get('port', 22)),
ssh_config_dict['user'],
ssh_config_dict['identityfile'][0],
ssh_proxy_command=ssh_config_dict.get('proxycommand', None),
)

default_partition = client.get_default_partition()
if default_partition is None:
# TODO(kevin): Have a way to specify default partition in
# ~/.sky/config.yaml if needed, in case a Slurm cluster
# really does not have a default partition.
raise ValueError('No default partition found for cluster '
f'{cluster_name}.')
return default_partition


def get_all_slurm_cluster_names() -> List[str]:
"""Get all Slurm cluster names available in the environment.

Returns:
List[str]: The list of Slurm cluster names if available,
an empty list otherwise.
"""
try:
ssh_config = get_slurm_ssh_config()
except FileNotFoundError:
return []
except Exception as e:
raise ValueError(
f'Failed to load SSH configuration from {DEFAULT_SLURM_PATH}: '
f'{common_utils.format_exception(e)}') from e

cluster_names = []
for cluster in ssh_config.get_hostnames():
if cluster == '*':
continue

cluster_names.append(cluster)

return cluster_names


def _check_cpu_mem_fits(
candidate_instance_type: SlurmInstanceType,
node_list: List[slurm.NodeInfo]) -> Tuple[bool, Optional[str]]:
"""Checks if instance fits on candidate nodes based on CPU and memory.

We check capacity (not allocatable) because availability can change
during scheduling, and we want to let the Slurm scheduler handle that.
"""
# We log max CPU and memory found on the GPU nodes for debugging.
max_cpu = 0
max_mem_gb = 0.0

for node_info in node_list:
node_cpus = node_info.cpus
node_mem_gb = node_info.memory_gb

if node_cpus > max_cpu:
max_cpu = node_cpus
max_mem_gb = node_mem_gb

if (node_cpus >= candidate_instance_type.cpus and
node_mem_gb >= candidate_instance_type.memory):
return True, None

return False, (f'Max found: {max_cpu} CPUs, '
f'{common_utils.format_float(max_mem_gb)}G memory')


def check_instance_fits(
cluster: str,
instance_type: str,
partition: Optional[str] = None) -> Tuple[bool, Optional[str]]:
"""Check if the given instance type fits in the given cluster/partition.

Args:
cluster: Name of the Slurm cluster.
instance_type: The instance type to check.
partition: Optional partition name. If None, checks all partitions.

Returns:
Tuple of (fits, reason) where fits is True if available.
"""
# Get Slurm node list in the given cluster (region).
try:
ssh_config = get_slurm_ssh_config()
except FileNotFoundError:
return (False, f'Could not query Slurm cluster {cluster} '
f'because the Slurm configuration file '
f'{DEFAULT_SLURM_PATH} does not exist.')
except Exception as e: # pylint: disable=broad-except
return (False, f'Could not query Slurm cluster {cluster} '
f'because Slurm SSH configuration at {DEFAULT_SLURM_PATH} '
f'could not be loaded: {common_utils.format_exception(e)}.')
ssh_config_dict = ssh_config.lookup(cluster)

client = slurm.SlurmClient(
ssh_config_dict['hostname'],
int(ssh_config_dict.get('port', 22)),
ssh_config_dict['user'],
ssh_config_dict['identityfile'][0],
ssh_proxy_command=ssh_config_dict.get('proxycommand', None),
)

nodes = client.info_nodes()
default_partition = get_cluster_default_partition(cluster)

def is_default_partition(node_partition: str) -> bool:
# info_nodes does not strip the '*' from the default partition name.
# But non-default partition names can also end with '*',
# so we need to check whether the partition name without the '*'
# is the same as the default partition name.
return (node_partition.endswith('*') and
node_partition[:-1] == default_partition)

partition_suffix = ''
if partition is not None:
filtered = []
for node_info in nodes:
node_partition = node_info.partition
if is_default_partition(node_partition):
# Strip '*' from default partition name.
node_partition = node_partition[:-1]
if node_partition == partition:
filtered.append(node_info)
nodes = filtered
partition_suffix = f' in partition {partition}'

slurm_instance_type = SlurmInstanceType.from_instance_type(instance_type)
acc_count = (slurm_instance_type.accelerator_count
if slurm_instance_type.accelerator_count is not None else 0)
acc_type = slurm_instance_type.accelerator_type
candidate_nodes = nodes
not_fit_reason_prefix = (
f'No nodes found with enough '
f'CPU (> {slurm_instance_type.cpus} CPUs) and/or '
f'memory (> {slurm_instance_type.memory} G){partition_suffix}. ')
if acc_type is not None:
assert acc_count is not None, (acc_type, acc_count)

gpu_nodes = []
# GRES string format: 'gpu:acc_type:acc_count(optional_extra_info)'
# Examples:
# - gpu:nvidia_h100_80gb_hbm3:8(S:0-1)
# - gpu:a10g:8
# - gpu:l4:1
gres_pattern = re.compile(r'^gpu:([^:]+):(\d+)')
for node_info in nodes:
gres_str = node_info.gres
# Extract the GPU type and count from the GRES string
match = gres_pattern.match(gres_str)
if not match:
continue

node_acc_type = match.group(1).lower()
node_acc_count = int(match.group(2))

# TODO(jwj): Handle status check.

# Check if the node has the requested GPU type and at least the
# requested count
if (node_acc_type == acc_type.lower() and
node_acc_count >= acc_count):
gpu_nodes.append(node_info)
if len(gpu_nodes) == 0:
return (False,
f'No GPU nodes found with at least {acc_type}:{acc_count} '
f'on the cluster.')

candidate_nodes = gpu_nodes
not_fit_reason_prefix = (
f'GPU nodes with {acc_type}{partition_suffix} do not have '
f'enough CPU (> {slurm_instance_type.cpus} CPUs) and/or '
f'memory (> {slurm_instance_type.memory} G). ')

# Check if CPU and memory requirements are met on at least one
# candidate node.
fits, reason = _check_cpu_mem_fits(slurm_instance_type, candidate_nodes)
if not fits and reason is not None:
reason = not_fit_reason_prefix + reason
return fits, reason


def _get_slurm_node_info_list(
slurm_cluster_name: Optional[str] = None) -> List[Dict[str, Any]]:
"""Gathers detailed information about each node in the Slurm cluster.

Raises:
FileNotFoundError: If the Slurm configuration file does not exist.
ValueError: If no Slurm cluster name is found in the Slurm
configuration file.
"""
# 1. Get node state and GRES using sinfo

# can raise FileNotFoundError if config file does not exist.
slurm_config = get_slurm_ssh_config()
if slurm_cluster_name is None:
slurm_cluster_names = get_all_slurm_cluster_names()
if slurm_cluster_names:
slurm_cluster_name = slurm_cluster_names[0]
if slurm_cluster_name is None:
raise ValueError(
f'No Slurm cluster name found in the {DEFAULT_SLURM_PATH} '
f'configuration.')
slurm_config_dict = slurm_config.lookup(slurm_cluster_name)
logger.debug(f'Slurm config dict: {slurm_config_dict}')
slurm_client = slurm.SlurmClient(
slurm_config_dict['hostname'],
int(slurm_config_dict.get('port', 22)),
slurm_config_dict['user'],
slurm_config_dict['identityfile'][0],
ssh_proxy_command=slurm_config_dict.get('proxycommand', None),
)
node_infos = slurm_client.info_nodes()

if not node_infos:
logger.warning(
f'`sinfo -N` returned no output on cluster {slurm_cluster_name}. '
f'No nodes found?')
return []

# 2. Process each node, aggregating partitions per node
slurm_nodes_info: Dict[str, Dict[str, Any]] = {}
gres_gpu_pattern = re.compile(r'((gpu)(?::([^:]+))?:(\d+))')

for node_info in node_infos:
node_name = node_info.node
state = node_info.state
gres_str = node_info.gres
partition = node_info.partition

if node_name in slurm_nodes_info:
slurm_nodes_info[node_name]['partitions'].append(partition)
continue

# Extract GPU info from GRES
gres_match = gres_gpu_pattern.search(gres_str)

total_gpus = 0
gpu_type_from_sinfo = None # Default to None for CPU-only nodes
if gres_match:
try:
total_gpus = int(gres_match.group(4))
if gres_match.group(3):
gpu_type_from_sinfo = gres_match.group(3).upper()
# If total_gpus > 0 but no type, default to 'GPU'
elif total_gpus > 0:
gpu_type_from_sinfo = 'GPU'
except ValueError:
logger.warning(
f'Could not parse GPU count from GRES for {node_name}.')

# Get allocated GPUs via squeue
allocated_gpus = 0
# TODO(zhwu): move to enum
if state in ('alloc', 'mix', 'drain', 'drng', 'drained', 'resv',
'comp'):
try:
node_jobs = slurm_client.get_node_jobs(node_name)
if node_jobs:
job_gres_pattern = re.compile(r'gpu(?::[^:]+)*:(\d+)')
for job_line in node_jobs:
gres_job_match = job_gres_pattern.search(job_line)
if gres_job_match:
allocated_gpus += int(gres_job_match.group(1))
except Exception as e: # pylint: disable=broad-except
if state == 'alloc':
# We can infer allocated GPUs only if the node is
# in 'alloc' state.
allocated_gpus = total_gpus
else:
# Otherwise, just raise the error.
raise e
elif state == 'idle':
allocated_gpus = 0

free_gpus = total_gpus - allocated_gpus if state not in ('down',
'drain',
'drng',
'maint') else 0
free_gpus = max(0, free_gpus)

# Get CPU/Mem info via scontrol
vcpu_total = 0
mem_gb = 0.0
try:
node_details = slurm_client.node_details(node_name)
vcpu_total = int(node_details.get('CPUTot', '0'))
mem_gb = float(node_details.get('RealMemory', '0')) / 1024.0
except Exception as e: # pylint: disable=broad-except
logger.warning(
f'Failed to get CPU/memory info for {node_name}: {e}')

slurm_nodes_info[node_name] = {
'node_name': node_name,
'slurm_cluster_name': slurm_cluster_name,
'partitions': [partition],
'node_state': state,
'gpu_type': gpu_type_from_sinfo,
'total_gpus': total_gpus,
'free_gpus': free_gpus,
'vcpu_count': vcpu_total,
'memory_gb': round(mem_gb, 2),
}

for node_info in slurm_nodes_info.values():
partitions = node_info.pop('partitions')
node_info['partition'] = ','.join(str(p) for p in partitions)

return list(slurm_nodes_info.values())


def slurm_node_info(
slurm_cluster_name: Optional[str] = None) -> List[Dict[str, Any]]:
"""Gets detailed information for each node in the Slurm cluster.

Returns:
List[Dict[str, Any]]: A list of dictionaries, each containing node info.
"""
try:
node_list = _get_slurm_node_info_list(
slurm_cluster_name=slurm_cluster_name)
except (RuntimeError, exceptions.NotSupportedError) as e:
logger.debug(f'Could not retrieve Slurm node info: {e}')
return []
return node_list


def is_inside_slurm_job() -> bool:
return os.environ.get('SLURM_JOB_ID') is not None


def get_partitions(cluster_name: str) -> List[str]:
"""Get unique partition names available in a Slurm cluster.

Args:
cluster_name: Name of the Slurm cluster.

Returns:
List of unique partition names available in the cluster.
The default partition appears first,
and the rest are sorted alphabetically.
"""
try:
slurm_config = SSHConfig.from_path(
os.path.expanduser(DEFAULT_SLURM_PATH))
slurm_config_dict = slurm_config.lookup(cluster_name)

client = slurm.SlurmClient(
slurm_config_dict['hostname'],
int(slurm_config_dict.get('port', 22)),
slurm_config_dict['user'],
slurm_config_dict['identityfile'][0],
ssh_proxy_command=slurm_config_dict.get('proxycommand', None),
)

partitions_info = client.get_partitions_info()
default_partitions = []
other_partitions = []
for partition in partitions_info:
if partition.is_default:
default_partitions.append(partition.name)
else:
other_partitions.append(partition.name)
return default_partitions + sorted(other_partitions)
except Exception as e: # pylint: disable=broad-except
logger.warning(
f'Failed to get partitions for cluster {cluster_name}: {e}')
# Fall back to default partition if query fails
return [DEFAULT_PARTITION]

+ 4
- 1
sky/provision/vast/instance.py View File

@@ -89,6 +89,7 @@ def run_instances(region: str, cluster_name: str, cluster_name_on_cloud: str,
resumed_instance_ids=[],
created_instance_ids=[])

secure_only = config.provider_config.get('secure_only', False)
for _ in range(to_start_count):
node_type = 'head' if head_instance_id is None else 'worker'
try:
@@ -99,7 +100,9 @@ def run_instances(region: str, cluster_name: str, cluster_name_on_cloud: str,
disk_size=config.node_config['DiskSize'],
preemptible=config.node_config['Preemptible'],
image_name=config.node_config['ImageId'],
ports=config.ports_to_open_on_launch)
ports=config.ports_to_open_on_launch,
secure_only=secure_only,
)
except Exception as e: # pylint: disable=broad-except
logger.warning(f'run_instances error: {e}')
raise


+ 10
- 6
sky/provision/vast/utils.py View File

@@ -34,8 +34,8 @@ def list_instances() -> Dict[str, Dict[str, Any]]:


def launch(name: str, instance_type: str, region: str, disk_size: int,
image_name: str, ports: Optional[List[int]],
preemptible: bool) -> str:
image_name: str, ports: Optional[List[int]], preemptible: bool,
secure_only: bool) -> str:
"""Launches an instance with the given parameters.

Converts the instance_type to the Vast GPU name, finds the specs for the
@@ -87,7 +87,7 @@ def launch(name: str, instance_type: str, region: str, disk_size: int,
gpu_name = instance_type.split('-')[1].replace('_', ' ')
num_gpus = int(instance_type.split('-')[0].replace('x', ''))

query = ' '.join([
query = [
'chunked=true',
'georegion=true',
f'geolocation="{region[-2:]}"',
@@ -95,13 +95,17 @@ def launch(name: str, instance_type: str, region: str, disk_size: int,
f'num_gpus={num_gpus}',
f'gpu_name="{gpu_name}"',
f'cpu_ram>="{cpu_ram}"',
])
]
if secure_only:
query.append('datacenter=true')
query_str = ' '.join(query)

instance_list = vast.vast().search_offers(query=query)
instance_list = vast.vast().search_offers(query=query_str)

if isinstance(instance_list, int) or len(instance_list) == 0:
raise RuntimeError('Failed to create instances, could not find an '
f'offer that satisfies the requirements "{query}".')
'offer that satisfies the requirements '
f'"{query_str}".')

instance_touse = instance_list[0]



+ 1
- 1
sky/serve/server/impl.py View File

@@ -517,7 +517,7 @@ def update(
f'{workers} is not supported. Ignoring the update.')

# Load the existing task configuration from the service's YAML file
yaml_content = service_record['yaml_content']
yaml_content = service_record['pool_yaml']

# Load the existing task configuration
task = task_lib.Task.from_yaml_str(yaml_content)


+ 1
- 1
sky/server/constants.py View File

@@ -10,7 +10,7 @@ from sky.skylet import constants
# based on version info is needed.
# For more details and code guidelines, refer to:
# https://docs.skypilot.co/en/latest/developers/CONTRIBUTING.html#backward-compatibility-guidelines
API_VERSION = 24
API_VERSION = 25

# The minimum peer API version that the code should still work with.
# Notes (dev):


+ 222
- 0
sky/server/plugins.py View File

@@ -0,0 +1,222 @@
"""Load plugins for the SkyPilot API server."""
import abc
import dataclasses
import importlib
import os
from typing import Dict, List, Optional, Tuple

from fastapi import FastAPI

from sky import sky_logging
from sky.skylet import constants as skylet_constants
from sky.utils import common_utils
from sky.utils import config_utils
from sky.utils import yaml_utils

logger = sky_logging.init_logger(__name__)

_DEFAULT_PLUGINS_CONFIG_PATH = '~/.sky/plugins.yaml'
_PLUGINS_CONFIG_ENV_VAR = (
f'{skylet_constants.SKYPILOT_SERVER_ENV_VAR_PREFIX}PLUGINS_CONFIG')


class ExtensionContext:
"""Context provided to plugins during installation.

Attributes:
app: The FastAPI application instance.
rbac_rules: List of RBAC rules registered by the plugin.
Example:
[
('user', RBACRule(path='/plugins/api/xx/*', method='POST')),
('user', RBACRule(path='/plugins/api/xx/*', method='DELETE'))
]
"""

def __init__(self, app: Optional[FastAPI] = None):
self.app = app
self.rbac_rules: List[Tuple[str, RBACRule]] = []

def register_rbac_rule(self,
path: str,
method: str,
description: Optional[str] = None,
role: str = 'user') -> None:
"""Register an RBAC rule for this plugin.

This method allows plugins to declare which endpoints should be
restricted to admin users during the install phase.

Args:
path: The path pattern to restrict (supports wildcards with
keyMatch2).
Example: '/plugins/api/credentials/*'
method: The HTTP method to restrict. Example: 'POST', 'DELETE'
description: Optional description of what this rule protects.
role: The role to add this rule to (default: 'user').
Rules added to 'user' role block regular users but allow
admins.

Example:
def install(self, ctx: ExtensionContext):
# Only admin can upload credentials
ctx.register_rbac_rule(
path='/plugins/api/credentials/*',
method='POST',
description='Only admin can upload credentials'
)
"""
rule = RBACRule(path=path, method=method, description=description)
self.rbac_rules.append((role, rule))
logger.debug(f'Registered RBAC rule for {role}: {method} {path}'
f'{f" - {description}" if description else ""}')


@dataclasses.dataclass
class RBACRule:
"""RBAC rule for a plugin endpoint.

Attributes:
path: The path pattern to match (supports wildcards with keyMatch2).
Example: '/plugins/api/credentials/*'
method: The HTTP method to restrict. Example: 'POST', 'DELETE'
description: Optional description of what this rule protects.
"""
path: str
method: str
description: Optional[str] = None


class BasePlugin(abc.ABC):
"""Base class for all SkyPilot server plugins."""

@property
def js_extension_path(self) -> Optional[str]:
"""Optional API route to the JavaScript extension to load."""
return None

@abc.abstractmethod
def install(self, extension_context: ExtensionContext):
"""Hook called by API server to let the plugin install itself."""
raise NotImplementedError

def shutdown(self):
"""Hook called by API server to let the plugin shutdown."""
pass


def _config_schema():
plugin_schema = {
'type': 'object',
'required': ['class'],
'additionalProperties': False,
'properties': {
'class': {
'type': 'string',
},
'parameters': {
'type': 'object',
'required': [],
'additionalProperties': True,
},
},
}
return {
'type': 'object',
'required': [],
'additionalProperties': False,
'properties': {
'plugins': {
'type': 'array',
'items': plugin_schema,
'default': [],
},
},
}


def _load_plugin_config() -> Optional[config_utils.Config]:
"""Load plugin config."""
config_path = os.getenv(_PLUGINS_CONFIG_ENV_VAR,
_DEFAULT_PLUGINS_CONFIG_PATH)
config_path = os.path.expanduser(config_path)
if not os.path.exists(config_path):
return None
config = yaml_utils.read_yaml(config_path) or {}
common_utils.validate_schema(config,
_config_schema(),
err_msg_prefix='Invalid plugins config: ')
return config_utils.Config.from_dict(config)


_PLUGINS: Dict[str, BasePlugin] = {}
_EXTENSION_CONTEXT: Optional[ExtensionContext] = None


def load_plugins(extension_context: ExtensionContext):
"""Load and initialize plugins from the config."""
global _EXTENSION_CONTEXT
_EXTENSION_CONTEXT = extension_context

config = _load_plugin_config()
if not config:
return

for plugin_config in config.get('plugins', []):
class_path = plugin_config['class']
module_path, class_name = class_path.rsplit('.', 1)
try:
module = importlib.import_module(module_path)
except ImportError as e:
raise ImportError(
f'Failed to import plugin module: {module_path}. '
'Please check if the module is installed in your Python '
'environment.') from e
try:
plugin_cls = getattr(module, class_name)
except AttributeError as e:
raise AttributeError(
f'Could not find plugin {class_name} class in module '
f'{module_path}. ') from e
if not issubclass(plugin_cls, BasePlugin):
raise TypeError(
f'Plugin {class_path} must inherit from BasePlugin.')
parameters = plugin_config.get('parameters') or {}
plugin = plugin_cls(**parameters)
plugin.install(extension_context)
_PLUGINS[class_path] = plugin


def get_plugins() -> List[BasePlugin]:
"""Return shallow copies of the registered plugins."""
return list(_PLUGINS.values())


def get_plugin_rbac_rules() -> Dict[str, List[Dict[str, str]]]:
"""Collect RBAC rules from all loaded plugins.

Collects rules from the ExtensionContext.

Returns:
Dictionary mapping role names to lists of blocklist rules.
Example:
{
'user': [
{'path': '/plugins/api/credentials/*', 'method': 'POST'},
{'path': '/plugins/api/credentials/*', 'method': 'DELETE'}
]
}
"""
rules_by_role: Dict[str, List[Dict[str, str]]] = {}

# Collect rules registered via ExtensionContext
if _EXTENSION_CONTEXT:
for role, rule in _EXTENSION_CONTEXT.rbac_rules:
if role not in rules_by_role:
rules_by_role[role] = []
rules_by_role[role].append({
'path': rule.path,
'method': rule.method,
})

return rules_by_role

+ 5
- 2
sky/server/requests/executor.py View File

@@ -44,6 +44,7 @@ from sky.server import common as server_common
from sky.server import config as server_config
from sky.server import constants as server_constants
from sky.server import metrics as metrics_lib
from sky.server import plugins
from sky.server.requests import payloads
from sky.server.requests import preconditions
from sky.server.requests import process
@@ -159,6 +160,8 @@ queue_backend = server_config.QueueBackend.MULTIPROCESSING
def executor_initializer(proc_group: str):
setproctitle.setproctitle(f'SkyPilot:executor:{proc_group}:'
f'{multiprocessing.current_process().pid}')
# Load plugins for executor process.
plugins.load_plugins(plugins.ExtensionContext())
# Executor never stops, unless the whole process is killed.
threading.Thread(target=metrics_lib.process_monitor,
args=(f'worker:{proc_group}', threading.Event()),
@@ -533,8 +536,8 @@ def _request_execution_wrapper(request_id: str,
# so that the "Request xxxx failed due to ..." log message will be
# written to the original stdout and stderr file descriptors.
_restore_output()
logger.info(f'Request {request_id} failed due to '
f'{common_utils.format_exception(e)}')
logger.error(f'Request {request_id} failed due to '
f'{common_utils.format_exception(e)}')
return
else:
api_requests.set_request_succeeded(


+ 12
- 1
sky/server/requests/payloads.py View File

@@ -82,7 +82,7 @@ def request_body_env_vars() -> dict:
if common.is_api_server_local() and env_var in EXTERNAL_LOCAL_ENV_VARS:
env_vars[env_var] = os.environ[env_var]
env_vars[constants.USER_ID_ENV_VAR] = common_utils.get_user_hash()
env_vars[constants.USER_ENV_VAR] = common_utils.get_current_user_name()
env_vars[constants.USER_ENV_VAR] = common_utils.get_local_user_name()
env_vars[
usage_constants.USAGE_RUN_ID_ENV_VAR] = usage_lib.messages.usage.run_id
if not common.is_api_server_local():
@@ -670,6 +670,11 @@ class KubernetesNodeInfoRequestBody(RequestBody):
context: Optional[str] = None


class SlurmNodeInfoRequestBody(RequestBody):
"""The request body for the slurm node info endpoint."""
slurm_cluster_name: Optional[str] = None


class ListAcceleratorsBody(RequestBody):
"""The request body for the list accelerators endpoint."""
gpus_only: bool = True
@@ -854,3 +859,9 @@ class RequestPayload(BasePayload):
status_msg: Optional[str] = None
should_retry: bool = False
finished_at: Optional[float] = None


class SlurmGpuAvailabilityRequestBody(RequestBody):
"""Request body for getting Slurm real-time GPU availability."""
name_filter: Optional[str] = None
quantity_filter: Optional[int] = None

+ 2
- 0
sky/server/requests/request_names.py View File

@@ -10,6 +10,8 @@ class RequestName(str, enum.Enum):
REALTIME_KUBERNETES_GPU_AVAILABILITY = (
'realtime_kubernetes_gpu_availability')
KUBERNETES_NODE_INFO = 'kubernetes_node_info'
REALTIME_SLURM_GPU_AVAILABILITY = 'realtime_slurm_gpu_availability'
SLURM_NODE_INFO = 'slurm_node_info'
STATUS_KUBERNETES = 'status_kubernetes'
LIST_ACCELERATORS = 'list_accelerators'
LIST_ACCELERATOR_COUNTS = 'list_accelerator_counts'


+ 5
- 1
sky/server/requests/requests.py View File

@@ -33,6 +33,7 @@ from sky.server import daemons
from sky.server.requests import payloads
from sky.server.requests.serializers import decoders
from sky.server.requests.serializers import encoders
from sky.server.requests.serializers import return_value_serializers
from sky.utils import asyncio_utils
from sky.utils import common_utils
from sky.utils import ux_utils
@@ -231,13 +232,16 @@ class Request:
assert isinstance(self.request_body,
payloads.RequestBody), (self.name, self.request_body)
try:
# Use version-aware serializer to handle backward compatibility
# for old clients that don't recognize new fields.
serializer = return_value_serializers.get_serializer(self.name)
return payloads.RequestPayload(
request_id=self.request_id,
name=self.name,
entrypoint=encoders.pickle_and_encode(self.entrypoint),
request_body=encoders.pickle_and_encode(self.request_body),
status=self.status.value,
return_value=orjson.dumps(self.return_value).decode('utf-8'),
return_value=serializer(self.return_value),
error=orjson.dumps(self.error).decode('utf-8'),
pid=self.pid,
created_at=self.created_at,


+ 17
- 0
sky/server/requests/serializers/encoders.py View File

@@ -266,6 +266,23 @@ def encode_realtime_gpu_availability(
return encoded


@register_encoder('realtime_slurm_gpu_availability')
def encode_realtime_slurm_gpu_availability(
return_value: List[Tuple[str,
List[Any]]]) -> List[Tuple[str, List[List[Any]]]]:
# Convert RealtimeGpuAvailability namedtuples to lists
# for JSON serialization.
encoded = []
for context, gpu_list in return_value:
converted_gpu_list = []
for gpu in gpu_list:
assert isinstance(gpu, models.RealtimeGpuAvailability), (
f'Expected RealtimeGpuAvailability, got {type(gpu)}')
converted_gpu_list.append(list(gpu))
encoded.append((context, converted_gpu_list))
return encoded


@register_encoder('list_accelerators')
def encode_list_accelerators(
return_value: Dict[str, List[Any]]) -> Dict[str, Any]:


+ 60
- 0
sky/server/requests/serializers/return_value_serializers.py View File

@@ -0,0 +1,60 @@
"""Version-aware serializers for request return values.

These serializers run at encode() time when remote_api_version is available,
to handle backward compatibility for old clients.

The existing encoders.py handles object -> dict conversion at set_return_value()
time. This module handles dict -> JSON string serialization at encode() time,
with version-aware field filtering for backward compatibility.
"""
from typing import Any, Callable, Dict

import orjson

from sky.server import constants as server_constants
from sky.server import versions

handlers: Dict[str, Callable[[Any], str]] = {}


def register_serializer(*names: str):
"""Decorator to register a version-aware serializer."""

def decorator(func):
for name in names:
if name != server_constants.DEFAULT_HANDLER_NAME:
name = server_constants.REQUEST_NAME_PREFIX + name
if name in handlers:
raise ValueError(f'Serializer {name} already registered: '
f'{handlers[name]}')
handlers[name] = func
return func

return decorator


def get_serializer(name: str) -> Callable[[Any], str]:
"""Get the serializer for a request name."""
return handlers.get(name, handlers[server_constants.DEFAULT_HANDLER_NAME])


@register_serializer(server_constants.DEFAULT_HANDLER_NAME)
def default_serializer(return_value: Any) -> str:
"""The default serializer."""
return orjson.dumps(return_value).decode('utf-8')


@register_serializer('kubernetes_node_info')
def serialize_kubernetes_node_info(return_value: Dict[str, Any]) -> str:
"""Serialize kubernetes node info with version compatibility.

The is_ready field was added in API version 25. Remove it for old clients
that don't recognize it.
"""
remote_api_version = versions.get_remote_api_version()
if (return_value and remote_api_version is not None and
remote_api_version < 25):
# Remove is_ready field for old clients that don't recognize it
for node_info in return_value.get('node_info_dict', {}).values():
node_info.pop('is_ready', None)
return orjson.dumps(return_value).decode('utf-8')

+ 78
- 8
sky/server/server.py View File

@@ -20,7 +20,7 @@ import struct
import sys
import threading
import traceback
from typing import Dict, List, Literal, Optional, Set, Tuple
from typing import Any, Dict, List, Literal, Optional, Set, Tuple
import uuid
import zipfile

@@ -48,6 +48,7 @@ from sky.jobs.server import server as jobs_rest
from sky.metrics import utils as metrics_utils
from sky.provision import metadata_utils
from sky.provision.kubernetes import utils as kubernetes_utils
from sky.provision.slurm import utils as slurm_utils
from sky.schemas.api import responses
from sky.serve.server import server as serve_rest
from sky.server import common
@@ -56,6 +57,8 @@ from sky.server import constants as server_constants
from sky.server import daemons
from sky.server import metrics
from sky.server import middleware_utils
from sky.server import plugins
from sky.server import server_utils
from sky.server import state
from sky.server import stream_utils
from sky.server import versions
@@ -470,7 +473,8 @@ async def schedule_on_boot_check_async():
await executor.schedule_request_async(
request_id='skypilot-server-on-boot-check',
request_name=request_names.RequestName.CHECK,
request_body=payloads.CheckBody(),
request_body=server_utils.build_body_at_server(
request=None, body_type=payloads.CheckBody),
func=sky_check.check,
schedule_type=requests_lib.ScheduleType.SHORT,
is_skypilot_system=True,
@@ -493,7 +497,8 @@ async def lifespan(app: fastapi.FastAPI): # pylint: disable=redefined-outer-nam
await executor.schedule_request_async(
request_id=event.id,
request_name=event.name,
request_body=payloads.RequestBody(),
request_body=server_utils.build_body_at_server(
request=None, body_type=payloads.RequestBody),
func=event.run_event,
schedule_type=requests_lib.ScheduleType.SHORT,
is_skypilot_system=True,
@@ -652,6 +657,17 @@ app.add_middleware(BearerTokenMiddleware)
# middleware above.
app.add_middleware(InitializeRequestAuthUserMiddleware)
app.add_middleware(RequestIDMiddleware)

# Load plugins after all the middlewares are added, to keep the core
# middleware stack intact if a plugin adds new middlewares.
# Note: server.py will be imported twice in server process, once as
# the top-level entrypoint module and once imported by uvicorn, we only
# load the plugin when imported by uvicorn for server process.
# TODO(aylei): move uvicorn app out of the top-level module to avoid
# duplicate app initialization.
if __name__ == 'sky.server.server':
plugins.load_plugins(plugins.ExtensionContext(app=app))

app.include_router(jobs_rest.router, prefix='/jobs', tags=['jobs'])
app.include_router(serve_rest.router, prefix='/serve', tags=['serve'])
app.include_router(users_rest.router, prefix='/users', tags=['users'])
@@ -746,8 +762,11 @@ async def enabled_clouds(request: fastapi.Request,
await executor.schedule_request_async(
request_id=request.state.request_id,
request_name=request_names.RequestName.ENABLED_CLOUDS,
request_body=payloads.EnabledCloudsBody(workspace=workspace,
expand=expand),
request_body=server_utils.build_body_at_server(
request=request,
body_type=payloads.EnabledCloudsBody,
workspace=workspace,
expand=expand),
func=core.enabled_clouds,
schedule_type=requests_lib.ScheduleType.SHORT,
)
@@ -784,6 +803,35 @@ async def kubernetes_node_info(
)


@app.post('/slurm_gpu_availability')
async def slurm_gpu_availability(
request: fastapi.Request,
slurm_gpu_availability_body: payloads.SlurmGpuAvailabilityRequestBody
) -> None:
"""Gets real-time Slurm GPU availability."""
await executor.schedule_request_async(
request_id=request.state.request_id,
request_name=request_names.RequestName.REALTIME_SLURM_GPU_AVAILABILITY,
request_body=slurm_gpu_availability_body,
func=core.realtime_slurm_gpu_availability,
schedule_type=requests_lib.ScheduleType.SHORT,
)


@app.get('/slurm_node_info')
async def slurm_node_info(
request: fastapi.Request,
slurm_node_info_body: payloads.SlurmNodeInfoRequestBody) -> None:
"""Gets detailed information for each node in the Slurm cluster."""
await executor.schedule_request_async(
request_id=request.state.request_id,
request_name=request_names.RequestName.SLURM_NODE_INFO,
request_body=slurm_node_info_body,
func=slurm_utils.slurm_node_info,
schedule_type=requests_lib.ScheduleType.SHORT,
)


@app.get('/status_kubernetes')
async def status_kubernetes(request: fastapi.Request) -> None:
"""[Experimental] Get all SkyPilot resources (including from other '
@@ -791,7 +839,8 @@ async def status_kubernetes(request: fastapi.Request) -> None:
await executor.schedule_request_async(
request_id=request.state.request_id,
request_name=request_names.RequestName.STATUS_KUBERNETES,
request_body=payloads.RequestBody(),
request_body=server_utils.build_body_at_server(
request=request, body_type=payloads.RequestBody),
func=core.status_kubernetes,
schedule_type=requests_lib.ScheduleType.SHORT,
)
@@ -1460,7 +1509,8 @@ async def storage_ls(request: fastapi.Request) -> None:
await executor.schedule_request_async(
request_id=request.state.request_id,
request_name=request_names.RequestName.STORAGE_LS,
request_body=payloads.RequestBody(),
request_body=server_utils.build_body_at_server(
request=request, body_type=payloads.RequestBody),
func=core.storage_ls,
schedule_type=requests_lib.ScheduleType.SHORT,
)
@@ -1752,6 +1802,15 @@ async def api_status(
return encoded_request_tasks


@app.get('/api/plugins', response_class=fastapi_responses.ORJSONResponse)
async def list_plugins() -> Dict[str, List[Dict[str, Any]]]:
"""Return metadata about loaded backend plugins."""
plugin_info = [{
'js_extension_path': plugin.js_extension_path,
} for plugin in plugins.get_plugins()]
return {'plugins': plugin_info}


@app.get(
'/api/health',
# response_model_exclude_unset omits unset fields
@@ -2007,7 +2066,8 @@ async def all_contexts(request: fastapi.Request) -> None:
await executor.schedule_request_async(
request_id=request.state.request_id,
request_name=request_names.RequestName.ALL_CONTEXTS,
request_body=payloads.RequestBody(),
request_body=server_utils.build_body_at_server(
request=request, body_type=payloads.RequestBody),
func=core.get_all_contexts,
schedule_type=requests_lib.ScheduleType.SHORT,
)
@@ -2057,6 +2117,14 @@ async def serve_dashboard(full_path: str):
if os.path.isfile(file_path):
return fastapi.responses.FileResponse(file_path)

# Serve plugin catch-all page for any /plugins/* paths so client-side
# routing can bootstrap correctly.
if full_path == 'plugins' or full_path.startswith('plugins/'):
plugin_catchall = os.path.join(server_constants.DASHBOARD_DIR,
'plugins', '[...slug].html')
if os.path.isfile(plugin_catchall):
return fastapi.responses.FileResponse(plugin_catchall)

# Serve index.html for client-side routing
# e.g. /clusters, /jobs
index_path = os.path.join(server_constants.DASHBOARD_DIR, 'index.html')
@@ -2220,6 +2288,8 @@ if __name__ == '__main__':

for gt in global_tasks:
gt.cancel()
for plugin in plugins.get_plugins():
plugin.shutdown()
subprocess_utils.run_in_parallel(lambda worker: worker.cancel(),
workers,
num_threads=len(workers))


+ 30
- 0
sky/server/server_utils.py View File

@@ -0,0 +1,30 @@
"""Utilities for the API server."""

from typing import Optional, Type, TypeVar

import fastapi

from sky.server.requests import payloads
from sky.skylet import constants

_BodyT = TypeVar('_BodyT', bound=payloads.RequestBody)


# TODO(aylei): remove this and disable request body construction at server-side
def build_body_at_server(request: Optional[fastapi.Request],
body_type: Type[_BodyT], **data) -> _BodyT:
"""Builds the request body at the server.

For historical reasons, some handlers mimic a client request body
at server-side in order to coordinate with the interface of executor.
This will cause issues where the client info like user identity is not
respected in these handlers. This function is a helper to build the request
body at server-side with the auth user overridden.
"""
request_body = body_type(**data)
if request is not None:
auth_user = getattr(request.state, 'auth_user', None)
if auth_user:
request_body.env_vars[constants.USER_ID_ENV_VAR] = auth_user.id
request_body.env_vars[constants.USER_ENV_VAR] = auth_user.name
return request_body

+ 17
- 6
sky/setup_files/dependencies.py View File

@@ -84,6 +84,7 @@ install_requires = [
'bcrypt==4.0.1',
'pyjwt',
'gitpython',
'paramiko',
'types-paramiko',
'alembic',
'aiohttp',
@@ -203,12 +204,21 @@ cloud_dependencies: Dict[str, List[str]] = {
'ssh': kubernetes_dependencies,
# For the container registry auth api. Reference:
# https://github.com/runpod/runpod-python/releases/tag/1.6.1
# RunPod needs a TOML parser to read ~/.runpod/config.toml. On Python 3.11+
# stdlib provides tomllib; on lower versions we depend on tomli explicitly.
# Instead of installing tomli conditionally, we install it explicitly.
# This is because the conditional installation of tomli does not work
# with controller package installation code.
'runpod': ['runpod>=1.6.1', 'tomli'],
'runpod': [
# For the container registry auth api. Reference:
# https://github.com/runpod/runpod-python/releases/tag/1.6.1
'runpod>=1.6.1',
# RunPod needs a TOML parser to read ~/.runpod/config.toml. On Python
# 3.11+ stdlib provides tomllib; on lower versions we depend on tomli
# explicitly. Instead of installing tomli conditionally, we install it
# explicitly. This is because the conditional installation of tomli does
# not work with controller package installation code.
'tomli',
# runpod installs aiodns (via aiohttp[speedups]), which is incompatible
# with pycares 5.0.0 due to deprecations.
# See https://github.com/aio-libs/aiodns/issues/214
'pycares<5',
],
'fluidstack': [], # No dependencies needed for fluidstack
'cudo': ['cudo-compute>=0.1.10'],
'paperspace': [], # No dependencies needed for paperspace
@@ -234,6 +244,7 @@ cloud_dependencies: Dict[str, List[str]] = {
'hyperbolic': [], # No dependencies needed for hyperbolic
'seeweb': ['ecsapi==0.4.0'],
'shadeform': [], # No dependencies needed for shadeform
'slurm': [], # No dependencies needed for slurm
}

# Calculate which clouds should be included in the [all] installation.


+ 13
- 3
sky/skylet/attempt_skylet.py View File

@@ -9,6 +9,7 @@ import psutil

from sky.skylet import constants
from sky.skylet import runtime_utils
from sky.utils import common_utils

VERSION_FILE = runtime_utils.get_runtime_dir_path(constants.SKYLET_VERSION_FILE)
SKYLET_LOG_FILE = runtime_utils.get_runtime_dir_path(constants.SKYLET_LOG_FILE)
@@ -97,8 +98,13 @@ def restart_skylet():
for pid in _find_running_skylet_pids():
try:
os.kill(pid, signal.SIGKILL)
except (OSError, ProcessLookupError):
# Process died between detection and kill
# Wait until process fully terminates so its socket gets released.
# Without this, find_free_port may race with the kernel closing the
# socket and fail to bind to the port that's supposed to be free.
psutil.Process(pid).wait(timeout=5)
except (OSError, ProcessLookupError, psutil.NoSuchProcess,
psutil.TimeoutExpired):
# Process died between detection and kill, or timeout waiting
pass
# Clean up the PID file
try:
@@ -106,7 +112,11 @@ def restart_skylet():
except OSError:
pass # Best effort cleanup

port = constants.SKYLET_GRPC_PORT
# TODO(kevin): Handle race conditions here. Race conditions can only
# happen on Slurm, where there could be multiple clusters running in
# one network namespace. For other clouds, the behaviour will be that
# it always gets port 46590 (default port).
port = common_utils.find_free_port(constants.SKYLET_GRPC_PORT)
subprocess.run(
# We have made sure that `attempt_skylet.py` is executed with the
# skypilot runtime env activated, so that skylet can access the cloud


+ 34
- 9
sky/skylet/constants.py View File

@@ -25,6 +25,7 @@ SKY_RUNTIME_DIR_ENV_VAR_KEY = 'SKY_RUNTIME_DIR'
# them be in $HOME makes it more convenient.
SKY_LOGS_DIRECTORY = '~/sky_logs'
SKY_REMOTE_WORKDIR = '~/sky_workdir'
SKY_TEMPLATES_DIRECTORY = '~/sky_templates'
SKY_IGNORE_FILE = '.skyignore'
GIT_IGNORE_FILE = '.gitignore'

@@ -67,10 +68,23 @@ SKY_PIP_CMD = f'{SKY_PYTHON_CMD} -m pip'
# #!/opt/conda/bin/python3
SKY_RAY_CMD = (f'{SKY_PYTHON_CMD} $([ -s {SKY_RAY_PATH_FILE} ] && '
f'cat {SKY_RAY_PATH_FILE} 2> /dev/null || which ray)')

# Use $(which env) to find env, falling back to /usr/bin/env if which is
# unavailable. This works around a Slurm quirk where srun's execvp() doesn't
# check execute permissions, failing when $HOME/.local/bin/env (non-executable,
# from uv installation) shadows /usr/bin/env.
SKY_SLURM_UNSET_PYTHONPATH = ('$(which env 2>/dev/null || echo /usr/bin/env) '
'-u PYTHONPATH')
SKY_SLURM_PYTHON_CMD = (f'{SKY_SLURM_UNSET_PYTHONPATH} '
f'$({SKY_GET_PYTHON_PATH_CMD})')

# Separate env for SkyPilot runtime dependencies.
SKY_REMOTE_PYTHON_ENV_NAME = 'skypilot-runtime'
SKY_REMOTE_PYTHON_ENV: str = f'{SKY_RUNTIME_DIR}/{SKY_REMOTE_PYTHON_ENV_NAME}'
ACTIVATE_SKY_REMOTE_PYTHON_ENV = f'source {SKY_REMOTE_PYTHON_ENV}/bin/activate'
# Place the conda root in the runtime directory, as installing to $HOME
# on an NFS takes too long (1-2m slower).
SKY_CONDA_ROOT = f'{SKY_RUNTIME_DIR}/miniconda3'
# uv is used for venv and pip, much faster than python implementations.
SKY_UV_INSTALL_DIR = '"$HOME/.local/bin"'
# set UV_SYSTEM_PYTHON to false in case the
@@ -162,6 +176,10 @@ DISABLE_GPU_ECC_COMMAND = (
'{ sudo reboot || echo "Failed to reboot. ECC mode may not be disabled"; } '
'|| true; ')

SETUP_SKY_DIRS_COMMANDS = (f'mkdir -p ~/sky_workdir && '
f'mkdir -p ~/.sky/sky_app && '
f'mkdir -p {SKY_RUNTIME_DIR}/.sky;')

# Install conda on the remote cluster if it is not already installed.
# We use conda with python 3.10 to be consistent across multiple clouds with
# best effort.
@@ -178,8 +196,9 @@ CONDA_INSTALLATION_COMMANDS = (
# because for some images, conda is already installed, but not initialized.
# In this case, we need to initialize conda and set auto_activate_base to
# true.
'{ bash Miniconda3-Linux.sh -b || true; '
'eval "$(~/miniconda3/bin/conda shell.bash hook)" && conda init && '
'{ '
f'bash Miniconda3-Linux.sh -b -p "{SKY_CONDA_ROOT}" || true; '
f'eval "$({SKY_CONDA_ROOT}/bin/conda shell.bash hook)" && conda init && '
# Caller should replace {conda_auto_activate} with either true or false.
'conda config --set auto_activate_base {conda_auto_activate} && '
'conda activate base; }; '
@@ -222,7 +241,7 @@ _sky_version = str(version.parse(sky.__version__))
RAY_STATUS = f'RAY_ADDRESS=127.0.0.1:{SKY_REMOTE_RAY_PORT} {SKY_RAY_CMD} status'
RAY_INSTALLATION_COMMANDS = (
f'{SKY_UV_INSTALL_CMD};'
'mkdir -p ~/sky_workdir && mkdir -p ~/.sky/sky_app;'
f'{SETUP_SKY_DIRS_COMMANDS}'
# Print the PATH in provision.log to help debug PATH issues.
'echo PATH=$PATH; '
# Install setuptools<=69.5.1 to avoid the issue with the latest setuptools
@@ -256,7 +275,7 @@ RAY_INSTALLATION_COMMANDS = (
#
# Here, we add ~/.local/bin to the end of the PATH to make sure the issues
# mentioned above are resolved.
'export PATH=$PATH:$HOME/.local/bin; '
f'export PATH=$PATH:{SKY_RUNTIME_DIR}/.local/bin; '
# Writes ray path to file if it does not exist or the file is empty.
f'[ -s {SKY_RAY_PATH_FILE} ] || '
f'{{ {SKY_UV_RUN_CMD} '
@@ -264,18 +283,23 @@ RAY_INSTALLATION_COMMANDS = (

# Copy SkyPilot templates from the installed wheel to ~/sky_templates.
# This must run after the skypilot wheel is installed.
# Note: We remove ~/sky_templates first to avoid import conflicts where Python
# would import from ~/sky_templates instead of site-packages (because
# sky_templates itself is a package), leading to src == dst error when
# launching on an existing cluster.
COPY_SKYPILOT_TEMPLATES_COMMANDS = (
f'rm -rf {SKY_TEMPLATES_DIRECTORY}; '
f'{ACTIVATE_SKY_REMOTE_PYTHON_ENV}; '
f'{SKY_PYTHON_CMD} -c \''
'import sky_templates, shutil, os; '
'src = os.path.dirname(sky_templates.__file__); '
'dst = os.path.expanduser(\"~/sky_templates\"); '
f'dst = os.path.expanduser(\"{SKY_TEMPLATES_DIRECTORY}\"); '
'print(f\"Copying templates from {src} to {dst}...\"); '
'shutil.copytree(src, dst, dirs_exist_ok=True); '
'shutil.copytree(src, dst); '
'print(f\"Templates copied successfully\")\'; '
# Make scripts executable.
'find ~/sky_templates -type f ! -name "*.py" ! -name "*.md" '
'-exec chmod +x {} \\; ')
f'find {SKY_TEMPLATES_DIRECTORY} -type f ! -name "*.py" ! -name "*.md" '
'-exec chmod +x {} + ; ')

SKYPILOT_WHEEL_INSTALLATION_COMMANDS = (
f'{SKY_UV_INSTALL_CMD};'
@@ -438,6 +462,7 @@ OVERRIDEABLE_CONFIG_KEYS_IN_TASK: List[Tuple[str, ...]] = [
('gcp', 'enable_gvnic'),
('gcp', 'enable_gpu_direct'),
('gcp', 'placement_policy'),
('vast', 'secure_only'),
('active_workspace',),
]
# When overriding the SkyPilot configs on the API server with the client one,
@@ -532,7 +557,7 @@ CATALOG_SCHEMA_VERSION = 'v8'
CATALOG_DIR = '~/.sky/catalogs'
ALL_CLOUDS = ('aws', 'azure', 'gcp', 'ibm', 'lambda', 'scp', 'oci',
'kubernetes', 'runpod', 'vast', 'vsphere', 'cudo', 'fluidstack',
'paperspace', 'primeintellect', 'do', 'nebius', 'ssh',
'paperspace', 'primeintellect', 'do', 'nebius', 'ssh', 'slurm',
'hyperbolic', 'seeweb', 'shadeform')
# END constants used for service catalog.



+ 10
- 4
sky/skylet/events.py View File

@@ -236,7 +236,7 @@ class AutostopEvent(SkyletEvent):
RAY_PROVISIONER_SKYPILOT_TERMINATOR):
logger.info('Using new provisioner to stop the cluster.')
self._stop_cluster_with_new_provisioner(autostop_config, config,
provider_name)
provider_name, cloud)
return
logger.info('Not using new provisioner to stop the cluster. '
f'Cloud of this cluster: {provider_name}')
@@ -314,7 +314,8 @@ class AutostopEvent(SkyletEvent):
raise NotImplementedError

def _stop_cluster_with_new_provisioner(self, autostop_config,
cluster_config, provider_name):
cluster_config, provider_name,
cloud):
# pylint: disable=import-outside-toplevel
from sky import provision as provision_lib
autostop_lib.set_autostopping_started()
@@ -334,8 +335,13 @@ class AutostopEvent(SkyletEvent):

# Stop the ray autoscaler to avoid scaling up, during
# stopping/terminating of the cluster.
logger.info('Stopping the ray cluster.')
subprocess.run(f'{constants.SKY_RAY_CMD} stop', shell=True, check=True)
if not cloud.uses_ray():
logger.info('Skipping ray stop as cloud does not use Ray.')
else:
logger.info('Stopping the ray cluster.')
subprocess.run(f'{constants.SKY_RAY_CMD} stop',
shell=True,
check=True)

operation_fn = provision_lib.stop_instances
if autostop_config.down:


+ 52
- 0
sky/skylet/executor/README.md View File

@@ -0,0 +1,52 @@
# SkyPilot Task Executors

This module contains task executors for running user scripts on cluster nodes.

## Concepts

- **Code Generator**: A `TaskCodeGen` subclass (e.g., `RayCodeGen`, `SlurmCodeGen`) that generates the job driver script. Lives in `sky/backends/task_codegen.py`.
- **Job Driver**: The generated Python script (`~/.sky/sky_app/sky_job_<id>`) that runs on the head node and orchestrates distributed execution across all nodes.
- **Task Executor**: A module that runs on each cluster node to execute the user's script. Handles environment setup, logging, and coordination with the job driver.

## Architecture

```
┌─────────────────────────────────────────────────────────┐
│ Code Generator │
│ (RayCodeGen / SlurmCodeGen in task_codegen.py) │
│ Generates the job driver script │
└─────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────┐
│ Job Driver │
│ (~/.sky/sky_app/sky_job_<id> - runs on head node) │
└─────────────────────────────────────────────────────────┘
┌───────────────┼───────────────┐
▼ ▼ ▼
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Task Executor │ │ Task Executor │ │ Task Executor │
│ (head) │ │ (worker1) │ │ (worker2) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
```

## Executors

### `slurm.py` - Slurm Task Executor

Invoked on each Slurm compute node via:
```bash
srun python -m sky.skylet.executor.slurm --script=<user_script> --log-dir=<path> ...
```

Handles Slurm-specific concerns:
- Determines node identity from `SLURM_PROCID` and cluster IP mapping
- Coordinates setup/run phases via signal files on shared NFS
- Writes and streams logs to unique per-node log files

### Ray (no separate executor)

Ray uses `ray.remote()` to dispatch tasks directly to worker nodes. The execution
logic is inlined in the generated driver script rather than a separate module,
since Ray can execute Python functions directly.

+ 1
- 0
sky/skylet/executor/__init__.py View File

@@ -0,0 +1 @@
"""Task Executors"""

+ 189
- 0
sky/skylet/executor/slurm.py View File

@@ -0,0 +1,189 @@
"""Slurm distributed task executor for SkyPilot.

This module is invoked on each Slurm compute node via:
srun python -m sky.skylet.executor.slurm --script=... --log-dir=...
"""
import argparse
import json
import os
import pathlib
import socket
import subprocess
import sys
import time

import colorama

from sky.skylet.log_lib import run_bash_command_with_log


def _get_ip_address() -> str:
"""Get the IP address of the current node."""
ip_result = subprocess.run(['hostname', '-I'],
capture_output=True,
text=True,
check=False)
return ip_result.stdout.strip().split(
)[0] if ip_result.returncode == 0 else 'unknown'


def _get_job_node_ips() -> str:
"""Get IPs of all nodes in the current Slurm job."""
nodelist = os.environ.get('SLURM_JOB_NODELIST', '')
assert nodelist, 'SLURM_JOB_NODELIST is not set'

# Expand compressed nodelist (e.g., "node[1-3,5]"
# -> "node1\nnode2\nnode3\nnode5")
result = subprocess.run(['scontrol', 'show', 'hostnames', nodelist],
capture_output=True,
text=True,
check=False)
if result.returncode != 0:
raise RuntimeError(f'Failed to get hostnames for: {nodelist}')

hostnames = result.stdout.strip().split('\n')
ips = []
for hostname in hostnames:
try:
ip = socket.gethostbyname(hostname)
ips.append(ip)
except socket.gaierror as e:
raise RuntimeError('Failed to get IP for hostname: '
f'{hostname}') from e

return '\n'.join(ips)


def main():
parser = argparse.ArgumentParser(
description='SkyPilot Slurm task runner for distributed execution')
parser.add_argument('--script', help='User script (inline, shell-quoted)')
parser.add_argument('--script-path',
help='Path to script file (if too long for inline)')
parser.add_argument('--env-vars',
default='{}',
help='JSON-encoded environment variables')
parser.add_argument('--log-dir',
required=True,
help='Directory for log files')
parser.add_argument('--cluster-num-nodes',
type=int,
required=True,
help='Total number of nodes in the cluster')
parser.add_argument('--cluster-ips',
required=True,
help='Comma-separated list of cluster node IPs')
parser.add_argument('--task-name',
default=None,
help='Task name for single-node log prefix')
parser.add_argument(
'--is-setup',
action='store_true',
help=
'Whether this is a setup command (affects logging prefix and filename)')
parser.add_argument('--alloc-signal-file',
help='Path to allocation signal file')
parser.add_argument('--setup-done-signal-file',
help='Path to setup-done signal file')
args = parser.parse_args()

assert args.script is not None or args.script_path is not None, (
'Either '
'--script or --script-path must be provided')

# Task rank, different from index of the node in the cluster.
rank = int(os.environ['SLURM_PROCID'])
num_nodes = int(os.environ.get('SLURM_NNODES', 1))
is_single_node_cluster = (args.cluster_num_nodes == 1)

# Determine node index from IP (like Ray's cluster_ips_to_node_id)
cluster_ips = args.cluster_ips.split(',')
ip_addr = _get_ip_address()
try:
node_idx = cluster_ips.index(ip_addr)
except ValueError as e:
raise RuntimeError(f'IP address {ip_addr} not found in '
f'cluster IPs: {cluster_ips}') from e
node_name = 'head' if node_idx == 0 else f'worker{node_idx}'

# Log files are written to a shared filesystem, so each node must use a
# unique filename to avoid collisions.
if args.is_setup:
# TODO(kevin): This is inconsistent with other clouds, where it is
# simply called 'setup.log'. On Slurm that is obviously not possible,
# since the ~/sky_logs directory is shared by all nodes, so
# 'setup.log' will be overwritten by other nodes.
# Perhaps we should apply this naming convention to other clouds.
log_filename = f'setup-{node_name}.log'
elif is_single_node_cluster:
log_filename = 'run.log'
else:
log_filename = f'{rank}-{node_name}.log'
log_path = os.path.join(args.log_dir, log_filename)

if args.script_path:
with open(args.script_path, 'r', encoding='utf-8') as f:
script = f.read()
else:
script = args.script

# Parse env vars and add SKYPILOT environment variables
env_vars = json.loads(args.env_vars)
if not args.is_setup:
# For setup, env vars are set in CloudVmRayBackend._setup.
env_vars['SKYPILOT_NODE_RANK'] = str(rank)
env_vars['SKYPILOT_NUM_NODES'] = str(num_nodes)
env_vars['SKYPILOT_NODE_IPS'] = _get_job_node_ips()

# Signal file coordination for setup/run synchronization
# Rank 0 touches the allocation signal to indicate resources acquired
if args.alloc_signal_file is not None and rank == 0:
pathlib.Path(args.alloc_signal_file).touch()

# Wait for setup to complete.
while args.setup_done_signal_file is not None and not os.path.exists(
args.setup_done_signal_file):
time.sleep(0.1)

# Build log prefix
# For setup on head: (setup pid={pid})
# For setup on workers: (setup pid={pid}, ip=1.2.3.4)
# For single-node cluster: (task_name, pid={pid})
# For multi-node on head: (head, rank=0, pid={pid})
# For multi-node on workers: (worker1, rank=1, pid={pid}, ip=1.2.3.4)
# The {pid} placeholder will be replaced by run_with_log
if args.is_setup:
# Setup prefix: head (node_idx=0) shows no IP, workers show IP
if node_idx == 0:
prefix = (f'{colorama.Fore.CYAN}(setup pid={{pid}})'
f'{colorama.Style.RESET_ALL} ')
else:
prefix = (f'{colorama.Fore.CYAN}(setup pid={{pid}}, ip={ip_addr})'
f'{colorama.Style.RESET_ALL} ')
elif is_single_node_cluster:
# Single-node cluster: use task name
name_str = args.task_name if args.task_name else 'task'
prefix = (f'{colorama.Fore.CYAN}({name_str}, pid={{pid}})'
f'{colorama.Style.RESET_ALL} ')
else:
# Multi-node cluster: head (node_idx=0) shows no IP, workers show IP
if node_idx == 0:
prefix = (
f'{colorama.Fore.CYAN}({node_name}, rank={rank}, pid={{pid}})'
f'{colorama.Style.RESET_ALL} ')
else:
prefix = (f'{colorama.Fore.CYAN}'
f'({node_name}, rank={rank}, pid={{pid}}, ip={ip_addr})'
f'{colorama.Style.RESET_ALL} ')

returncode = run_bash_command_with_log(script,
log_path,
env_vars=env_vars,
stream_logs=True,
streaming_prefix=prefix)

sys.exit(returncode)


if __name__ == '__main__':
main()

+ 2
- 1
sky/skylet/job_lib.py View File

@@ -1273,4 +1273,5 @@ class JobLibCodeGen:
def _build(cls, code: List[str]) -> str:
code = cls._PREFIX + code
code = ';'.join(code)
return f'{constants.SKY_PYTHON_CMD} -u -c {shlex.quote(code)}'
return (f'{constants.ACTIVATE_SKY_REMOTE_PYTHON_ENV}; '
f'{constants.SKY_PYTHON_CMD} -u -c {shlex.quote(code)}')

+ 22
- 6
sky/skylet/log_lib.py View File

@@ -172,7 +172,7 @@ def run_with_log(
streaming_prefix: Optional[str] = None,
log_cmd: bool = False,
**kwargs,
) -> Union[int, Tuple[int, str, str]]:
) -> Union[int, Tuple[int, str, str], Tuple[int, int]]:
"""Runs a command and logs its output to a file.

Args:
@@ -183,6 +183,8 @@ def run_with_log(
process_stream: Whether to post-process the stdout/stderr of the
command, such as replacing or skipping lines on the fly. If
enabled, lines are printed only when '\r' or '\n' is found.
streaming_prefix: Optional prefix for each log line. Can contain {pid}
placeholder which will be replaced with the subprocess PID.

Returns the returncode or returncode, stdout and stderr of the command.
Note that the stdout and stderr is already decoded.
@@ -228,6 +230,13 @@ def run_with_log(
# For backward compatibility, do not specify use_kill_pg by
# default.
subprocess_utils.kill_process_daemon(proc.pid)

# Format streaming_prefix with subprocess PID if it contains {pid}
formatted_streaming_prefix = streaming_prefix
if streaming_prefix and '{pid}' in streaming_prefix:
formatted_streaming_prefix = streaming_prefix.format(
pid=proc.pid)

stdout = ''
stderr = ''
stdout_stream_handler = None
@@ -256,7 +265,7 @@ def run_with_log(
line_processor=line_processor,
# Replace CRLF when the output is logged to driver by ray.
replace_crlf=with_ray,
streaming_prefix=streaming_prefix,
streaming_prefix=formatted_streaming_prefix,
)
stdout_stream_handler = functools.partial(
_handle_io_stream,
@@ -349,7 +358,8 @@ def run_bash_command_with_log(bash_command: str,
log_path: str,
env_vars: Optional[Dict[str, str]] = None,
stream_logs: bool = False,
with_ray: bool = False):
with_ray: bool = False,
streaming_prefix: Optional[str] = None):
with tempfile.NamedTemporaryFile('w', prefix='sky_app_',
delete=False) as fp:
bash_command = make_task_bash_script(bash_command, env_vars=env_vars)
@@ -364,6 +374,7 @@ def run_bash_command_with_log(bash_command: str,
log_path,
stream_logs=stream_logs,
with_ray=with_ray,
streaming_prefix=streaming_prefix,
shell=True)


@@ -372,9 +383,14 @@ def run_bash_command_with_log_and_return_pid(
log_path: str,
env_vars: Optional[Dict[str, str]] = None,
stream_logs: bool = False,
with_ray: bool = False):
return_code = run_bash_command_with_log(bash_command, log_path, env_vars,
stream_logs, with_ray)
with_ray: bool = False,
streaming_prefix: Optional[str] = None):
return_code = run_bash_command_with_log(bash_command,
log_path,
env_vars,
stream_logs,
with_ray,
streaming_prefix=streaming_prefix)
return {'return_code': return_code, 'pid': os.getpid()}




+ 8
- 6
sky/skylet/log_lib.pyi View File

@@ -68,7 +68,7 @@ def run_with_log(cmd: Union[List[str], str],
process_stream: bool = ...,
line_processor: Optional[log_utils.LineProcessor] = ...,
streaming_prefix: Optional[str] = ...,
ray_job_id: Optional[str] = ...,
log_cmd: bool = ...,
**kwargs) -> int:
...

@@ -87,7 +87,7 @@ def run_with_log(cmd: Union[List[str], str],
process_stream: bool = ...,
line_processor: Optional[log_utils.LineProcessor] = ...,
streaming_prefix: Optional[str] = ...,
ray_job_id: Optional[str] = ...,
log_cmd: bool = ...,
**kwargs) -> Tuple[int, str, str]:
...

@@ -106,8 +106,8 @@ def run_with_log(cmd: Union[List[str], str],
process_stream: bool = ...,
line_processor: Optional[log_utils.LineProcessor] = ...,
streaming_prefix: Optional[str] = ...,
ray_job_id: Optional[str] = ...,
**kwargs) -> Union[int, Tuple[int, str, str]]:
log_cmd: bool = ...,
**kwargs) -> Tuple[int, int]:
...


@@ -125,7 +125,8 @@ def run_bash_command_with_log(bash_command: str,
log_path: str,
env_vars: Optional[Dict[str, str]] = ...,
stream_logs: bool = ...,
with_ray: bool = ...):
with_ray: bool = ...,
streaming_prefix: Optional[str] = ...) -> int:
...


@@ -134,7 +135,8 @@ def run_bash_command_with_log_and_return_pid(
log_path: str,
env_vars: Optional[Dict[str, str]] = ...,
stream_logs: bool = ...,
with_ray: bool = ...):
with_ray: bool = ...,
streaming_prefix: Optional[str] = ...) -> Dict[str, Union[int, str]]:
...




+ 5
- 1
sky/skylet/skylet.py View File

@@ -48,8 +48,12 @@ def start_grpc_server(port: int = constants.SKYLET_GRPC_PORT) -> grpc.Server:
# putting it here for visibility.
# TODO(kevin): Determine the optimal max number of threads.
max_workers = min(32, (os.cpu_count() or 1) + 4)
# There's only a single skylet process per cluster, so disable
# SO_REUSEPORT to raise an error if the port is already in use.
options = (('grpc.so_reuseport', 0),)
server = grpc.server(
concurrent.futures.ThreadPoolExecutor(max_workers=max_workers))
concurrent.futures.ThreadPoolExecutor(max_workers=max_workers),
options=options)

autostopv1_pb2_grpc.add_AutostopServiceServicer_to_server(
services.AutostopServiceImpl(), server)


+ 2
- 1
sky/skylet/subprocess_daemon.py View File

@@ -110,7 +110,8 @@ def kill_process_tree(process: psutil.Process,


def main():
# daemonize()
daemonize()

parser = argparse.ArgumentParser()
parser.add_argument('--parent-pid', type=int, required=True)
parser.add_argument('--proc-pid', type=int, required=True)


+ 12
- 0
sky/ssh_node_pools/constants.py View File

@@ -0,0 +1,12 @@
"""Constants for SSH Node Pools"""
# pylint: disable=line-too-long
import os

DEFAULT_KUBECONFIG_PATH = os.path.expanduser('~/.kube/config')
SSH_CONFIG_PATH = os.path.expanduser('~/.ssh/config')
NODE_POOLS_INFO_DIR = os.path.expanduser('~/.sky/ssh_node_pools_info')
NODE_POOLS_KEY_DIR = os.path.expanduser('~/.sky/ssh_keys')
DEFAULT_SSH_NODE_POOLS_PATH = os.path.expanduser('~/.sky/ssh_node_pools.yaml')

# TODO (kyuds): make this configurable?
K3S_TOKEN = 'mytoken' # Any string can be used as the token

+ 40
- 3
sky/ssh_node_pools/core.py View File

@@ -1,10 +1,15 @@
"""SSH Node Pool management core functionality."""
import os
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, Tuple

import yaml

from sky import clouds
from sky.ssh_node_pools import constants
from sky.ssh_node_pools import deploy
from sky.usage import usage_lib
from sky.utils import common_utils
from sky.utils import yaml_utils


@@ -12,8 +17,8 @@ class SSHNodePoolManager:
"""Manager for SSH Node Pool configurations."""

def __init__(self):
self.config_path = Path.home() / '.sky' / 'ssh_node_pools.yaml'
self.keys_dir = Path.home() / '.sky' / 'ssh_keys'
self.config_path = Path(constants.DEFAULT_SSH_NODE_POOLS_PATH)
self.keys_dir = Path(constants.NODE_POOLS_KEY_DIR)
self.keys_dir.mkdir(parents=True, exist_ok=True)

def get_all_pools(self) -> Dict[str, Any]:
@@ -133,3 +138,35 @@ def list_ssh_keys() -> List[str]:
"""List available SSH keys."""
manager = SSHNodePoolManager()
return manager.list_ssh_keys()


@usage_lib.entrypoint
def ssh_up(infra: Optional[str] = None, cleanup: bool = False) -> None:
"""Deploys or tears down a Kubernetes cluster on SSH targets.

Args:
infra: Name of the cluster configuration in ssh_node_pools.yaml.
If None, the first cluster in the file is used.
cleanup: If True, clean up the cluster instead of deploying.
"""
deploy.run(cleanup=cleanup, infra=infra)


@usage_lib.entrypoint
def ssh_status(context_name: str) -> Tuple[bool, str]:
"""Check the status of an SSH Node Pool context.

Args:
context_name: The SSH context name (e.g., 'ssh-my-cluster')

Returns:
Tuple[bool, str]: (is_ready, reason)
- is_ready: True if the SSH Node Pool is ready, False otherwise
- reason: Explanation of the status
"""
try:
is_ready, reason = clouds.SSH.check_single_context(context_name)
return is_ready, reason
except Exception as e: # pylint: disable=broad-except
return False, ('Failed to check SSH context: '
f'{common_utils.format_exception(e)}')

+ 4
- 0
sky/ssh_node_pools/deploy/__init__.py View File

@@ -0,0 +1,4 @@
"""Module for Deploying SSH Node Pools"""
from sky.ssh_node_pools.deploy.deploy import run

__all__ = ['run']

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save
Baidu
map