from typing import Union
import warnings
from csdl_alpha.src.graph.variable import Variable
from csdl_alpha.utils.inputs import variablize
[docs]class VariableGroup:
"""
Represents a group of variables.
This class provides a way to organize and manage a group of variables. It allows for defining checks
on the variables, adding tags to the variables, and saving the variables.
"""
def __init__(self):
if not type(self) == VariableGroup:
raise TypeError("Subclasses of VariableGroup should be decorated with @dataclass.")
self.__post_init__()
def __post_init__(self):
self._metadata = {}
self.define_checks()
self.check()
def __setattr__(self, name, value):
if hasattr(self, '_metadata'):
value = self._check_parameters(name, value)
super().__setattr__(name, value)
def _check_parameters(self, name, value):
if name in self._metadata:
if self._metadata[name]['variablize']:
# NOTE: variablize now turns things into Constant objects, idk if this is the desired behavior
value = variablize(value)
if self._metadata[name]['type'] is not None:
if type(value) != self._metadata[name]['type']:
raise ValueError(f"Variable {name} must be of type {self._metadata[name]['type']}.")
if self._metadata[name]['shape'] is not None:
if value.shape != self._metadata[name]['shape']:
raise ValueError(f"Variable {name} must have shape {self._metadata[name]['shape']}.")
return value
[docs] def check(self):
"""Applies all checks to the variables in the group.
"""
for key in self._metadata.keys():
if not hasattr(self, key):
raise ValueError(f"Variable {key} not found in the group.")
val = getattr(self, key)
setattr(self, key, self._check_parameters(key, val))
def define_checks(self):
pass
[docs] def add_check(self, name:str, type=None, shape:tuple=None, variablize:bool=False):
"""Declare parameters to be checked for a variable in the group.
This method is used to define checks for a variable in the group. The parameters that can be checked
include the name, type, shape, and whether the variable should be variablized.
Parameters
----------
name : str
The name of the variable.
type : type, optional
The type of the variable, by default None.
shape : type, optional
The shape of the variable, by default None.
variablize : bool, optional
Whether the variable should be turned into a CSDL variable, by default False.
Raises
------
ValueError
If the variable with the given name is not found in the group.
ValueError
If parameters for the variable with the given name are already declared.
"""
if not name in self.__annotations__:
raise ValueError(f"Variable {name} not found in the group.")
if name in self._metadata:
raise ValueError(f"Checks for variable {name} already declared.")
self._metadata[name] = {'type': type, 'shape': shape, 'variablize': variablize}
[docs] def add_tag(self, tag:str):
"""Adds a tag to all Variables in the group or subgroups.
Parameters
----------
tag : str
Tag to add to the Variables.
"""
for key, val in self.__dict__.items():
if isinstance(val, (Variable, VariableGroup)):
val.add_tag(tag)
[docs] def save(self):
"""saves any Variables in the group or subgroups.
"""
for key, val in self.__dict__.items():
if isinstance(val, (Variable, VariableGroup)):
val.save()
# def print_all(self):
# if __name__ == '__main__':
# import csdl_alpha as csdl
# recroder = csdl.Recorder()
# recroder.start()
# vg = csdl.VariableGroup()
# vg.a = 1
# vg.b = csdl.Variable(shape=(1,), value=1)
# @dataclass
# class VG(VariableGroup):
# a : Union[Variable, int, float]
# b : Variable
# def define_checks(self):
# self.add_check('a', shape=(1,), variablize=True)
# self.add_check('b', type=Variable, shape=(1,))
# vg = VG(a=1, b=csdl.Variable(shape=(1,), value=1))