2 Commits

Author SHA1 Message Date
  Hanks d0cd5ac50d
Merge pull request #841 from hpcaitech/oom_fix 9 months ago
  hxwang 5730060f41
[ckpt] mitigate gpu mem peak when loading ckpt 9 months ago
1 changed files with 1 additions and 2 deletions
Split View
  1. +1
    -2
      opensora/utils/ckpt.py

+ 1
- 2
opensora/utils/ckpt.py View File

@@ -113,8 +113,7 @@ def load_checkpoint(

log_message(f"Loading checkpoint from {path}")
if path.endswith(".safetensors"):
# ckpt = load_file(path, device=str(device_map))
ckpt = load_file(path, device=torch.cuda.current_device())
ckpt = load_file(path, device='cpu')

if rename_keys is not None:
# rename keys in the loaded state_dict with old_key_prefix to with new_key_prefix.


Loading…
Cancel
Save
Baidu
map