4 Commits

23 changed files with 474 additions and 4283 deletions
Split View
  1. +22
    -1
      README.md
  2. +17
    -10
      examples/diffusers/janus/demo/app.py
  3. +40
    -21
      examples/diffusers/janus/demo/app_janusflow.py
  4. +37
    -16
      examples/diffusers/janus/demo/app_januspro.py
  5. +1
    -1
      examples/diffusers/janus/generation_inference.py
  6. +1
    -1
      examples/diffusers/janus/inference.py
  7. +6
    -0
      mindnlp/core/_prims/ascend.py
  8. +192
    -4
      mindnlp/core/_prims/numpy.py
  9. +26
    -9
      mindnlp/core/_tensor.py
  10. +1
    -0
      mindnlp/core/distributions/__init__.py
  11. +44
    -33
      mindnlp/core/distributions/gamma.py
  12. +32
    -24
      mindnlp/core/distributions/studentT.py
  13. +6
    -6
      mindnlp/core/distributions/utils.py
  14. +4
    -1
      mindnlp/core/linalg/__init__.py
  15. +3
    -1
      mindnlp/core/nn/functional.py
  16. +3
    -0
      mindnlp/core/npu/__init__.py
  17. +12
    -6
      mindnlp/core/ops/array.py
  18. +18
    -7
      mindnlp/core/ops/comparison.py
  19. +6
    -4
      mindnlp/core/ops/random.py
  20. +1
    -4100
      mindnlp/transformers/__init__.py
  21. +1
    -1
      mindnlp/transformers/masking_utils.py
  22. +0
    -36
      mindnlp/transformers/models/gemma3/__init__.py
  23. +1
    -1
      setup.py

+ 22
- 1
README.md View File

@@ -39,7 +39,28 @@

## News 📢

* 🔥 **Fully compatible with 🤗HuggingFace**, it enables seamless execution of any Transformers/Diffusers models on MindSpore across all hardware platforms (GPU/Ascend/CPU).
* ⚡ **MindNLP Core support Pytorch compatible:** To meet ecosystem compatibility requirements, we provide the `mindnlp.core` module to support compatibility with PyTorch interfaces. This module is built upon MindSpore's foundational APIs and operators, enabling model development using syntax similar to PyTorch. It also supports taking over torch interfaces through a Proxy, allowing the use of MindSpore for acceleration on Ascend hardware without the need for code modifications. The specific usage is as follows:

```python
import mindnlp # import mindnlp lib will enable proxy automaticlly
import torch
from torch import nn

# all torch.xx apis will be mapped to mindnlp.core.xx
net = nn.Linear(10, 5)
x = torch.randn(3, 10)
out = net(x)
print(out.shape)
# core.Size([3, 5])
```

It is particularly noteworthy that MindNLP supports several features not yet available in MindSpore, which enables better support for model serialization, heterogeneous computing, and other scenarios:
1. ​Dispatch Mechanism Support: Operators are dispatched to the appropriate backend based on Tensor.device.
2. ​Meta Device Support: Allows for shape inference without performing actual computations.
3. ​Numpy as CPU Backend: Supports using NumPy as a CPU backend for acceleration.
4. ​Tensor.to for Heterogeneous Data Movement: Facilitates the movement of data across different devices using `Tensor.to`.

* 🔥 **Fully compatible with 🤗HuggingFace:** It enables seamless execution of any Transformers/Diffusers models on MindSpore across all hardware platforms (GPU/Ascend/CPU).
You may still invoke models through MindNLP as shown in the example code below:



+ 17
- 10
examples/diffusers/janus/demo/app.py View File

@@ -1,13 +1,16 @@
import gradio as gr
import mindspore
import mindnlp
from mindnlp import core
import gradio as gr
from transformers import AutoConfig, AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.models import VLChatProcessor
from PIL import Image

import numpy as np

device = 'cpu'
if core.npu.is_available():
device = 'npu'
elif core.cuda.is_available():
device = 'cuda'

# Load model and processor
model_path = "deepseek-ai/Janus-1.3B"
@@ -16,7 +19,8 @@ language_config = config.language_config
language_config._attn_implementation = 'eager'
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
language_config=language_config,
trust_remote_code=True, ms_dtype=mindspore.float16)
trust_remote_code=True)
vl_gpt = vl_gpt.to(core.bfloat16).to(device)

vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
@@ -26,9 +30,12 @@ tokenizer = vl_chat_processor.tokenizer
# Multimodal Understanding function
def multimodal_understanding(image, question, seed, top_p, temperature):
# Clear CUDA cache before generating
core.cuda.empty_cache()
# set seed
mindspore.manual_seed(seed)
core.manual_seed(seed)
np.random.seed(seed)
core.cuda.manual_seed(seed)
conversation = [
{
@@ -42,9 +49,9 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
pil_images = [Image.fromarray(image)]
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(core.get_default_device(), dtype=mindspore.float16)
).to(device, dtype=core.bfloat16 if core.cuda.is_available() else core.float16)
print(prepare_inputs)
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
outputs = vl_gpt.language_model.generate(
@@ -75,13 +82,13 @@ def generate(input_ids,
# Clear CUDA cache before generating
core.cuda.empty_cache()
tokens = core.zeros((parallel_size * 2, len(input_ids)), dtype=core.int)
tokens = core.zeros((parallel_size * 2, len(input_ids)), dtype=core.int).to(device)
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
generated_tokens = core.zeros((parallel_size, image_token_num_per_image), dtype=core.int)
generated_tokens = core.zeros((parallel_size, image_token_num_per_image), dtype=core.int).to(device)

pkv = None
for i in range(image_token_num_per_image):


+ 40
- 21
examples/diffusers/janus/demo/app_janusflow.py View File

@@ -1,29 +1,40 @@
import gradio as gr
import mindspore
import mindnlp
from mindnlp import core
import gradio as gr
from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
from PIL import Image
from transformers import DynamicCache
from diffusers.models import AutoencoderKL
import numpy as np

device = 'cpu'
if core.npu.is_available():
device = 'npu'
elif core.cuda.is_available():
device = 'cuda'

# Load model and processor
model_path = "deepseek-ai/JanusFlow-1.3B"
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt = MultiModalityCausalLM.from_pretrained(model_path, ms_dtype=mindspore.float16)
vl_gpt = vl_gpt.eval()
vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
vl_gpt = vl_gpt.to(core.bfloat16).to(device).eval()

# remember to use bfloat16 dtype, this vae doesn't work with fp16
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", ms_dtype=mindspore.float16)
vae = vae.eval()
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
vae = vae.to(core.bfloat16).to(device).eval()

# Multimodal Understanding function
@core.inference_mode()
# Multimodal Understanding function
def multimodal_understanding(image, question, seed, top_p, temperature):
# Clear CUDA cache before generating
core.cuda.empty_cache()
# set seed
mindspore.manual_seed(seed)
core.manual_seed(seed)
np.random.seed(seed)
core.cuda.manual_seed(seed)
conversation = [
{
@@ -37,9 +48,9 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
pil_images = [Image.fromarray(image)]
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(core.get_default_device(), mindspore.float16)
).to(device, dtype=core.bfloat16 if core.cuda.is_available() else core.float16)
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
outputs = vl_gpt.language_model.generate(
@@ -60,13 +71,14 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
return answer


@core.inference_mode()
def generate(
input_ids,
cfg_weight: float = 2.0,
num_inference_steps: int = 30
):
# we generate 5 images at a time, *2 for CFG
tokens = core.stack([input_ids] * 10)
tokens = core.stack([input_ids] * 10).cuda()
tokens[5:, 1:] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
print(inputs_embeds.shape)
@@ -76,10 +88,10 @@ def generate(
# generate with rectified flow ode
# step 1: encode with vision_gen_enc
z = core.randn((5, 4, 48, 48), dtype=mindspore.float16)
z = core.randn((5, 4, 48, 48), dtype=core.bfloat16).cuda()
dt = 1.0 / num_inference_steps
dt = core.zeros_like(z).to(mindspore.float16) + dt
dt = core.zeros_like(z).cuda().to(core.bfloat16) + dt
# step 2: run ode
attention_mask = core.ones((10, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
@@ -103,18 +115,21 @@ def generate(
use_cache=True,
attention_mask=attention_mask,
past_key_values=None)
past_key_values = []
for kv_cache in past_key_values:
k, v = kv_cache[0], kv_cache[1]
past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
past_key_values = tuple(past_key_values)
past_key_values = DynamicCache.from_legacy_cache(outputs.past_key_values)

else:
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
use_cache=True,
attention_mask=attention_mask,
past_key_values=past_key_values)
past_key_values = []
for kv_cache in outputs.past_key_values:
k, v = kv_cache[0], kv_cache[1]
past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

hidden_states = outputs.last_hidden_state
# transform hidden_states back to v
hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
@@ -141,13 +156,17 @@ def unpack(dec, width, height, parallel_size=5):
return visual_img


@core.inference_mode()
def generate_image(prompt,
seed=None,
guidance=5,
num_inference_steps=30):
# Clear CUDA cache and avoid tracking gradients
core.cuda.empty_cache()
# Set the seed for reproducible results
if seed is not None:
mindspore.manual_seed(seed)
core.manual_seed(seed)
core.cuda.manual_seed(seed)
np.random.seed(seed)
with core.no_grad():


+ 37
- 16
examples/diffusers/janus/demo/app_januspro.py View File

@@ -1,13 +1,18 @@
import gradio as gr
import mindnlp
import mindspore
from mindnlp import core
import gradio as gr
from transformers import AutoConfig, AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
from janus.models import VLChatProcessor
from PIL import Image

import numpy as np
# import spaces # Import spaces for ZeroGPU compatibility

device = 'cpu'
if core.npu.is_available():
device = 'npu'
elif core.cuda.is_available():
device = 'cuda'


# Load model and processor
model_path = "deepseek-ai/Janus-Pro-7B"
@@ -15,17 +20,24 @@ config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
language_config=language_config,
trust_remote_code=True, ms_dtype=mindspore.float16)
language_config=language_config,
trust_remote_code=True)
vl_gpt = vl_gpt.to(core.bfloat16).to(device)

vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer

@core.inference_mode()
# @spaces.GPU(duration=120)
# Multimodal Understanding function
def multimodal_understanding(image, question, seed, top_p, temperature):
# Clear CUDA cache before generating
core.cuda.empty_cache()
# set seed
mindspore.manual_seed(seed)
core.manual_seed(seed)
np.random.seed(seed)
core.cuda.manual_seed(seed)
conversation = [
{
@@ -39,8 +51,9 @@ def multimodal_understanding(image, question, seed, top_p, temperature):
pil_images = [Image.fromarray(image)]
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
).to(core.get_default_device(), mindspore.float16)

).to(device, dtype=core.bfloat16 if core.cuda.is_available() else core.float16)
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
outputs = vl_gpt.language_model.generate(
@@ -68,14 +81,16 @@ def generate(input_ids,
cfg_weight: float = 5,
image_token_num_per_image: int = 576,
patch_size: int = 16):
# Clear CUDA cache before generating
core.cuda.empty_cache()
tokens = core.zeros((parallel_size * 2, len(input_ids)), dtype=mindspore.int32)
tokens = core.zeros((parallel_size * 2, len(input_ids)), dtype=core.int).to(device)
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
generated_tokens = core.zeros((parallel_size, image_token_num_per_image), dtype=mindspore.int32)
generated_tokens = core.zeros((parallel_size, image_token_num_per_image), dtype=core.int).to(device)

pkv = None
for i in range(image_token_num_per_image):
@@ -99,13 +114,13 @@ def generate(input_ids,


patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=mindspore.int32),
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=core.int),
shape=[parallel_size, 8, width // patch_size, height // patch_size])

return generated_tokens.to(dtype=mindspore.int32), patches
return generated_tokens.to(dtype=core.int), patches

def unpack(dec, width, height, parallel_size=5):
dec = dec.to(mindspore.float32).cpu().numpy().transpose(0, 2, 3, 1)
dec = dec.to(core.float32).cpu().numpy().transpose(0, 2, 3, 1)
dec = np.clip((dec + 1) / 2 * 255, 0, 255)

visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
@@ -114,13 +129,19 @@ def unpack(dec, width, height, parallel_size=5):
return visual_img



@core.inference_mode()
# @spaces.GPU(duration=120) # Specify a duration to avoid timeout
def generate_image(prompt,
seed=None,
guidance=5,
t2i_temperature=1.0):
# Clear CUDA cache and avoid tracking gradients
core.cuda.empty_cache()
# Set the seed for reproducible results
if seed is not None:
mindspore.manual_seed(seed)
core.manual_seed(seed)
core.cuda.manual_seed(seed)
np.random.seed(seed)
width = 384
height = 384


+ 1
- 1
examples/diffusers/janus/generation_inference.py View File

@@ -36,7 +36,7 @@ language_config._attn_implementation = 'eager'
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
model_path, language_config=language_config, trust_remote_code=True, ms_dtype=torch.float16
)
vl_gpt = vl_gpt.eval()
vl_gpt = vl_gpt.eval().cuda()

conversation = [
{


+ 1
- 1
examples/diffusers/janus/inference.py View File

@@ -35,7 +35,7 @@ language_config._attn_implementation = 'eager'
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
model_path, language_config=language_config, trust_remote_code=True, torch_dtype=torch.float16
)
vl_gpt = vl_gpt.eval()
vl_gpt = vl_gpt.eval().npu()

conversation = [
{


+ 6
- 0
mindnlp/core/_prims/ascend.py View File

@@ -421,3 +421,9 @@ def dropout2d(input, p):
return dropout_2d_op(input)

__all__.append('dropout2d')

def bernoulli_ext(input, generator):
seed, offset = generator._step(12) # pylint: disable=protected-access
return gen_ops_prim.bernoulli_ext_op(input, seed, offset)

__all__.append('bernoulli_ext')

+ 192
- 4
mindnlp/core/_prims/numpy.py View File

@@ -103,6 +103,8 @@ def dyn_shape(self):
__all__.append('dyn_shape')

def cast(input, dtype):
if input.dtype == dtype:
return input
out = input.asnumpy().astype(core.dtype2np[dtype])
return core.Tensor.from_numpy(out)

@@ -150,6 +152,13 @@ def bitwise_and_scalar(input, other):

__all__.append('bitwise_and_scalar')


def bitwise_or_tensor(input, other):
out = np.bitwise_or(input.numpy(), other.numpy())
return core.Tensor.from_numpy(out)

__all__.append('bitwise_or_tensor')

def right_shift(input, other):
out = np.right_shift(input.numpy(), other)
return core.Tensor.from_numpy(out)
@@ -240,9 +249,17 @@ def normal(shape):

__all__.append('normal')

def pad_v3(input, pad, mode, value):
out = np.pad(input.asnumpy(), pad, mode, constant_values=value)
return core.Tensor.from_numpy(out)
def pad_v3(input_x, padding, mode='constant', value=None):
pad_op = ops.PadV3(mode=mode, paddings_contiguous=True).set_device('CPU')
if input_x.dtype == core.bool:
input_x = input_x.to(core.int32)
value = int(value)
out = pad_op(input_x, padding, value)
return cast(out, core.bool)

if isinstance(value, (float, int)):
value = core.tensor(value, dtype=input_x.dtype)
return pad_op(input_x, padding, value)

__all__.append('pad_v3')

@@ -332,6 +349,8 @@ def tile(input, dims):
__all__.append('tile')

def squeeze(input, dim):
if isinstance(dim, int) and input.shape[dim] != 1:
return input
out = np.squeeze(input.numpy(), dim)
return core.Tensor.from_numpy(out)

@@ -378,6 +397,13 @@ def inplace_fill_scalar(input, value):

__all__.append('inplace_fill_scalar')

def inplace_fill_tensor(input, value):
out = np.full_like(input.numpy(), value)
numpy_to_tensor_overwrite(out, input)
return input

__all__.append('inplace_fill_tensor')

def inplace_normal(input, mean, std, generator_):
out = np.random.normal(mean, std, input.shape).astype(core.dtype2np[input.dtype])
numpy_to_tensor_overwrite(out, input)
@@ -590,6 +616,8 @@ __all__.append('randn')

def erfinv(input):
out = scipy.special.erfinv(input)
if not isinstance(out, np.ndarray):
out = np.array(out)
return core.Tensor.from_numpy(out)

__all__.append('erfinv')
@@ -661,6 +689,8 @@ __all__.append('argmax_ext')

def log(input):
out = np.log(input.numpy())
if not isinstance(out, np.ndarray):
out = np.array(out)
return core.Tensor.from_numpy(out)

__all__.append('log')
@@ -842,6 +872,8 @@ def exp(input):
out = np.exp(input.numpy())
if input.dtype == np.int64:
out = out.astype(np.float32)
if not isinstance(out, np.ndarray):
out = np.array(out)
return core.Tensor.from_numpy(out)

__all__.append('exp')
@@ -896,6 +928,8 @@ __all__.append('maximum')

def prod_ext(input, dim, keepdim, dtype):
out = np.prod(input.numpy(), axis=dim, keepdims=keepdim)
if not isinstance(out, np.ndarray):
out = np.array(out)
return core.Tensor.from_numpy(out)

__all__.append('prod_ext')
@@ -973,4 +1007,158 @@ def one_hot_ext(tensor, num_classes=-1):
out = np.eye(num_classes)[tensor.numpy()]
return core.Tensor.from_numpy(out)

__all__.append('one_hot_ext')
__all__.append('one_hot_ext')

def log1p(input):
out = np.log1p(input.numpy())
return core.Tensor.from_numpy(out)

__all__.append('log1p')

def gather(input, indices, _dimension):
out = np.take(input.numpy(), indices.numpy(), _dimension)
return core.Tensor.from_numpy(out)

__all__.append('gather')


def layer_norm_ext(input, normalized_shape, weight=None, bias=None, eps=1e-5):
# 确定需要计算均值和方差的轴
# 从第一个维度开始到 normalized_shape 所涵盖的维度之前的维度会被保留(即 batch 维度等)
# 我们需要计算所有不在最后 len(normalized_shape) 个维度上的轴的均值和方差
input = input.numpy()
if weight is not None:
weight = weight.numpy()
if bias is not None:
bias = bias.numpy()

start_axis = input.ndim - len(normalized_shape)
axes = tuple(range(start_axis, input.ndim))
# 计算均值和方差,并保持维度以便广播
mean = np.mean(input, axis=axes, keepdims=True)
var = np.var(input, axis=axes, keepdims=True)
# 标准化: (x - mean) / sqrt(var + eps)
normalized = (input - mean) / np.sqrt(var + eps)
# 应用可学习的缩放和平移参数 (gamma 和 beta)
if weight is not None:
normalized = normalized * weight
if bias is not None:
normalized = normalized + bias
return (core.Tensor.from_numpy(normalized),)

__all__.append('layer_norm_ext')

def erf(input):
out = scipy.special.erf(input.numpy())
return core.Tensor.from_numpy(out)

__all__.append('erf')

def mse_loss_ext(input, target, reduction='mean'):
if input.shape != target.shape:
raise ValueError(f"Input and target must have the same shape. Got input: {input.shape}, target: {target.shape}")

squared_errors = np.square(input - target)

if reduction == 'mean':
loss = np.mean(squared_errors)
elif reduction == 'sum':
loss = np.sum(squared_errors)
elif reduction == 'none':
loss = squared_errors
else:
raise ValueError("Reduction must be 'mean', 'sum', or 'none'.")

if not isinstance(loss, np.ndarray):
loss = np.array(loss)
return core.Tensor.from_numpy(loss)

__all__.append('mse_loss_ext')

def square(input):
out = np.square(input.numpy())
return core.Tensor.from_numpy(out)

__all__.append('square')

def lgamma(input):
out = scipy.special.gammaln(input.numpy())
return core.Tensor.from_numpy(out)

__all__.append('lgamma')

def gamma(shape, alpha, beta):
out = np.random.gamma(alpha, 1/beta, shape)
return core.Tensor.from_numpy(out)

__all__.append('gamma')

def gather_d(input, dim, index):
indices = []
for axis in range(input.ndim):
if axis == dim:
indices.append(index)
else:
shape = [1] * index.ndim
shape[axis] = input.shape[axis]
indices.append(np.arange(input.shape[axis]).reshape(shape))
out = input[tuple(indices)]
if not isinstance(out, np.ndarray):
out = np.array(out)
return core.Tensor.from_numpy(out)

__all__.append('gather_d')


def log_softmax(x, axis=-1):
x = x.numpy()
x_max = np.max(x, axis=axis, keepdims=True)
x_shifted = x - x_max
exp_x = np.exp(x_shifted)
sum_exp_x = np.sum(exp_x, axis=axis, keepdims=True)
log_sum_exp_x = np.log(sum_exp_x)
out = x_shifted - log_sum_exp_x
return core.Tensor.from_numpy(out)

__all__.append('log_softmax')

def nllloss(input, target, weight=None, reduction='mean', ignore_index=-100):
op = ops.NLLLoss(reduction, ignore_index).set_device('CPU')
return op(input, target, weight)

__all__.append('nllloss')

def linalg_qr(A, mode):
# out = np.linalg.qr(A.numpy(), mode)
# return [core.Tensor.from_numpy(o) for o in out]
if mode not in ('reduced', 'complete'):
raise TypeError(f"For qr, the arg mode must be 'reduced' or 'complete', but got {mode}.")
qr_ = _get_cache_prim(ops.Qr)(mode == 'complete').set_device('CPU')
return qr_(A)

__all__.append('linalg_qr')

def diag_ext(input, diagonal):
out = np.diag(input.numpy(), diagonal)
return core.Tensor.from_numpy(out)

__all__.append('diag_ext')

def sign(input):
out = np.sign(input.numpy())
return core.Tensor.from_numpy(out)

__all__.append('sign')

def log2(input):
out = np.log2(input.numpy())
return core.Tensor.from_numpy(out)

__all__.append('log2')

+ 26
- 9
mindnlp/core/_tensor.py View File

@@ -175,6 +175,12 @@ class TensorPlaceHolder:
self.requires_grad = requires_grad
return self
def __array_wrap__(self, array):
if array.dtype == bool:
# Workaround, torch has no built-in bool tensor
array = array.astype("uint8")
return ops.from_numpy(array)
def __reduce_ex__(self, protocol):
if isinstance(self, StubTensor):
data = Tensor_(self.stub_sync())
@@ -342,6 +348,8 @@ class TensorPlaceHolder:
return ops.sub(other, self)
def __eq__(self, other):
if other is None:
return False
return ops.eq(self, other)
def __gt__(self, other):
@@ -912,7 +920,8 @@ class TensorPlaceHolder:
if isinstance(self, StubTensor) and isinstance(new_value, StubTensor):
self.stub = new_value.stub
else:
if self.device.type == 'cpu' and new_value.device.type == 'cpu':
if self.device.type == 'cpu' and new_value.device.type == 'cpu' \
and self.shape == new_value.shape and self.dtype == new_value.dtype:
src_ct = ctypes.c_void_p(new_value.data_ptr())
dst_ct = ctypes.c_void_p(self.data_ptr())
ctypes.memmove(dst_ct, src_ct, self.nbytes)
@@ -1594,7 +1603,8 @@ class TensorPlaceHolder:
# Tensor.max
def max(self, dim=None, keepdim=False):
def max(self, dim=None, keepdim=False, **kwargs):
dim = kwargs.pop('axis', dim)
return ops.max(self, dim, keepdim)
# Tensor.maximum
@@ -1884,7 +1894,8 @@ class TensorPlaceHolder:
return ops.repeat_interleave(self, repeats, dim, output_size=output_size)
# Tensor.reshape
def reshape(self, *shape):
def reshape(self, *shape, **kwargs):
shape = kwargs.pop('shape', shape)
return ops.reshape(self, *shape)
# Tensor.reshape_as
@@ -1949,13 +1960,13 @@ class TensorPlaceHolder:
# Tensor.scatter_reduce_
def scatter_reduce_(self, dim, index, src):
def scatter_reduce_(self, dim, index, src, reduce, *, include_self=True):
return self.copy_(ops.scatter_reduce(self, dim, index, src))
# Tensor.scatter_reduce
def scatter_reduce(self, dim, index, src):
return ops.scatter_reduce(self, dim, index, src)
def scatter_reduce(self, dim, index, src, reduce, *, include_self=True):
return ops.scatter_reduce(self, dim, index, src, reduce)
# Tensor.select
@@ -2096,8 +2107,11 @@ class TensorPlaceHolder:
return self.copy_(ops.square(self))
# Tensor.squeeze
def squeeze(self, *args, **kwargs):
return ops.squeeze(self, *args, **kwargs)
def squeeze(self, *dim, **kwargs):
dim = kwargs.pop('dim', dim)
if isinstance(dim, tuple) and len(dim) == 1:
dim = dim[0]
return ops.squeeze(self, dim)
# Tensor.squeeze_
def squeeze_(self, dim=None):
@@ -2150,6 +2164,8 @@ class TensorPlaceHolder:
def sum(self, dim=None, keepdim=False, dtype=None, **kwargs):
dim = kwargs.pop('axis', dim)
keepdim = kwargs.pop('keepdims', keepdim)
if isinstance(dim, list):
dim = tuple(dim)
return ops.sum(self, dim, keepdim, dtype=dtype)
# Tensor.sum_to_size
@@ -2424,7 +2440,8 @@ class TensorPlaceHolder:
# Tensor.var
def var(self, dim=None, *, correction=1, keepdim=False):
def var(self, dim=None, *, correction=1, keepdim=False, **kwargs):
correction = int(kwargs.pop('unbiased', correction))
return ops.var(self, dim, correction=correction, keepdim=keepdim)
# Tensor.vdot


+ 1
- 0
mindnlp/core/distributions/__init__.py View File

@@ -1,6 +1,7 @@
"""distributions"""
from .bernoulli import Bernoulli
from .categorical import Categorical
from .chi2 import *
from .distribution import Distribution
from .independent import Independent
from .negative_binomial import NegativeBinomial


+ 44
- 33
mindnlp/core/distributions/gamma.py View File

@@ -1,16 +1,21 @@
"""gamma"""
# mypy: allow-untyped-defs
from numbers import Number
from typing import Optional, Union
from .. import ops
from . import constraints
from .exp_family import ExponentialFamily
from .utils import broadcast_all
from mindnlp import core
from mindnlp.core import Tensor
from mindnlp.core.distributions import constraints
from mindnlp.core.distributions.exp_family import ExponentialFamily
from mindnlp.core.distributions.utils import broadcast_all
from mindnlp.core.types import _Number, _size
__all__ = ["Gamma"]
def _standard_gamma(concentration):
return core._standard_gamma(concentration)
class Gamma(ExponentialFamily):
r"""
Creates a Gamma distribution parameterized by shape :attr:`concentration` and :attr:`rate`.
@@ -25,9 +30,10 @@ class Gamma(ExponentialFamily):
Args:
concentration (float or Tensor): shape parameter of the distribution
(often referred to as alpha)
rate (float or Tensor): rate = 1 / scale of the distribution
(often referred to as beta)
rate (float or Tensor): rate parameter of the distribution
(often referred to as beta), rate = 1 / scale
"""
arg_constraints = {
"concentration": constraints.positive,
"rate": constraints.positive,
@@ -37,77 +43,82 @@ class Gamma(ExponentialFamily):
_mean_carrier_measure = 0
@property
def mean(self):
def mean(self) -> Tensor:
return self.concentration / self.rate
@property
def mode(self):
def mode(self) -> Tensor:
return ((self.concentration - 1) / self.rate).clamp(min=0)
@property
def variance(self):
def variance(self) -> Tensor:
return self.concentration / self.rate.pow(2)
def __init__(self, concentration, rate, validate_args=None):
def __init__(
self,
concentration: Union[Tensor, float],
rate: Union[Tensor, float],
validate_args: Optional[bool] = None,
) -> None:
self.concentration, self.rate = broadcast_all(concentration, rate)
if isinstance(concentration, Number) and isinstance(rate, Number):
batch_shape = ()
if isinstance(concentration, _Number) and isinstance(rate, _Number):
batch_shape = core.Size()
else:
batch_shape = self.concentration.shape
batch_shape = self.concentration.size()
super().__init__(batch_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Gamma, _instance)
new.concentration = self.concentration.broadcast_to(batch_shape)
new.rate = self.rate.broadcast_to(batch_shape)
batch_shape = core.Size(batch_shape)
new.concentration = self.concentration.expand(batch_shape)
new.rate = self.rate.expand(batch_shape)
super(Gamma, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new
def rsample(self, sample_shape=()):
def rsample(self, sample_shape: _size = core.Size()) -> Tensor:
shape = self._extended_shape(sample_shape)
if shape == (): # pylint: disable=use-implicit-booleaness-not-comparison
sample_shape = (1,)
else:
sample_shape = shape
value = ops.gamma(sample_shape, self.concentration, self.rate)
value = core.gamma(sample_shape, self.concentration, self.rate)
if shape == (): # pylint: disable=use-implicit-booleaness-not-comparison
value = ops.squeeze(value)
value = core.squeeze(value)
value = value.clamp(
min=float(ops.finfo(value.dtype).tiny)
value.detach().clamp_(
min=core.finfo(value.dtype).tiny
) # do not record in autograd graph
return value
def log_prob(self, value):
value = ops.as_tensor(value, dtype=self.rate.dtype)
value = core.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device)
if self._validate_args:
self._validate_sample(value)
return (
ops.xlogy(self.concentration, self.rate)
+ ops.xlogy(self.concentration - 1, value)
core.xlogy(self.concentration, self.rate)
+ core.xlogy(self.concentration - 1, value)
- self.rate * value
- ops.lgamma(self.concentration)
- core.lgamma(self.concentration)
)
def entropy(self):
return (
self.concentration
- ops.log(self.rate)
+ ops.lgamma(self.concentration)
+ (1.0 - self.concentration) * ops.digamma(self.concentration)
- core.log(self.rate)
+ core.lgamma(self.concentration)
+ (1.0 - self.concentration) * core.digamma(self.concentration)
)
@property
def _natural_params(self):
def _natural_params(self) -> tuple[Tensor, Tensor]:
return (self.concentration - 1, -self.rate)
def _log_normalizer(self, x, y):
return ops.lgamma(x + 1) + (x + 1) * ops.log(-y.reciprocal())
return core.lgamma(x + 1) + (x + 1) * core.log(-y.reciprocal())
def cdf(self, value):
if self._validate_args:
self._validate_sample(value)
return ops.igamma(self.concentration, self.rate * value)
return core.special.gammainc(self.concentration, self.rate * value)

+ 32
- 24
mindnlp/core/distributions/studentT.py View File

@@ -1,13 +1,13 @@
"""studentT"""
# mypy: allow-untyped-defs
import math
from math import inf, nan
from typing import Optional, Union
from .. import ops
from . import constraints
from .chi2 import Chi2
from .distribution import Distribution
from .utils import _standard_normal, broadcast_all
from mindnlp import core
from mindnlp.core import inf, nan, Tensor
from mindnlp.core.distributions import Chi2, constraints
from mindnlp.core.distributions.distribution import Distribution
from mindnlp.core.distributions.utils import _standard_normal, broadcast_all
from mindnlp.core.types import _size
__all__ = ["StudentT"]
@@ -30,6 +30,7 @@ class StudentT(Distribution):
loc (float or Tensor): mean of the distribution
scale (float or Tensor): scale of the distribution
"""
arg_constraints = {
"df": constraints.positive,
"loc": constraints.real,
@@ -39,18 +40,18 @@ class StudentT(Distribution):
has_rsample = True
@property
def mean(self):
m = self.loc.copy()
def mean(self) -> Tensor:
m = self.loc.clone(memory_format=core.contiguous_format)
m[self.df <= 1] = nan
return m
@property
def mode(self):
def mode(self) -> Tensor:
return self.loc
@property
def variance(self):
m = self.df.copy()
def variance(self) -> Tensor:
m = self.df.clone(memory_format=core.contiguous_format)
m[self.df > 2] = (
self.scale[self.df > 2].pow(2)
* self.df[self.df > 2]
@@ -60,14 +61,21 @@ class StudentT(Distribution):
m[self.df <= 1] = nan
return m
def __init__(self, df, loc=0.0, scale=1.0, validate_args=None):
def __init__(
self,
df: Union[Tensor, float],
loc: Union[Tensor, float] = 0.0,
scale: Union[Tensor, float] = 1.0,
validate_args: Optional[bool] = None,
) -> None:
self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
self._chi2 = Chi2(self.df)
batch_shape = self.df.shape
batch_shape = self.df.size()
super().__init__(batch_shape, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(StudentT, _instance)
batch_shape = core.Size(batch_shape)
new.df = self.df.expand(batch_shape)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
@@ -76,7 +84,7 @@ class StudentT(Distribution):
new._validate_args = self._validate_args
return new
def rsample(self, sample_shape=()):
def rsample(self, sample_shape: _size = core.Size()) -> Tensor:
# NOTE: This does not agree with scipy implementation as much as other distributions.
# (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor
# parameters seems to help.
@@ -85,9 +93,9 @@ class StudentT(Distribution):
# Z ~ Chi2(df)
# Y = X / sqrt(Z / df) ~ StudentT(df)
shape = self._extended_shape(sample_shape)
X = _standard_normal(shape, dtype=self.df.dtype)
X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device)
Z = self._chi2.rsample(sample_shape)
Y = X * ops.rsqrt(Z / self.df)
Y = X * core.rsqrt(Z / self.df)
return self.loc + self.scale * Y
def log_prob(self, value):
@@ -98,22 +106,22 @@ class StudentT(Distribution):
self.scale.log()
+ 0.5 * self.df.log()
+ 0.5 * math.log(math.pi)
+ ops.lgamma(0.5 * self.df)
- ops.lgamma(0.5 * (self.df + 1.0))
+ core.lgamma(0.5 * self.df)
- core.lgamma(0.5 * (self.df + 1.0))
)
return -0.5 * (self.df + 1.0) * ops.log1p(y**2.0 / self.df) - Z
return -0.5 * (self.df + 1.0) * core.log1p(y**2.0 / self.df) - Z
def entropy(self):
lbeta = (
ops.lgamma(0.5 * self.df)
core.lgamma(0.5 * self.df)
+ math.lgamma(0.5)
- ops.lgamma(0.5 * (self.df + 1))
- core.lgamma(0.5 * (self.df + 1))
)
return (
self.scale.log()
+ 0.5
* (self.df + 1)
* (ops.digamma(0.5 * (self.df + 1)) - ops.digamma(0.5 * self.df))
* (core.digamma(0.5 * (self.df + 1)) - core.digamma(0.5 * self.df))
+ 0.5 * self.df.log()
+ lbeta
)
)

+ 6
- 6
mindnlp/core/distributions/utils.py View File

@@ -63,12 +63,12 @@ def _standard_normal(
dtype: Optional[_dtype],
device: Optional[Device],
) -> Tensor:
if core._C._get_tracing_state():
# [JIT WORKAROUND] lack of support for .normal_()
return core.normal(
core.zeros(shape, dtype=dtype, device=device),
core.ones(shape, dtype=dtype, device=device),
)
# if core._C._get_tracing_state():
# # [JIT WORKAROUND] lack of support for .normal_()
# return core.normal(
# core.zeros(shape, dtype=dtype, device=device),
# core.ones(shape, dtype=dtype, device=device),
# )
return core.empty(shape, dtype=dtype, device=device).normal_()


+ 4
- 1
mindnlp/core/linalg/__init__.py View File

@@ -1,7 +1,7 @@
from collections import namedtuple
import numpy as np

from mindspore import ops, mint
from mindspore import ops
from mindspore.ops._primitive_cache import _get_cache_prim

from mindnlp import core
@@ -32,3 +32,6 @@ def vector_norm(x, ord=2, dim=None, keepdim=False, *, dtype=None, out=None):

def solve(A, B, *, left=True, out=None):
return core.tensor(np.linalg.solve(A.numpy(), B.numpy()))

def qr(A, mode='reduced'):
return execute('linalg_qr', A, mode)

+ 3
- 1
mindnlp/core/nn/functional.py View File

@@ -56,7 +56,7 @@ def softplus(input, beta=1, threshold=20):
return execute('softplus_ext', input, beta, threshold)
def logsigmoid(input):
return execute('logsigmoid', input)
return execute('logsigmoid', input)[0]
def leaky_relu(input, alpha=0.2):
return execute('leaky_relu_ext', input, alpha)
@@ -291,6 +291,8 @@ def pad(input, pad, mode='constant', value=None):
if input.device.type in ['cpu', 'meta'] or ON_A1:
new_pad = ()
for idx, pad_v in enumerate(pad):
if not isinstance(pad_v, int):
pad_v = pad_v.item()
if pad_v < 0:
dim = input.ndim - 1 - idx // 2
input = input.narrow(dim, 0, input.shape[dim] + pad_v)


+ 3
- 0
mindnlp/core/npu/__init__.py View File

@@ -27,6 +27,9 @@ class DefaultGenerators:
def __getitem__(self, idx):
return core.default_generator
def __len__(self):
return 1
default_generators = DefaultGenerators()
def set_compile_mode(*args, **kwargs):


+ 12
- 6
mindnlp/core/ops/array.py View File

@@ -48,7 +48,8 @@ def concat(tensors, dim=0, **kwargs):
# concatenate
def concatenate(tensors, dim=0):
def concatenate(tensors, dim=0, **kwargs):
dim = kwargs.pop('axis', dim)
return cat(tensors, dim)
@@ -227,7 +228,7 @@ def reshape(input, *shape):
shape = shape[0]
new_shape = ()
for s in shape:
if not isinstance(s, numbers.Number):
if not isinstance(s, numbers.Number) or isinstance(s, np.int64):
s = s.item()
new_shape += (s,)
return execute("reshape", input, new_shape)
@@ -270,6 +271,11 @@ def scatter_add(input, dim, index, src):
# scatter_reduce
def scatter_reduce(input, dim, index, src, reduce, *, include_self=True):
if reduce == 'sum':
return scatter_add(input, dim, index, src)
else:
raise ValueError(f'do not support reduce: {reduce}')
# split
@@ -290,11 +296,11 @@ def split_with_sizes(input, split_sizes, dim=0):
return execute("split_with_size", input, split_sizes, dim)
# squeeze
def squeeze(input, *dim, **kwargs):
dim = kwargs.get('dim', dim)
def squeeze(input, dim=None):
if dim is None:
dim = ()
return execute("squeeze", input, dim)
# stack
@@ -1004,7 +1010,7 @@ __all__ = [
# select_scatter
# slice_scatter
"scatter_add",
# scatter_reduce
"scatter_reduce",
"split",
"squeeze",
"stack",


+ 18
- 7
mindnlp/core/ops/comparison.py View File

@@ -51,12 +51,20 @@ def isfinite(input):
return execute('isfinite', input)
# isin
def isin(elements, test_elements):
def in1d(ar1, ar2, invert=False):
ar1 = core.unsqueeze(ar1.ravel(), -1)
ar2 = ar2.ravel()
included = core.eq(ar1, ar2)
# ops.reduce_sum only supports float
res = core.sum(included.to(core.float32), -1).to(core.bool_)
if invert:
res = core.logical_not(res)
return res
def isin(elements, test_elements, invert=False):
if elements.device.type != 'cpu':
test_elements = core.tensor(test_elements)
if test_elements.ndim == 0:
test_elements = test_elements.unsqueeze(0)
return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze()
res = in1d(elements, test_elements, invert=invert)
return core.reshape(res, elements.shape)
return execute('isin', elements, test_elements)
@@ -94,6 +102,8 @@ def less(input, other):
# maximum
def maximum(input, other):
if isinstance(other, core.Tensor) and other.device != input.device:
other = other.to(input.device)
return execute('maximum', input, other)
# minimum
@@ -128,8 +138,9 @@ def topk(input, k, dim=-1, largest=True, sorted=True):
if not largest:
res = execute('topk', input, k, sorted)
values, indices = -res[0], res[1]
return values, indices
return execute('topk', input, k, sorted)
return topk_out(values=values, indices=indices)
out = execute('topk', input, k, sorted)
return topk_out(values=out[0], indices=out[1])
input = input.swapaxes(dim, input.ndim - 1)
output = execute('topk', input, k, sorted)
values = output[0].swapaxes(dim, input.ndim - 1)


+ 6
- 4
mindnlp/core/ops/random.py View File

@@ -9,11 +9,10 @@ generator_step_ = 12
# bernoulli
def bernoulli(input, *, generator=None, out=None):
def bernoulli(input, *, generator=None, out=None, **kwargs):
if generator is None:
generator = default_generator
seed, offset = generator._step(generator_step_) # pylint: disable=protected-access
output = execute("bernoulli_ext", input, seed, offset)
output = execute("bernoulli_ext", input, generator)
if out is None:
return output
out.data = output
@@ -239,7 +238,7 @@ def randint_like(
high,
seed,
offset,
dtype_to_type_id("RandIntLike", "dtype", dtype),
dtype,
device=device,
requires_grad=requires_grad,
)
@@ -349,6 +348,8 @@ def randperm(
out.data = output
return out
def gamma(shape, alpha, beta):
return execute('gamma', shape, alpha, beta)
__all__ = [
"bernoulli",
@@ -361,4 +362,5 @@ __all__ = [
"randn_like",
"randperm",
"randint_like",
"gamma"
]

+ 1
- 4100
mindnlp/transformers/__init__.py
File diff suppressed because it is too large
View File


+ 1
- 1
mindnlp/transformers/masking_utils.py View File

@@ -410,7 +410,7 @@ def sdpa_mask_older_torch(
# However, in more recent version of Pytorch, a trick was introduced to handle it - which is the reason we have
# `sdpa_mask_recent_torch`, as it allows more general `mask_function`
# causal_mask = mask_function(None, None, cache_position, kv_arange)
causal_mask = mask_function(None, None, cache_position.reshape(cache_position.shape[0], 1), kv_arange.reshape(1, kv_arange.shape[0]))
causal_mask = mask_function(slice(None), None, cache_position.reshape(cache_position.shape[0], 1), kv_arange.reshape(1, kv_arange.shape[0]))
# causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange)
if causal_mask.ndim == 2:
causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)


+ 0
- 36
mindnlp/transformers/models/gemma3/__init__.py View File

@@ -1,36 +0,0 @@
from typing import Optional, Callable
import mindspore
from mindspore import ops, nn, mint
from mindnlp import core

def token_type_ids_mask_function(
token_type_ids: Optional[core.Tensor],
image_group_ids: Optional[core.Tensor],
tokens_per_image: int,
) -> Optional[Callable]:
"""
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
not start and end indices.
"""
# Do not return an additional mask in this case
if token_type_ids is None:
return None

def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
# If it's 1 for both query and key/value, we are in an image block
# NOTE: static cache shape goes beyond input seq length, while token_type_ids.shape[1] == input seq length
# Since vmap doesn't support `if statement` we workaround it with `torch.where`
safe_idx = core.where(kv_idx < token_type_ids.shape[1], kv_idx, 0)
token_type_ids_at_kv_idx = token_type_ids[:, safe_idx]
token_type_ids_at_kv_idx = core.where(kv_idx < token_type_ids.shape[1], token_type_ids_at_kv_idx, 0)

image_group_ids_at_kv_idx = image_group_ids[:, safe_idx]
image_group_ids_at_kv_idx = core.where(kv_idx < image_group_ids.shape[1], image_group_ids_at_kv_idx, -1)

is_image_block = (token_type_ids[:, q_idx] == 1) & (token_type_ids_at_kv_idx == 1)
same_image_block = image_group_ids[:, q_idx] == image_group_ids_at_kv_idx

# This is bidirectional attention whenever we are dealing with image tokens
return is_image_block & same_image_block

return inner_mask

+ 1
- 1
setup.py View File

@@ -64,7 +64,7 @@ class CustomInstall(install):
_create_namespace_links() # 安装后创建链接


version = '0.5.0'
version = '0.5.0rc1'
cur_dir = os.path.dirname(os.path.realpath(__file__))
pkg_dir = os.path.join(cur_dir, 'build')



Loading…
Cancel
Save
Baidu
map