2 Commits

Author SHA1 Message Date
  color f363913ce6 Merge pull request 'fix_load_master (#955)' (#38) from OpenI/MSAdapter:master into master 1 year ago
  hanjr 7c2fc8bf87 fix_load_master (#955) 1 year ago
3 changed files with 17 additions and 9 deletions
Split View
  1. +2
    -2
      mindtorch/torch/_utils.py
  2. +8
    -4
      mindtorch/torch/storage.py
  3. +7
    -3
      mindtorch/torch/tensor.py

+ 2
- 2
mindtorch/torch/_utils.py View File

@@ -56,7 +56,7 @@ def _rebuild_tensor(storage, storage_offset, size, stride):
unsupported_attr(stride)
from mindtorch.torch.tensor import tensor # pylint: disable=R0401, C0415
t = tensor([], dtype=storage.dtype, device=storage._untyped().device)
return t.set_(storage._untyped(), storage_offset, size)
return t._set_storage(storage._untyped(), storage_offset, size)


def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None):
@@ -82,7 +82,7 @@ def _rebuild_parameter(data, requires_grad, backward_hooks):
unsupported_attr(backward_hooks)
from mindtorch.torch.nn import Parameter # pylint: disable=R0401, C0415
param = Parameter(data, requires_grad)
param.set_(data.storage()._untyped(), 0, data.size())
param._set_storage(data.storage()._untyped(), 0, data.size())
return param

def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state):


+ 8
- 4
mindtorch/torch/storage.py View File

@@ -72,10 +72,12 @@ class _StorageBase():
self._update_referenced_tensor()
return self

def _update_referenced_tensor(self, strict=True, size=None):
def _update_referenced_tensor(self, strict=True, size=None, storage_offset=0):
if self.referenced_tensor is not None:
np_data = np.frombuffer(self.inner_data,
_TypeDict.get(self.referenced_tensor.dtype))
if storage_offset:
np_data = np_data[storage_offset:]
if size is not None:
np_data = np_data.reshape(size)
if strict:
@@ -177,10 +179,12 @@ class _StorageBase():
return False

def _set_from_file(self, f, offset, is_real_file, element_size):
if not is_real_file:
raise RuntimeError("Currently, in `storage._set_from_file` only is_real_file==True supported.")
nbytes = np.frombuffer(f.read(8), np.int64).item() * element_size
array = np.fromfile(f, dtype=np.uint8, count=nbytes, offset=offset)
if is_real_file:
array = np.fromfile(f, dtype=np.uint8, count=nbytes, offset=offset)
else:
f.seek(offset)
array = np.frombuffer(f.read(nbytes), np.uint8)
self.inner_data[:] = array
self._update_referenced_tensor()
return self


+ 7
- 3
mindtorch/torch/tensor.py View File

@@ -1659,7 +1659,6 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
if graph_mode_condition():
warning('`Tensor.set_` is an in-place operation and "x.set_()" is not supported to use '
'in MindSpore static graph mode.')

if isinstance(source, Tensor):
if source.dtype != self.dtype:
raise RuntimeError("In `tensor.set_`, sourse.dtype must equal to self.dtype.")
@@ -1669,19 +1668,24 @@ class Tensor(StubTensor, metaclass=_TensorMeta):
self.assign_value(source)
return self

return self._set_storage(source=source, storage_offset=storage_offset, size=size, stride=stride)

def _set_storage(self,source=None, storage_offset=0, size=None, stride=None):
unsupported_attr(stride)
if isinstance(source, _TypedStorage):
# handle source is a _TypedStorage
if source.dtype != self.dtype:
raise RuntimeError("In `tensor.set_`, _TypedStorage.dtype must equal to self.dtype.")
source._storage.referenced_tensor = self
source._storage._update_referenced_tensor(strict=False, size=size)
source._storage._update_referenced_tensor(strict=False, size=size, storage_offset=storage_offset)
return self

# handle source is a _UntypedStorage
source.referenced_tensor = self
source._update_referenced_tensor(strict=False, size=size)
source._update_referenced_tensor(strict=False, size=size,storage_offset=storage_offset)
return self


def to(self, *args, **kwargs):
# TODO:
# Note that this API requires the user to ensure the correctness of the input currently,


Loading…
Cancel
Save
Baidu
map