Source code for mxnet.numpy.stride_tricks

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""Util functions with broadcast."""

from ..ndarray.ndarray import _get_broadcast_shape
from ..ndarray import numpy as _mx_nd_np
from .multiarray import asarray


__all__ = ['broadcast_arrays']


def _broadcast_shape(*args):
    shape = ()
    for arr in args:
        shape = _get_broadcast_shape(shape, arr.shape)
    return shape


[docs] def broadcast_arrays(*args): """ Broadcast any number of arrays against each other. Parameters ---------- `*args` : a list of ndarrays The arrays to broadcast. Returns ------- broadcasted : list of arrays These arrays are copies of the original arrays unless that all the input arrays have the same shape, the input list of arrays are returned instead of a list of copies. Examples -------- >>> x = np.array([[1,2,3]]) >>> y = np.array([[4],[5]]) >>> np.broadcast_arrays(x, y) [array([[1., 2., 3.], [1., 2., 3.]]), array([[4., 4., 4.], [5., 5., 5.]])] """ args = [asarray(array) for array in args] shape = _broadcast_shape(*args) if all(array.shape == shape for array in args): # Common case where nothing needs to be broadcasted. return list(args) return [_mx_nd_np.broadcast_to(array, shape) for array in args]