3 Commits

Author SHA1 Message Date
  Nicki Skafte Detlefsen 04baf7ff27
Sanitize profile filename (#21395) 12 hours ago
  dependabot[bot] 716c2c61bb
build(deps): update jsonargparse[jsonnet,signatures] requirement from <4.44.0,>=4.39.0 to >=4.39.0,<4.45.0 in /requirements (#21392) 12 hours ago
  Nicki Skafte Detlefsen ad7a958237
Deprecate method `to_torchscript` (#21397) 12 hours ago
7 changed files with 98 additions and 18 deletions
Split View
  1. +1
    -1
      requirements/pytorch/extra.txt
  2. +7
    -0
      src/lightning/pytorch/CHANGELOG.md
  3. +11
    -1
      src/lightning/pytorch/core/module.py
  4. +9
    -2
      src/lightning/pytorch/profilers/profiler.py
  5. +2
    -1
      tests/tests_pytorch/helpers/test_models.py
  6. +41
    -13
      tests/tests_pytorch/models/test_torchscript.py
  7. +27
    -0
      tests/tests_pytorch/profilers/test_profiler.py

+ 1
- 1
requirements/pytorch/extra.txt View File

@@ -5,7 +5,7 @@
matplotlib>3.1, <3.11.0
omegaconf >=2.2.3, <2.4.0
hydra-core >=1.2.0, <1.4.0
jsonargparse[signatures,jsonnet] >=4.39.0, <4.44.0
jsonargparse[signatures,jsonnet] >=4.39.0, <4.45.0
rich >=12.3.0, <14.3.0
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
bitsandbytes >=0.45.2,<0.47.0; platform_system != "Darwin"

+ 7
- 0
src/lightning/pytorch/CHANGELOG.md View File

@@ -16,12 +16,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

-

### Deprecated

- Deprecated `to_torchscript` method due to deprecation of TorchScript in PyTorch ([#21397](https://github.com/Lightning-AI/pytorch-lightning/pull/21397))

### Removed

- Removed support for Python 3.9 due to end-of-life status ([#21398](https://github.com/Lightning-AI/pytorch-lightning/pull/21398))

### Fixed

- Sanitize profiler filenames when saving to avoid crashes due to invalid characters ([#21395](https://github.com/Lightning-AI/pytorch-lightning/pull/21395))


- Fix `StochasticWeightAveraging` with infinite epochs ([#21396](https://github.com/Lightning-AI/pytorch-lightning/pull/21396))




+ 11
- 1
src/lightning/pytorch/core/module.py View File

@@ -64,7 +64,7 @@ from lightning.pytorch.utilities import GradClipAlgorithmType
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_6, _TORCHMETRICS_GREATER_EQUAL_0_9_1
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_deprecation, rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
from lightning.pytorch.utilities.types import (
_METRIC,
@@ -1498,6 +1498,11 @@ class LightningModule(
scripted you should override this method. In case you want to return multiple modules, we recommend using a
dictionary.

.. deprecated::
``LightningModule.to_torchscript`` has been deprecated in v2.7 and will be removed in v2.8.
TorchScript is deprecated in PyTorch. Use ``torch.export.export()`` for model exporting instead.
See https://pytorch.org/docs/stable/export.html for more information.

Args:
file_path: Path where to save the torchscript. Default: None (no file saved).
method: Whether to use TorchScript's script or trace method. Default: 'script'
@@ -1536,6 +1541,11 @@ class LightningModule(
defined or not.

"""
rank_zero_deprecation(
"`LightningModule.to_torchscript` has been deprecated in v2.7 and will be removed in v2.8. "
"TorchScript is deprecated in PyTorch. Use `torch.export.export()` for model exporting instead. "
"See https://pytorch.org/docs/stable/export.html for more information."
)
mode = self.training

if method == "script":


+ 9
- 2
src/lightning/pytorch/profilers/profiler.py View File

@@ -15,6 +15,7 @@

import logging
import os
import re
from abc import ABC, abstractmethod
from collections.abc import Generator
from contextlib import contextmanager
@@ -80,7 +81,6 @@ class Profiler(ABC):
self,
action_name: Optional[str] = None,
extension: str = ".txt",
split_token: str = "-", # noqa: S107
) -> str:
args = []
if self._stage is not None:
@@ -91,7 +91,14 @@ class Profiler(ABC):
args.append(str(self._local_rank))
if action_name is not None:
args.append(action_name)
return split_token.join(args) + extension
base = "-".join(args)
# Replace a set of path-unsafe characters across platforms with '_'
base = re.sub(r"[\\/:*?\"<>|\n\r\t]", "_", base)
base = re.sub(r"_+", "_", base)
base = base.strip()
if not base:
base = "profile"
return base + extension

def _prepare_streams(self) -> None:
if self._write_stream is not None:


+ 2
- 1
tests/tests_pytorch/helpers/test_models.py View File

@@ -46,7 +46,8 @@ def test_models(tmp_path, data_class, model_class):
if dm is not None:
trainer.test(model, datamodule=dm)

model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
model.to_torchscript()
if data_class:
model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample)



+ 41
- 13
tests/tests_pytorch/models/test_torchscript.py View File

@@ -21,6 +21,7 @@ from fsspec.implementations.local import LocalFileSystem

from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.rank_zero import LightningDeprecationWarning
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.demos.boring_classes import BoringModel
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleRNN
@@ -36,7 +37,8 @@ def test_torchscript_input_output(modelclass):
if isinstance(model, BoringModel):
model.example_input_array = torch.randn(5, 32)

script = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert isinstance(script, torch.jit.ScriptModule)

model.eval()
@@ -59,7 +61,8 @@ def test_torchscript_example_input_output_trace(modelclass):
if isinstance(model, BoringModel):
model.example_input_array = torch.randn(5, 32)

script = model.to_torchscript(method="trace")
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(method="trace")
assert isinstance(script, torch.jit.ScriptModule)

model.eval()
@@ -74,7 +77,8 @@ def test_torchscript_input_output_trace():
"""Test that traced LightningModule forward works with example_inputs."""
model = BoringModel()
example_inputs = torch.randn(1, 32)
script = model.to_torchscript(example_inputs=example_inputs, method="trace")
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(example_inputs=example_inputs, method="trace")
assert isinstance(script, torch.jit.ScriptModule)

model.eval()
@@ -99,7 +103,8 @@ def test_torchscript_device(device_str):
model = BoringModel().to(device)
model.example_input_array = torch.randn(5, 32)

script = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert next(script.parameters()).device == device
script_output = script(model.example_input_array.to(device))
assert script_output.device == device
@@ -121,7 +126,8 @@ def test_torchscript_device_with_check_inputs(device_str):

check_inputs = torch.rand(5, 32)

script = model.to_torchscript(method="trace", check_inputs=check_inputs)
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(method="trace", check_inputs=check_inputs)
assert isinstance(script, torch.jit.ScriptModule)


@@ -129,11 +135,13 @@ def test_torchscript_retain_training_state():
"""Test that torchscript export does not alter the training mode of original model."""
model = BoringModel()
model.train(True)
script = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert model.training
assert not script.training
model.train(False)
_ = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
_ = model.to_torchscript()
assert not model.training
assert not script.training

@@ -142,7 +150,8 @@ def test_torchscript_retain_training_state():
def test_torchscript_properties(modelclass):
"""Test that scripted LightningModule has unnecessary methods removed."""
model = modelclass()
script = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert not hasattr(model, "batch_size") or hasattr(script, "batch_size")
assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate")
assert not callable(getattr(script, "training_step", None))
@@ -153,7 +162,8 @@ def test_torchscript_save_load(tmp_path, modelclass):
"""Test that scripted LightningModule is correctly saved and can be loaded."""
model = modelclass()
output_file = str(tmp_path / "model.pt")
script = model.to_torchscript(file_path=output_file)
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(file_path=output_file)
loaded_script = torch.jit.load(output_file)
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))

@@ -170,7 +180,8 @@ def test_torchscript_save_load_custom_filesystem(tmp_path, modelclass):

model = modelclass()
output_file = os.path.join(_DUMMY_PRFEIX, _PREFIX_SEPARATOR, tmp_path, "model.pt")
script = model.to_torchscript(file_path=output_file)
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(file_path=output_file)

fs = get_filesystem(output_file)
with fs.open(output_file, "rb") as f:
@@ -184,7 +195,10 @@ def test_torchcript_invalid_method():
model = BoringModel()
model.train(True)

with pytest.raises(ValueError, match="only supports 'script' or 'trace'"):
with (
pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"),
pytest.raises(ValueError, match="only supports 'script' or 'trace'"),
):
model.to_torchscript(method="temp")


@@ -193,7 +207,10 @@ def test_torchscript_with_no_input():
model = BoringModel()
model.example_input_array = None

with pytest.raises(ValueError, match="requires either `example_inputs` or `model.example_input_array`"):
with (
pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"),
pytest.raises(ValueError, match="requires either `example_inputs` or `model.example_input_array`"),
):
model.to_torchscript(method="trace")


@@ -224,6 +241,17 @@ def test_torchscript_script_recursively():

lm = Parent()
assert not lm._jit_is_scripting
script = lm.to_torchscript(method="script")
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = lm.to_torchscript(method="script")
assert not lm._jit_is_scripting
assert isinstance(script, torch.jit.RecursiveScriptModule)


def test_to_torchscript_deprecation():
"""Test that to_torchscript raises a deprecation warning."""
model = BoringModel()
model.example_input_array = torch.randn(5, 32)

with pytest.warns(LightningDeprecationWarning, match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert isinstance(script, torch.jit.ScriptModule)

+ 27
- 0
tests/tests_pytorch/profilers/test_profiler.py View File

@@ -322,6 +322,33 @@ def test_advanced_profiler_dump_states(tmp_path):
assert len(data) > 0


@pytest.mark.parametrize("char", ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\n", "\r", "\t"])
def test_advanced_profiler_dump_states_sanitizes_filename(tmp_path, char):
"""Profiler should sanitize action names to produce filesystem-safe .prof filenames.

This guards against errors when callbacks or actions include path-unsafe characters (e.g., metric names with '/').

"""
profiler = AdvancedProfiler(dirpath=tmp_path, dump_stats=True)
action_name = f"before{char}after"
with profiler.profile(action_name):
pass

profiler.describe()

prof_files = [f for f in os.listdir(tmp_path) if f.endswith(".prof")]
assert len(prof_files) == 1
prof_name = prof_files[0]

# Ensure none of the path-unsafe characters are present in the produced filename
forbidden = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\n", "\r", "\t"]
for bad in forbidden:
assert bad not in prof_name

# File should be non-empty
assert (tmp_path / prof_name).read_bytes()


def test_advanced_profiler_value_errors(advanced_profiler):
"""Ensure errors are raised where expected."""
action = "test"


Loading…
Cancel
Save
Baidu
map