from csdl_alpha.src.operations.operation_subclasses import ElementwiseOperation, ComposedOperation
from csdl_alpha.src.graph.operation import Operation, set_properties
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.inputs import variablize, validate_and_variablize
import csdl_alpha.utils.testing_utils as csdl_tests
import pytest
from csdl_alpha.utils.typing import VariableLike
class VectorDot(Operation):
def __init__(self,x,y):
super().__init__(x,y)
self.name = 'vdot'
self.set_dense_outputs(((1,),))
def compute_inline(self, x, y):
import numpy as np
return np.vdot(x, y)
def compute_jax(self, x, y):
import jax.numpy as jnp
return jnp.vdot(x, y)
def evaluate_vjp(self, cotangents, x, y, z):
import csdl_alpha as csdl
if cotangents.check(x):
cotangents.accumulate(x, (cotangents[z]*y).reshape(x.shape))
if cotangents.check(y):
cotangents.accumulate(y, (cotangents[z]*x).reshape(y.shape))
[docs]def vdot(x:VariableLike,y:VariableLike)->Variable:
"""
Dot product of two vectors x and y. The result is a scalar of shape (1,).
Parameters
----------
x : Variable
1D vector.
y : Variable
1D vector.
Returns
-------
out: Variable
Scalar dot product of x and y.
Examples
--------
>>> recorder = csdl.Recorder(inline = True)
>>> recorder.start()
>>> x = csdl.Variable(value = np.array([1, 2, 3]))
>>> y = csdl.Variable(value = np.array([4, 5, 6]))
>>> csdl.vdot(x, y).value
array([32.])
"""
x = validate_and_variablize(x)
y = validate_and_variablize(y)
# checks:
# - x must be 1D (I guess not)
# - y must be 1D (I guess not)
# - x and y must have the same size
# For now, allow tensor for vdot
# if len(x.shape) != 1:
# raise ValueError(f"Vector x must be 1D, but has shape {x.shape}")
# if len(y.shape) != 1:
# raise ValueError(f"Vector y must be 1D, but has shape {y.shape}")
if x.size != y.size:
raise ValueError(f"Vectors x and y must have the same size. {x.size} != {y.size}")
return VectorDot(x, y).finalize_and_return_outputs()
class TestVDot(csdl_tests.CSDLTest):
def test_functionality(self,):
self.prep()
import csdl_alpha as csdl
import numpy as np
x_val = np.arange(10)
y_val = np.arange(10)+2.0
x = csdl.Variable(value = x_val)
y = csdl.Variable(value = y_val)
compare_values = []
compare_values += [csdl_tests.TestingPair(csdl.vdot(x,y), np.vdot(x_val, y_val).flatten())]
compare_values += [csdl_tests.TestingPair(csdl.vdot(x_val,y), np.vdot(x_val, y_val).flatten())]
compare_values += [csdl_tests.TestingPair(csdl.vdot(x,y_val), np.vdot(x_val, y_val).flatten())]
compare_values += [csdl_tests.TestingPair(csdl.vdot(x.reshape((2,5)),y_val.reshape((2,5))), np.vdot(x_val, y_val).flatten())]
compare_values += [csdl_tests.TestingPair(csdl.vdot(x,y.reshape((10,1))), np.vdot(x_val, y_val).flatten())]
self.run_tests(compare_values = compare_values, verify_derivatives=True)
def test_errors(self,):
self.prep()
import csdl_alpha as csdl
import numpy as np
x_val = np.arange(10)
y_val = np.arange(9)+2.0
x = csdl.Variable(value = x_val)
y = csdl.Variable(value = y_val)
with pytest.raises(ValueError):
csdl.vdot(x,y)
with pytest.raises(ValueError):
csdl.vdot(x_val,y)
with pytest.raises(ValueError):
csdl.vdot(x,y_val)
with pytest.raises(ValueError):
csdl.vdot(y,x)
with pytest.raises(ValueError):
csdl.vdot(y_val,x)
with pytest.raises(ValueError):
csdl.vdot(y,x_val)
x_val = (np.arange(10)).reshape(2,5)
y_val = (np.arange(2)+2.0).reshape(2)
x = csdl.Variable(value = x_val)
y = csdl.Variable(value = y_val)
with pytest.raises(ValueError):
csdl.vdot(x,y)
with pytest.raises(ValueError):
csdl.vdot(x_val,y)
with pytest.raises(ValueError):
csdl.vdot(x,y_val)
with pytest.raises(ValueError):
csdl.vdot(y,x)
with pytest.raises(ValueError):
csdl.vdot(y_val,x)
with pytest.raises(ValueError):
csdl.vdot(x_val,y_val)
def test_docstring(self):
self.docstest(vdot)
if __name__ == '__main__':
test = TestVDot()
test.test_functionality()
test.test_errors()
test.test_docstring()