from csdl_alpha.src.graph.operation import Operation, set_properties
from csdl_alpha.src.operations.operation_subclasses import ComposedOperation
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
import csdl_alpha.utils.testing_utils as csdl_tests
import csdl_alpha as csdl
from csdl_alpha.src.operations.derivatives.utils import get_uncontract_action
import numpy as np
class Average(Operation):
'''
Average entries in the input tensor along the specified axes.
'''
def __init__(self, x, axes=None, out_shape=None):
super().__init__(x)
self.name = 'average'
out_shapes = (out_shape,)
self.axes = axes
self.set_dense_outputs(out_shapes)
def compute_inline(self, x):
if self.axes is None:
return np.average(x)
else:
return np.average(x, axis=self.axes)
def compute_jax(self, x):
import jax.numpy as jnp
if self.axes is None:
return jnp.average(x)
else:
return jnp.average(x, axis=self.axes)
def evaluate_vjp(self, cotangents, x, y):
if cotangents.check(x):
import csdl_alpha as csdl
if self.axes is None:
cotangents.accumulate(x, csdl.expand(cotangents[y]/x.size, out_shape=x.shape))
else:
action = get_uncontract_action(x.shape, self.axes)
axis_size = 1
for axis in self.axes:
axis_size *= x.shape[axis]
vy = cotangents[y] / axis_size
cotangents.accumulate(x, csdl.expand(vy, action = action, out_shape=x.shape))
class ElementwiseAverage(ComposedOperation):
'''
Elementwise average of all the Variables in the arguments.
'''
def __init__(self, *args):
super().__init__(*args)
self.name = 'elementwise_average'
def evaluate_composed(self, *args):
return evaluate_elementwise_average(*args)
def evaluate_elementwise_average(*args):
out = csdl.sum(*args)/len(args)
return out
[docs]def average(*args, axes=None):
'''
Computes the average of all entries in the input tensor if a single argument is provided.
Computes the average of all entries along the specified axes if `axes` argument is given.
Computes the elementwise average of multiple variables of the same shape,
if multiple arguments are provided. Axes argument is not allowed in this case.
Parameters
----------
*args : tuple of Variable or np.ndarray objects
Input tensor/s whose average/s needs to be computed.
axes : tuple of int, default=None
Axes along which to compute the average of the input tensor,
if there's only one input tensor.
Returns
-------
Variable
Average of all entries in the input tensor if a single argument is provided.
Average of entries along the specified axes if `axes` argument is given.
Elementwise average of multiple variables of the same shape,
if multiple arguments are provided.
Examples
--------
>>> recorder = csdl.Recorder(inline = True)
>>> recorder.start()
>>> x = csdl.Variable(value = np.array([1.0, 2.0, 3.0]))
>>> y1 = csdl.average(x)
>>> y1.value
array([2.])
Average of a single tensor variable along a specified axis
>>> x_val = np.arange(6).reshape(2,3)
>>> x = csdl.Variable(value = x_val)
>>> y2 = csdl.average(x, axes=(1,))
>>> y2.value
array([1., 4.])
Elementwise average of multiple tensor variables
>>> y3 = csdl.average(x, 2 * np.ones((2,3)), np.ones((2,3)))
>>> y3.value
array([[1. , 1.33333333, 1.66666667],
[2. , 2.33333333, 2.66666667]])
'''
# Multiple Variables to average
if axes is not None and len(args) > 1:
raise ValueError('Cannot average multiple Variables along specified axes. \
Use X = average(A,B,...) followed by out=average(X, axes=(...)) instead.')
if any(args[i].shape != args[0].shape for i in range(1, len(args))):
raise ValueError('All Variables must have the same shape.')
# Single Variable to average
if axes is not None:
if any(np.asarray(axes) > (len(args[0].shape)-1)):
raise ValueError('Specified axes cannot be more than the rank of the Variable averaged.')
if any(np.asarray(axes) < 0):
raise ValueError('Axes cannot have negative entries.')
if len(args) == 1:
if axes is None:
out_shape = (1,)
else:
out_shape = tuple([x for i, x in enumerate(args[0].shape) if i not in axes])
if len(out_shape) == 0:
raise ValueError('It is inefficient to average a tensor Variable along all axes. \
Use average(A) to find the average of all tensor entries.')
op = Average(validate_and_variablize(args[0]), axes=axes, out_shape=out_shape)
else:
# axes is None for multiple variables
args = [validate_and_variablize(x) for x in args]
op = ElementwiseAverage(*args)
return op.finalize_and_return_outputs()
class TestAverage(csdl_tests.CSDLTest):
def test_functionality(self,):
self.prep()
import csdl_alpha as csdl
import numpy as np
recorder = csdl.build_new_recorder(inline = True)
recorder.start()
x_val = 3.0*np.arange(6).reshape(2,3)
y_val = 2.0*np.ones((2,3))
z_val = np.ones((2,3))
x = csdl.Variable(name = 'x', value = x_val)
y = csdl.Variable(name = 'y', value = y_val)
z = csdl.Variable(name = 'z', value = z_val)
compare_values = []
# average of a single tensor variable
s1 = csdl.average(x)
t1 = np.array([7.5])
compare_values += [csdl_tests.TestingPair(s1, t1, tag = 's1')]
# average of a single tensor constant
s2 = csdl.average(x_val)
compare_values += [csdl_tests.TestingPair(s2, t1, tag = 's2')]
# average of a single tensor variable along specified axes
s3 = csdl.average(x, axes=(1,))
t3 = np.average(x_val, axis=1)
compare_values += [csdl_tests.TestingPair(s3, t3, tag = 's3')]
# elementwise average of multiple tensor variables
s4 = csdl.average(x, y, z)
t4 = (x_val + y_val + z_val)/3
compare_values += [csdl_tests.TestingPair(s4, t4, tag = 's4')]
# elementwise average of multiple tensor constants
s5 = csdl.average(x_val, y_val, z_val)
compare_values += [csdl_tests.TestingPair(s5, t4, tag = 's5')]
# Try more complicated tensors and averages
x_val = 3.0*np.arange(24).reshape(2,3,4)
x = csdl.Variable(value = x_val)
s3 = csdl.average(x, axes=(1,2))
t3 = np.average(x_val, axis=(1,2))
compare_values += [csdl_tests.TestingPair(s3, t3, tag = 's3')]
# Try more complicated tensors and averages
x_val = 3.0*np.arange(24).reshape(2,3,4)
x = csdl.Variable(value = x_val)
s3 = csdl.average(x, axes=(1,0))
t3 = np.average(x_val, axis=(1,0))
compare_values += [csdl_tests.TestingPair(s3, t3, tag = 's3')]
# Try more complicated tensors and averages
x_val = 3.0*np.arange(24).reshape(2,3,4)
x = csdl.Variable(value = x_val)
s3 = csdl.average(x, axes=(1,))
t3 = np.average(x_val, axis=(1,))
compare_values += [csdl_tests.TestingPair(s3, t3, tag = 's3')]
self.run_tests(compare_values = compare_values,verify_derivatives=True)
def test_example(self,):
self.docstest(average)
if __name__ == '__main__':
test = TestAverage()
test.test_functionality()
test.test_example()