from csdl_alpha.src.graph.operation import Operation, set_properties
import csdl_alpha.utils.testing_utils as csdl_tests
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
from csdl_alpha.src.operations.linalg.utils import process_matA_vecb
from csdl_alpha.utils.typing import VariableLike
import pytest
@set_properties()
class MatVec(Operation):
def __init__(self, A:Variable, x:Variable) -> 'MatVec':
super().__init__(A,x)
self.name = 'matvec'
self.set_dense_outputs(((A.shape[0], 1),))
def compute_inline(self, A, x):
return A @ x
def compute_jax(self, A, x):
return A @ x
def evaluate_vjp(self, cotangents, A, x, b):
import csdl_alpha as csdl
if cotangents.check(x):
cotangents.accumulate(x, csdl.matvec(A.T(), cotangents[b]))
if cotangents.check(A):
cotangents.accumulate(A, csdl.outer(cotangents[b], x).reshape(A.shape))
[docs]def matvec(A:VariableLike, x:VariableLike) -> Variable:
"""matrix-vector multiplication A*x. The number of columns of A must be equal to the number of rows of x.
If x is 1D, reshaped to 2D.
Parameters
----------
A : Variable
2D matrix
x : Variable
1D or 2D vector
Returns
-------
y: Variable
1D or 2D vector
Examples
--------
>>> recorder = csdl.Recorder(inline = True)
>>> recorder.start()
>>> A = csdl.Variable(value = np.array([[1, 2], [3, 4], [5, 6]]))
>>> x = csdl.Variable(value = np.array([1, 2]))
>>> csdl.matvec(A, x).value
array([ 5., 11., 17.])
"""
A_mat = validate_and_variablize(A, raise_on_sparse=False)
x_vec = validate_and_variablize(x)
output = MatVec(*process_matA_vecb(A_mat, x_vec)).finalize_and_return_outputs()
if len(x.shape) == 2:
return output
if len(x.shape) == 1:
return output.reshape((output.size,))
class TestMatVec(csdl_tests.CSDLTest):
def test_functionality(self,):
self.prep()
import csdl_alpha as csdl
import numpy as np
A_shape = (3,4)
B_shape = (4,1)
A_val = np.arange(np.prod(A_shape)).reshape(A_shape)
B_val = np.arange(np.prod(B_shape)).reshape(B_shape)
A = csdl.Variable(value = A_val)
B = csdl.Variable(value = B_val)
compare_values = []
C = csdl.matvec(A,B)
compare_values += [csdl_tests.TestingPair(C, A_val@B_val)]
B_shape = (4,)
B_val = np.arange(np.prod(B_shape)).reshape(B_shape)
B = csdl.Variable(value = B_val)
C = csdl.matvec(A,B)
compare_values += [csdl_tests.TestingPair(C, A_val@B_val)]
C = csdl.matvec(A_val,B)
compare_values += [csdl_tests.TestingPair(C, A_val@B_val)]
C = csdl.matvec(A,B_val)
compare_values += [csdl_tests.TestingPair(C, A_val@B_val)]
B_shape = (4,)
B_val = np.arange(np.prod(B_shape)).reshape(B_shape)
C = csdl.matvec(A,B_val)
compare_values += [csdl_tests.TestingPair(C, A_val@B_val)]
C = csdl.matvec(A_val,B)
compare_values += [csdl_tests.TestingPair(C, A_val@B_val)]
self.run_tests(compare_values = compare_values,verify_derivatives=True)
def test_errors(self):
self.prep()
import csdl_alpha as csdl
import numpy as np
A = csdl.Variable(value = np.ones((2,2)))
B = csdl.Variable(value = np.ones((2,2)))
with pytest.raises(ValueError):
C = csdl.matvec(A, B)
A = csdl.Variable(value = np.ones((2,2,3)))
B = csdl.Variable(value = np.ones((2,1)))
with pytest.raises(ValueError):
C = csdl.matvec(A, B)
A = csdl.Variable(value = np.ones((2,3)))
B = csdl.Variable(value = np.ones((2,1)))
with pytest.raises(ValueError):
C = csdl.matvec(A, B)
A = csdl.Variable(value = np.ones((2,3)))
B = csdl.Variable(value = np.ones((2,)))
with pytest.raises(ValueError):
C = csdl.matvec(A, B)
A = csdl.Variable(value = np.ones((2,3)))
B = csdl.Variable(value = np.ones((3,4,4)))
with pytest.raises(ValueError):
C = csdl.matvec(A, B)
def test_docsstrings(self):
self.docstest(matvec)
if __name__ == '__main__':
t = TestMatVec()
t.test_functionality()
t.test_docsstrings()
t.test_errors()