Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ env:

jobs:
unit-test:
runs-on: ubuntu-20.04
runs-on: ubuntu-latest
environment: integration-test-workflow
strategy:
matrix:
python-version: ["3.6.15", "3.7", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.6.15", "3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -42,16 +42,16 @@ jobs:
pytest unit_tests
integration-test:
timeout-minutes: 10
runs-on: ubuntu-20.04
runs-on: ubuntu-latest
environment: integration-test-workflow
strategy:
matrix:
python-version: ["3.6.15", "3.11"]
aerie-version: ["3.0.1", "3.1.1", "3.2.0"]
python-version: ["3.12"]
aerie-version: ["3.0.1", "3.1.1", "3.2.0", "3.3.1", "3.4.0"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
Expand Down
173 changes: 172 additions & 1 deletion src/aerie_cli/aerie_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .schemas.api import ApiMissionModelRead
from .schemas.api import ApiResourceSampleResults
from .schemas.api import ApiParcelRead
from .schemas.client import Activity
from .schemas.client import Activity, SimulationDataset, UserSequence, Workspace
from .schemas.client import ActivityPlanCreate
from .schemas.client import ActivityPlanRead
from .schemas.client import DictionaryMetadata
Expand Down Expand Up @@ -576,6 +576,7 @@ def get_resource_samples(self, simulation_dataset_id: int, state_names: List=Non
}

def get_simulation_results(self, sim_dataset_id: int) -> str:
"""Older implementation"""

sim_result_query = """
query Simulation($sim_dataset_id: Int!) {
Expand All @@ -597,6 +598,67 @@ def get_simulation_results(self, sim_dataset_id: int) -> str:
sim_result_query, sim_dataset_id=sim_dataset_id)
return resp

def get_model_effective_arguments(self, arguments: Dict, model_id: int) -> Dict:
query = """
query GetModelEffectiveArguments($model_id: Int!, $arguments: ModelArguments!) {
getModelEffectiveArguments(missionModelId: $model_id, modelArguments: $arguments) {
arguments
success
}
}
"""
resp = self.aerie_host.post_to_graphql(
query,
model_id=model_id,
arguments=arguments
)

return resp["arguments"]

def get_simulation_dataset(self, sim_dataset_id: int):
"""Newer implementation"""
query = """
query GetSimulationDataset($id: Int!) {
simulation_dataset_by_pk(id: $id) {
id
arguments
status
dataset_id
simulation_start_time
simulation_end_time
simulated_activities {
id
activity_type_name
attributes
parent_id
start_time
end_time
start_offset
duration
activity_directive {
id
name
type
start_offset
arguments
metadata
anchor_id
anchored_to_start
}
}
}
}
"""
resp = self.aerie_host.post_to_graphql(query, id=sim_dataset_id)
sim_dataset = SimulationDataset.from_api_dict(resp)

plan_id = self.get_plan_id_by_sim_id(sim_dataset_id)
plans_metadata = self.list_all_activity_plans()
model_id = next(filter(lambda p: p.id == plan_id, plans_metadata)).model_id
full_arguments = self.get_model_effective_arguments(sim_dataset.arguments, model_id)
sim_dataset.arguments = full_arguments
return sim_dataset

def delete_plan(self, plan_id: int) -> str:

delete_plan_mutation = """
Expand Down Expand Up @@ -2237,3 +2299,112 @@ def delete_plan_collaborator(self, plan_id: int, user: str):

if resp is None:
raise RuntimeError(f"Failed to delete plan collaborator")

def create_workspace(self, name: str) -> int:
query = """
mutation CreateWorkspace($workspace: workspace_insert_input!) {
createWorkspace: insert_workspace_one(object: $workspace) {
id
}
}
"""
workspace = {
"name": name
}
resp = self.aerie_host.post_to_graphql(
query,
workspace=workspace
)

return resp["id"]

def get_workspaces(self) -> List[Workspace]:
query = """
query GetWorkspaces {
workspace {
id
name
}
}
"""
resp = self.aerie_host.post_to_graphql(query)
return [Workspace.from_dict(w) for w in resp]

def get_workspace_id_by_name(self, name: str) -> int:
"""Get ID of workspace by name, creating if it doesn't exist

Args:
name (str): Workspace name

Returns:
int: Workspace ID
"""
workspaces = self.list_workspaces()
workspace = next(filter(workspaces, lambda w: w.name == name), None)

if workspace is None:
return self.create_workspace(name)

return workspace.id

def create_user_sequence(self, sequence: UserSequence) -> int:
query = """
mutation CreateUserSequence($sequence: user_sequence_insert_input!) {
createUserSequence: insert_user_sequence_one(object: $sequence) {
id
}
}
"""
resp = self.aerie_host.post_to_graphql(
query,
sequence=sequence.to_api_create().to_dict()
)

return resp["id"]

def get_user_sequences(self) -> List[UserSequence]:
query = """
query GetUserSequences {
user_sequence {
id
name
definition
workspace_id
parcel_id
}
}
"""
resp = self.aerie_host.post_to_graphql(query)
return [UserSequence.from_api_dict(s) for s in resp]

def update_user_sequence(self, sequence: UserSequence):
query = """
mutation UpdateUserSequence($id: Int!, $sequence: user_sequence_set_input!) {
updateUserSequence: update_user_sequence_by_pk(
pk_columns: { id: $id }, _set: $sequence
) {
id
}
}
"""
if sequence.id is None:
raise ValueError("User sequence ID must be specified to update")

self.aerie_host.post_to_graphql(
query,
sequence=sequence.to_api_create().to_dict(),
id=id
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
id=id
id=sequence.id

)

def delete_user_sequence(self, id: int):
query = """
mutation DeleteUserSequence($id: Int!) {
deleteUserSequence: delete_user_sequence_by_pk(id: $id) {
id
}
}
"""
self.aerie_host.post_to_graphql(
query,
id=id
)
3 changes: 3 additions & 0 deletions src/aerie_cli/aerie_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
"3.1.0",
"3.1.1",
"3.2.0",
"3.3.0",
"3.3.1",
"3.4.0"
]

class AerieHostVersionError(RuntimeError):
Expand Down
54 changes: 31 additions & 23 deletions src/aerie_cli/schemas/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,27 @@ class ApiActivityPlanRead(ApiActivityPlanBase):


@define
class ApiAsSimulatedActivity(ApiSerialize):
type: str
parent_id: Optional[str]
start_timestamp: Arrow = field(
converter = arrow.get
)
children: List[str]
duration: timedelta = field(
converter = lambda microseconds: timedelta(microseconds=microseconds)
)
class ApiSimulatedActivity(ApiSerialize):
id: int
activity_type_name: str
attributes: Dict[str, Any]
parent_id: Optional[int]
start_time: Arrow = field(converter = arrow.get)
end_time: Arrow = field(converter = arrow.get) # TODO what does this look like for an unfinished activity?
start_offset: timedelta = field(converter = postgres_interval_to_timedelta)
duration: timedelta = field(converter = postgres_interval_to_timedelta) # TODO handle unfinished activties?
activity_directive: Optional[ApiActivityRead] = field(converter = converters.optional(lambda a: ApiActivityRead.from_dict(a)))


@define
class ApiSimulationDataset(ApiSerialize):
id: int
arguments: Dict[str, Any]
status: str
dataset_id: int
simulation_start_time: Arrow = field(converter = arrow.get)
simulation_end_time: Arrow = field(converter = arrow.get)
simulated_activities: List[ApiSimulatedActivity] = field(converter = lambda acts: [ApiSimulatedActivity.from_dict(a) for a in acts])


@define
Expand All @@ -168,19 +178,6 @@ class ApiSimulatedResourceSample(ApiSerialize):
y: Any


@define
class ApiSimulationResults(ApiSerialize):
start: Arrow = field(
converter = arrow.get
)
activities: Dict[str, ApiAsSimulatedActivity]
unfinishedActivities: Any
# TODO: implement constraints
constraints: Any
# TODO: implement events
events: Any


@define
class ApiResourceSampleResults(ApiSerialize):
resourceSamples: Dict[str, List[ApiSimulatedResourceSample]]
Expand Down Expand Up @@ -210,3 +207,14 @@ class ApiParcelCreate(ApiSerialize):
class ApiParcelRead(ApiParcelCreate):
id: int
parameter_dictionaries: Dict[str, int]

@define
class ApiUserSequenceCreate(ApiSerialize):
name: str
definition: str
parcel_id: str
workspace_id: str

@define
class ApiUserSequenceRead(ApiUserSequenceCreate):
id: int
Loading
Loading