@@ -21,6 +21,7 @@ from fsspec.implementations.local import LocalFileSystem
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.rank_zero import LightningDeprecationWarning
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.demos.boring_classes import BoringModel
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleRNN
@@ -36,7 +37,8 @@ def test_torchscript_input_output(modelclass):
if isinstance(model, BoringModel):
model.example_input_array = torch.randn(5, 32)
script = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert isinstance(script, torch.jit.ScriptModule)
model.eval()
@@ -59,7 +61,8 @@ def test_torchscript_example_input_output_trace(modelclass):
if isinstance(model, BoringModel):
model.example_input_array = torch.randn(5, 32)
script = model.to_torchscript(method="trace")
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(method="trace")
assert isinstance(script, torch.jit.ScriptModule)
model.eval()
@@ -74,7 +77,8 @@ def test_torchscript_input_output_trace():
"""Test that traced LightningModule forward works with example_inputs."""
model = BoringModel()
example_inputs = torch.randn(1, 32)
script = model.to_torchscript(example_inputs=example_inputs, method="trace")
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(example_inputs=example_inputs, method="trace")
assert isinstance(script, torch.jit.ScriptModule)
model.eval()
@@ -99,7 +103,8 @@ def test_torchscript_device(device_str):
model = BoringModel().to(device)
model.example_input_array = torch.randn(5, 32)
script = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert next(script.parameters()).device == device
script_output = script(model.example_input_array.to(device))
assert script_output.device == device
@@ -121,7 +126,8 @@ def test_torchscript_device_with_check_inputs(device_str):
check_inputs = torch.rand(5, 32)
script = model.to_torchscript(method="trace", check_inputs=check_inputs)
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(method="trace", check_inputs=check_inputs)
assert isinstance(script, torch.jit.ScriptModule)
@@ -129,11 +135,13 @@ def test_torchscript_retain_training_state():
"""Test that torchscript export does not alter the training mode of original model."""
model = BoringModel()
model.train(True)
script = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert model.training
assert not script.training
model.train(False)
_ = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
_ = model.to_torchscript()
assert not model.training
assert not script.training
@@ -142,7 +150,8 @@ def test_torchscript_retain_training_state():
def test_torchscript_properties(modelclass):
"""Test that scripted LightningModule has unnecessary methods removed."""
model = modelclass()
script = model.to_torchscript()
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert not hasattr(model, "batch_size") or hasattr(script, "batch_size")
assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate")
assert not callable(getattr(script, "training_step", None))
@@ -153,7 +162,8 @@ def test_torchscript_save_load(tmp_path, modelclass):
"""Test that scripted LightningModule is correctly saved and can be loaded."""
model = modelclass()
output_file = str(tmp_path / "model.pt")
script = model.to_torchscript(file_path=output_file)
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(file_path=output_file)
loaded_script = torch.jit.load(output_file)
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))
@@ -170,7 +180,8 @@ def test_torchscript_save_load_custom_filesystem(tmp_path, modelclass):
model = modelclass()
output_file = os.path.join(_DUMMY_PRFEIX, _PREFIX_SEPARATOR, tmp_path, "model.pt")
script = model.to_torchscript(file_path=output_file)
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript(file_path=output_file)
fs = get_filesystem(output_file)
with fs.open(output_file, "rb") as f:
@@ -184,7 +195,10 @@ def test_torchcript_invalid_method():
model = BoringModel()
model.train(True)
with pytest.raises(ValueError, match="only supports 'script' or 'trace'"):
with (
pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"),
pytest.raises(ValueError, match="only supports 'script' or 'trace'"),
):
model.to_torchscript(method="temp")
@@ -193,7 +207,10 @@ def test_torchscript_with_no_input():
model = BoringModel()
model.example_input_array = None
with pytest.raises(ValueError, match="requires either `example_inputs` or `model.example_input_array`"):
with (
pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"),
pytest.raises(ValueError, match="requires either `example_inputs` or `model.example_input_array`"),
):
model.to_torchscript(method="trace")
@@ -224,6 +241,17 @@ def test_torchscript_script_recursively():
lm = Parent()
assert not lm._jit_is_scripting
script = lm.to_torchscript(method="script")
with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"):
script = lm.to_torchscript(method="script")
assert not lm._jit_is_scripting
assert isinstance(script, torch.jit.RecursiveScriptModule)
def test_to_torchscript_deprecation():
"""Test that to_torchscript raises a deprecation warning."""
model = BoringModel()
model.example_input_array = torch.randn(5, 32)
with pytest.warns(LightningDeprecationWarning, match="has been deprecated in v2.7 and will be removed in v2.8"):
script = model.to_torchscript()
assert isinstance(script, torch.jit.ScriptModule)