diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 23f0d3b..e897dd4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,11 +20,11 @@ env: jobs: unit-test: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest environment: integration-test-workflow strategy: matrix: - python-version: ["3.6.15", "3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.6.15", "3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -42,16 +42,16 @@ jobs: pytest unit_tests integration-test: timeout-minutes: 10 - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest environment: integration-test-workflow strategy: matrix: - python-version: ["3.6.15", "3.11"] - aerie-version: ["3.0.1", "3.1.1", "3.2.0"] + python-version: ["3.12"] + aerie-version: ["3.0.1", "3.1.1", "3.2.0", "3.3.1", "3.4.0"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/src/aerie_cli/aerie_client.py b/src/aerie_cli/aerie_client.py index 6c01fd7..68bdd74 100644 --- a/src/aerie_cli/aerie_client.py +++ b/src/aerie_cli/aerie_client.py @@ -14,7 +14,7 @@ from .schemas.api import ApiMissionModelRead from .schemas.api import ApiResourceSampleResults from .schemas.api import ApiParcelRead -from .schemas.client import Activity +from .schemas.client import Activity, SimulationDataset, UserSequence, Workspace from .schemas.client import ActivityPlanCreate from .schemas.client import ActivityPlanRead from .schemas.client import DictionaryMetadata @@ -576,6 +576,7 @@ def get_resource_samples(self, simulation_dataset_id: int, state_names: List=Non } def get_simulation_results(self, sim_dataset_id: int) -> str: + """Older implementation""" sim_result_query = """ query Simulation($sim_dataset_id: Int!) { @@ -597,6 +598,67 @@ def get_simulation_results(self, sim_dataset_id: int) -> str: sim_result_query, sim_dataset_id=sim_dataset_id) return resp + def get_model_effective_arguments(self, arguments: Dict, model_id: int) -> Dict: + query = """ + query GetModelEffectiveArguments($model_id: Int!, $arguments: ModelArguments!) { + getModelEffectiveArguments(missionModelId: $model_id, modelArguments: $arguments) { + arguments + success + } + } + """ + resp = self.aerie_host.post_to_graphql( + query, + model_id=model_id, + arguments=arguments + ) + + return resp["arguments"] + + def get_simulation_dataset(self, sim_dataset_id: int): + """Newer implementation""" + query = """ + query GetSimulationDataset($id: Int!) { + simulation_dataset_by_pk(id: $id) { + id + arguments + status + dataset_id + simulation_start_time + simulation_end_time + simulated_activities { + id + activity_type_name + attributes + parent_id + start_time + end_time + start_offset + duration + activity_directive { + id + name + type + start_offset + arguments + metadata + anchor_id + anchored_to_start + } + } + } + } + """ + resp = self.aerie_host.post_to_graphql(query, id=sim_dataset_id) + sim_dataset = SimulationDataset.from_api_dict(resp) + + plan_id = self.get_plan_id_by_sim_id(sim_dataset_id) + plans_metadata = self.list_all_activity_plans() + model_id = next(filter(lambda p: p.id == plan_id, plans_metadata)).model_id + full_arguments = self.get_model_effective_arguments(sim_dataset.arguments, model_id) + sim_dataset.arguments = full_arguments + return sim_dataset + def delete_plan(self, plan_id: int) -> str: delete_plan_mutation = """ @@ -2237,3 +2299,112 @@ def delete_plan_collaborator(self, plan_id: int, user: str): if resp is None: raise RuntimeError(f"Failed to delete plan collaborator") + + def create_workspace(self, name: str) -> int: + query = """ + mutation CreateWorkspace($workspace: workspace_insert_input!) { + createWorkspace: insert_workspace_one(object: $workspace) { + id + } + } + """ + workspace = { + "name": name + } + resp = self.aerie_host.post_to_graphql( + query, + workspace=workspace + ) + + return resp["id"] + + def get_workspaces(self) -> List[Workspace]: + query = """ + query GetWorkspaces { + workspace { + id + name + } + } + """ + resp = self.aerie_host.post_to_graphql(query) + return [Workspace.from_dict(w) for w in resp] + + def get_workspace_id_by_name(self, name: str) -> int: + """Get ID of workspace by name, creating if it doesn't exist + + Args: + name (str): Workspace name + + Returns: + int: Workspace ID + """ + workspaces = self.list_workspaces() + workspace = next(filter(workspaces, lambda w: w.name == name), None) + + if workspace is None: + return self.create_workspace(name) + + return workspace.id + + def create_user_sequence(self, sequence: UserSequence) -> int: + query = """ + mutation CreateUserSequence($sequence: user_sequence_insert_input!) { + createUserSequence: insert_user_sequence_one(object: $sequence) { + id + } + } + """ + resp = self.aerie_host.post_to_graphql( + query, + sequence=sequence.to_api_create().to_dict() + ) + + return resp["id"] + + def get_user_sequences(self) -> List[UserSequence]: + query = """ + query GetUserSequences { + user_sequence { + id + name + definition + workspace_id + parcel_id + } + } + """ + resp = self.aerie_host.post_to_graphql(query) + return [UserSequence.from_api_dict(s) for s in resp] + + def update_user_sequence(self, sequence: UserSequence): + query = """ + mutation UpdateUserSequence($id: Int!, $sequence: user_sequence_set_input!) { + updateUserSequence: update_user_sequence_by_pk( + pk_columns: { id: $id }, _set: $sequence + ) { + id + } + } + """ + if sequence.id is None: + raise ValueError("User sequence ID must be specified to update") + + self.aerie_host.post_to_graphql( + query, + sequence=sequence.to_api_create().to_dict(), + id=id + ) + + def delete_user_sequence(self, id: int): + query = """ + mutation DeleteUserSequence($id: Int!) { + deleteUserSequence: delete_user_sequence_by_pk(id: $id) { + id + } + } + """ + self.aerie_host.post_to_graphql( + query, + id=id + ) diff --git a/src/aerie_cli/aerie_host.py b/src/aerie_cli/aerie_host.py index c7d1c9e..927cd3d 100644 --- a/src/aerie_cli/aerie_host.py +++ b/src/aerie_cli/aerie_host.py @@ -13,6 +13,9 @@ "3.1.0", "3.1.1", "3.2.0", + "3.3.0", + "3.3.1", + "3.4.0" ] class AerieHostVersionError(RuntimeError): diff --git a/src/aerie_cli/schemas/api.py b/src/aerie_cli/schemas/api.py index 5315003..599bb20 100644 --- a/src/aerie_cli/schemas/api.py +++ b/src/aerie_cli/schemas/api.py @@ -147,17 +147,27 @@ class ApiActivityPlanRead(ApiActivityPlanBase): @define -class ApiAsSimulatedActivity(ApiSerialize): - type: str - parent_id: Optional[str] - start_timestamp: Arrow = field( - converter = arrow.get - ) - children: List[str] - duration: timedelta = field( - converter = lambda microseconds: timedelta(microseconds=microseconds) - ) +class ApiSimulatedActivity(ApiSerialize): + id: int + activity_type_name: str + attributes: Dict[str, Any] + parent_id: Optional[int] + start_time: Arrow = field(converter = arrow.get) + end_time: Arrow = field(converter = arrow.get) # TODO what does this look like for an unfinished activity? + start_offset: timedelta = field(converter = postgres_interval_to_timedelta) + duration: timedelta = field(converter = postgres_interval_to_timedelta) # TODO handle unfinished activties? + activity_directive: Optional[ApiActivityRead] = field(converter = converters.optional(lambda a: ApiActivityRead.from_dict(a))) + + +@define +class ApiSimulationDataset(ApiSerialize): + id: int arguments: Dict[str, Any] + status: str + dataset_id: int + simulation_start_time: Arrow = field(converter = arrow.get) + simulation_end_time: Arrow = field(converter = arrow.get) + simulated_activities: List[ApiSimulatedActivity] = field(converter = lambda acts: [ApiSimulatedActivity.from_dict(a) for a in acts]) @define @@ -168,19 +178,6 @@ class ApiSimulatedResourceSample(ApiSerialize): y: Any -@define -class ApiSimulationResults(ApiSerialize): - start: Arrow = field( - converter = arrow.get - ) - activities: Dict[str, ApiAsSimulatedActivity] - unfinishedActivities: Any - # TODO: implement constraints - constraints: Any - # TODO: implement events - events: Any - - @define class ApiResourceSampleResults(ApiSerialize): resourceSamples: Dict[str, List[ApiSimulatedResourceSample]] @@ -210,3 +207,14 @@ class ApiParcelCreate(ApiSerialize): class ApiParcelRead(ApiParcelCreate): id: int parameter_dictionaries: Dict[str, int] + +@define +class ApiUserSequenceCreate(ApiSerialize): + name: str + definition: str + parcel_id: str + workspace_id: str + +@define +class ApiUserSequenceRead(ApiUserSequenceCreate): + id: int diff --git a/src/aerie_cli/schemas/client.py b/src/aerie_cli/schemas/client.py index bb2dcb7..cf14358 100644 --- a/src/aerie_cli/schemas/client.py +++ b/src/aerie_cli/schemas/client.py @@ -23,13 +23,15 @@ from aerie_cli.schemas.api import ApiActivityPlanCreate from aerie_cli.schemas.api import ApiActivityPlanRead from aerie_cli.schemas.api import ApiActivityRead -from aerie_cli.schemas.api import ApiAsSimulatedActivity +from aerie_cli.schemas.api import ApiSimulatedActivity from aerie_cli.schemas.api import ApiResourceSampleResults from aerie_cli.schemas.api import ApiSimulatedResourceSample -from aerie_cli.schemas.api import ApiSimulationResults +from aerie_cli.schemas.api import ApiSimulationDataset from aerie_cli.schemas.api import ActivityBase from aerie_cli.schemas.api import ApiParcelRead from aerie_cli.schemas.api import ApiParcelCreate +from aerie_cli.schemas.api import ApiUserSequenceCreate +from aerie_cli.schemas.api import ApiUserSequenceRead def parse_timedelta_str_converter(t) -> timedelta: if isinstance(t, str): @@ -245,31 +247,38 @@ def from_api_read(cls, api_plan_read: ApiActivityPlanRead) -> "ActivityPlanRead" @define -class AsSimulatedActivity(ClientSerialize): - type: str +class SimulatedActivity(ClientSerialize): id: str + type: str + arguments: Dict[str, Any] + computed_attributes: Dict[str, Any] parent_id: Optional[str] - start_time: Arrow = field( - converter = arrow.get - ) - children: List[str] - duration: timedelta = field( - converter = parse_timedelta_str_converter - ) - parameters: Dict[str, Any] + start_time: Arrow = field(converter = arrow.get) + end_time: Arrow = field(converter = arrow.get) + start_offset: timedelta = field(converter = parse_timedelta_str_converter) + duration: timedelta = field(converter = parse_timedelta_str_converter) + directive: Optional[Activity] @classmethod def from_api_as_simulated_activity( - cls, api_as_simulated_activity: ApiAsSimulatedActivity, id: str + cls, api_as_simulated_activity: ApiSimulatedActivity ): - return AsSimulatedActivity( - type=api_as_simulated_activity.type, - id=id, + if api_as_simulated_activity.activity_directive is None: + directive = None + else: + directive = Activity.from_api_read(api_as_simulated_activity.activity_directive) + + return SimulatedActivity( + id=api_as_simulated_activity.id, + type=api_as_simulated_activity.activity_type_name, + arguments=api_as_simulated_activity.attributes["arguments"], + computed_attributes=api_as_simulated_activity.attributes["computedAttributes"], parent_id=api_as_simulated_activity.parent_id, - start_time=api_as_simulated_activity.start_timestamp, - children=api_as_simulated_activity.children, + start_time=api_as_simulated_activity.start_time, + end_time=api_as_simulated_activity.end_time, + start_offset=api_as_simulated_activity.start_offset, duration=api_as_simulated_activity.duration, - parameters=api_as_simulated_activity.arguments, + directive=directive ) @@ -304,32 +313,43 @@ def from_api_sim_res_timeline( @define -class SimulationResults(ClientSerialize): - start_time: Arrow = field( - converter = arrow.get - ) - activities: List[AsSimulatedActivity] - resources: List[SimulatedResourceTimeline] +class SimulationDataset(ClientSerialize): + id: int + arguments: Dict[str, Any] + status: str + dataset_id: int + start_time: Arrow = field(converter = arrow.get) + end_time: Arrow = field(converter = arrow.get) + activities: List[SimulatedActivity] + # resources: List[SimulatedResourceTimeline] + + @classmethod + def from_api_dict(cls, api_dict: Dict) -> "SimulationDataset": + return cls.from_api_results(ApiSimulationDataset.from_dict(api_dict)) @classmethod def from_api_results( cls, - api_sim_results: ApiSimulationResults, - api_resource_timeline: ApiResourceSampleResults, - ): - plan_start = api_sim_results.start - return SimulationResults( - start_time=plan_start, + api_sim_results: ApiSimulationDataset, + # api_resource_timeline: ApiResourceSampleResults, + ) -> "SimulationDataset": + return SimulationDataset( + id=api_sim_results.id, + arguments=api_sim_results.arguments, + status=api_sim_results.status, + dataset_id=api_sim_results.dataset_id, + start_time=api_sim_results.simulation_start_time, + end_time=api_sim_results.simulation_end_time, activities=[ - AsSimulatedActivity.from_api_as_simulated_activity(act, id) - for id, act in api_sim_results.activities.items() - ], - resources=[ - SimulatedResourceTimeline.from_api_sim_res_timeline( - name, api_timeline, plan_start - ) - for name, api_timeline in api_resource_timeline.resourceSamples.items() - ], + SimulatedActivity.from_api_as_simulated_activity(act) + for act in api_sim_results.simulated_activities + ] + # resources=[ + # SimulatedResourceTimeline.from_api_sim_res_timeline( + # name, api_timeline, plan_start + # ) + # for name, api_timeline in api_resource_timeline.resourceSamples.items() + # ], ) @@ -469,3 +489,40 @@ class SequenceAdaptationMetadata(ClientSerialize): updated_at: Arrow = field( converter=arrow.get ) + + +@define +class Workspace(ClientSerialize): + name: str + id: int + + +@define +class UserSequence(ClientSerialize): + name: str + definition: str + parcel_id: int + workspace_id: int + id: int = field(default=None) + + def to_api_create(self) -> ApiUserSequenceCreate: + return ApiUserSequenceCreate( + name=self.name, + definition=self.definition, + parcel_id=self.parcel_id, + workspace_id=self.workspace_id + ) + + @classmethod + def from_api_dict(cls, sequence: Dict) -> "UserSequence": + return cls.from_api_read(ApiUserSequenceRead.from_dict(sequence)) + + @classmethod + def from_api_read(cls, sequence: ApiUserSequenceRead) -> "UserSequence": + return cls( + name=sequence.name, + definition=sequence.definition, + parcel_id=sequence.parcel_id, + workspace_id=sequence.workspace_id, + id=sequence.id + )