From 9bc34ca29d9fd7f3ad0537116706b404da9de177 Mon Sep 17 00:00:00 2001 From: Quan Pham Date: Thu, 24 Apr 2025 11:36:23 -0400 Subject: [PATCH] Allow creating and updating allocations with API versioning To create an allocation, a JSON payload can be uploaded to `/api/allocations`: { "attributes": [ {"attribute_type": "OpenShift Limit on CPU Quota", "value": 8}, {"attribute_type": "OpenShift Limit on RAM Quota (MiB)", "value": 16}, ], "project": {"id": project.id}, "resources": [{"id": self.resource.id}], "status": "New", } Updating allocation status is done via a PATCH request to `/api/allocations/{id}` with a JSON payload: { "status": "Active" } Certain status transitions trigger signals: - New -> Active: allocation_activate - Active -> Denied: allocation_deactivate --- src/coldfront_plugin_api/serializers.py | 115 +++++++++++++++++- .../tests/unit/test_allocations.py | 98 +++++++++++++++ src/coldfront_plugin_api/urls.py | 8 +- src/local_settings.py | 1 + 4 files changed, 218 insertions(+), 4 deletions(-) diff --git a/src/coldfront_plugin_api/serializers.py b/src/coldfront_plugin_api/serializers.py index 8a29059..cfd0e4d 100644 --- a/src/coldfront_plugin_api/serializers.py +++ b/src/coldfront_plugin_api/serializers.py @@ -1,14 +1,29 @@ +import logging +from datetime import datetime, timedelta + from rest_framework import serializers -from coldfront.core.allocation.models import Allocation, AllocationAttribute -from coldfront.core.allocation.models import Project +from coldfront.core.allocation.models import ( + Allocation, + AllocationAttribute, + AllocationStatusChoice, + AllocationAttributeType, +) +from coldfront.core.allocation.models import Project, Resource +from coldfront.core.allocation import signals + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) class ProjectSerializer(serializers.ModelSerializer): class Meta: model = Project fields = ["id", "title", "pi", "description", "field_of_science", "status"] + read_only_fields = ["title", "pi", "description", "field_of_science", "status"] + id = serializers.IntegerField() pi = serializers.SerializerMethodField() field_of_science = serializers.SerializerMethodField() status = serializers.SerializerMethodField() @@ -23,6 +38,35 @@ def get_status(self, obj: Project) -> str: return obj.status.name +class AllocationAttributeSerializer(serializers.ModelSerializer): + class Meta: + model = AllocationAttribute + fields = ["attribute_type", "value"] + + attribute_type = ( + serializers.SlugRelatedField( # Peforms validation to ensure attribute exists + read_only=False, + slug_field="name", + queryset=AllocationAttributeType.objects.all(), + source="allocation_attribute_type", + ) + ) + value = serializers.CharField(read_only=False) + + +class ResourceSerializer(serializers.ModelSerializer): + class Meta: + model = Resource + fields = ["id", "name", "resource_type"] + + id = serializers.IntegerField() + name = serializers.CharField(required=False) + resource_type = serializers.SerializerMethodField(required=False) + + def get_resource_type(self, obj: Resource): + return obj.resource_type.name + + class AllocationSerializer(serializers.ModelSerializer): class Meta: model = Allocation @@ -48,3 +92,70 @@ def get_attributes(self, obj: Allocation): def get_status(self, obj: Allocation) -> str: return obj.status.name + + +class AllocationSerializerV2(serializers.ModelSerializer): + class Meta: + model = Allocation + fields = ["id", "project", "description", "resources", "status", "attributes"] + + resources = ResourceSerializer(many=True) + project = ProjectSerializer() + attributes = AllocationAttributeSerializer( + many=True, source="allocationattribute_set", required=False + ) + status = serializers.SlugRelatedField( + slug_field="name", queryset=AllocationStatusChoice.objects.all() + ) + + def create(self, validated_data): + project_obj = Project.objects.get(id=validated_data["project"]["id"]) + resource_obj = Resource.objects.get(id=validated_data["resources"][0]["id"]) + allocation = Allocation.objects.create( + project=project_obj, + status=validated_data["status"], + justification="", + start_date=datetime.now(), + end_date=datetime.now() + timedelta(days=365), + ) + allocation.resources.add(resource_obj) + allocation.save() + + for attribute in validated_data.pop("allocationattribute_set", []): + AllocationAttribute.objects.create( + allocation=allocation, + allocation_attribute_type=attribute["allocation_attribute_type"], + value=attribute["value"], + ) + + logger.info( + f"Created allocation {allocation.id} for project {project_obj.title}" + ) + return allocation + + def update(self, allocation: Allocation, validated_data): + """ + Only allow updating allocation status for now + + Certain status transitions will have side effects (activating/deactivating allocations) + """ + + old_status = allocation.status.name + new_status = validated_data.get("status", old_status).name + + allocation.status = validated_data.get("status", allocation.status) + allocation.save() + + if old_status == "New" and new_status == "Active": + signals.allocation_activate.send( + sender=self.__class__, allocation_pk=allocation.pk + ) + elif old_status == "Active" and new_status in ["Denied", "Revoked"]: + signals.allocation_disable.send( + sender=self.__class__, allocation_pk=allocation.pk + ) + + logger.info( + f"Updated allocation {allocation.id} for project {allocation.project.title}" + ) + return allocation diff --git a/src/coldfront_plugin_api/tests/unit/test_allocations.py b/src/coldfront_plugin_api/tests/unit/test_allocations.py index af71803..5e20c57 100644 --- a/src/coldfront_plugin_api/tests/unit/test_allocations.py +++ b/src/coldfront_plugin_api/tests/unit/test_allocations.py @@ -1,5 +1,7 @@ from os import devnull +from datetime import datetime, timedelta import sys +from unittest.mock import patch, ANY from coldfront.core.allocation import models as allocation_models from django.core.management import call_command @@ -38,6 +40,12 @@ def admin_client(self): client.login(username="admin", password="test1234") return client + @property + def admin_v2_client(self): + client = APIClient(headers={"Accept": "application/json; version=2.0"}) + client.login(username="admin", password="test1234") + return client + def test_list_allocations(self): user = self.new_user() project = self.new_project(pi=user) @@ -146,3 +154,93 @@ def test_filter_allocations(self): "/api/allocations?fake_model_attribute=fake" ).json() self.assertEqual(r_json, []) + + def test_create_allocation(self): + user = self.new_user() + project = self.new_project(pi=user) + + payload = { + "attributes": [ + {"attribute_type": "OpenShift Limit on CPU Quota", "value": 8}, + {"attribute_type": "OpenShift Limit on RAM Quota (MiB)", "value": 16}, + ], + "project": {"id": project.id}, + "resources": [{"id": self.resource.id}], + "status": "New", + } + + self.admin_v2_client.post("/api/allocations", payload, format="json") + + created_allocation = allocation_models.Allocation.objects.get( + project=project, + resources__in=[self.resource], + ) + self.assertEqual(created_allocation.status.name, "New") + self.assertEqual(created_allocation.justification, "") + self.assertEqual(created_allocation.start_date, datetime.now().date()) + self.assertEqual( + created_allocation.end_date, (datetime.now() + timedelta(days=365)).date() + ) + + allocation_models.AllocationAttribute.objects.get( + allocation=created_allocation, + allocation_attribute_type=allocation_models.AllocationAttributeType.objects.get( + name="OpenShift Limit on CPU Quota" + ), + value=8, + ) + allocation_models.AllocationAttribute.objects.get( + allocation=created_allocation, + allocation_attribute_type=allocation_models.AllocationAttributeType.objects.get( + name="OpenShift Limit on RAM Quota (MiB)" + ), + value=16, + ) + + def test_update_allocation_status_new_to_active(self): + user = self.new_user() + project = self.new_project(pi=user) + allocation = self.new_allocation(project, self.resource, 1) + allocation.status = allocation_models.AllocationStatusChoice.objects.get( + name="New" + ) + allocation.save() + + payload = {"status": "Active"} + + with patch( + "coldfront.core.allocation.signals.allocation_activate.send" + ) as mock_activate: + response = self.admin_v2_client.patch( + f"/api/allocations/{allocation.id}?all=true", payload, format="json" + ) + self.assertEqual(response.status_code, 200) + allocation.refresh_from_db() + self.assertEqual(allocation.status.name, "Active") + mock_activate.assert_called_once_with( + sender=ANY, allocation_pk=allocation.pk + ) + + def test_update_allocation_status_active_to_denied(self): + user = self.new_user() + project = self.new_project(pi=user) + allocation = self.new_allocation(project, self.resource, 1) + allocation.status = allocation_models.AllocationStatusChoice.objects.get( + name="Active" + ) + allocation.save() + + payload = {"status": "Denied"} + + with patch( + "coldfront.core.allocation.signals.allocation_disable.send" + ) as mock_disable: + response = self.admin_v2_client.patch( + f"/api/allocations/{allocation.id}", payload, format="json" + ) + self.assertEqual(response.status_code, 200) + allocation.refresh_from_db() + self.assertEqual(allocation.status.name, "Denied") + mock_disable.assert_called_once_with( + sender=ANY, allocation_pk=allocation.pk + ) diff --git a/src/coldfront_plugin_api/urls.py b/src/coldfront_plugin_api/urls.py index e5ac5a4..d9c729e 100644 --- a/src/coldfront_plugin_api/urls.py +++ b/src/coldfront_plugin_api/urls.py @@ -10,7 +10,7 @@ from coldfront_plugin_api import auth, serializers -class AllocationViewSet(viewsets.ReadOnlyModelViewSet): +class AllocationViewSet(viewsets.ModelViewSet): """ This viewset implements the API to Coldfront's allocation object The API allows filtering allocations by any of Coldfront's allocation model attributes, @@ -41,10 +41,14 @@ class AllocationViewSet(viewsets.ReadOnlyModelViewSet): In cases where an invalid model attribute or AA is queried, an empty list is returned """ - serializer_class = serializers.AllocationSerializer authentication_classes = auth.AUTHENTICATION_CLASSES permission_classes = [IsAdminUser] + def get_serializer_class(self): + if self.request.version == "2.0": + return serializers.AllocationSerializerV2 + return serializers.AllocationSerializer + def get_queryset(self): queryset = Allocation.objects.filter(status__name="Active") query_params = self.request.query_params diff --git a/src/local_settings.py b/src/local_settings.py index a353ae8..7e460b4 100644 --- a/src/local_settings.py +++ b/src/local_settings.py @@ -20,6 +20,7 @@ "rest_framework.authentication.BasicAuthentication", "rest_framework.authentication.SessionAuthentication", ], + "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.AcceptHeaderVersioning", } if os.getenv("PLUGIN_AUTH_OIDC") == "True":