| 
 | 1 | +from argparse import ArgumentParser  | 
 | 2 | +from datetime import datetime  | 
 | 3 | +from typing import Any  | 
 | 4 | + | 
 | 5 | +from django.core.management.base import BaseCommand, CommandError  | 
 | 6 | +from django.db import transaction  | 
 | 7 | + | 
 | 8 | +from shared.models import (  | 
 | 9 | +    AffectedProduct,  | 
 | 10 | +    Container,  | 
 | 11 | +    CveRecord,  | 
 | 12 | +    Description,  | 
 | 13 | +    Organization,  | 
 | 14 | +    Version,  | 
 | 15 | +)  | 
 | 16 | + | 
 | 17 | + | 
 | 18 | +class Command(BaseCommand):  | 
 | 19 | +    help = "Create a test CVE for a specific package"  | 
 | 20 | + | 
 | 21 | +    def add_arguments(self, parser: ArgumentParser) -> None:  | 
 | 22 | +        parser.add_argument(  | 
 | 23 | +            "package_name",  | 
 | 24 | +            type=str,  | 
 | 25 | +            help="Package name to create a CVE for",  | 
 | 26 | +        )  | 
 | 27 | +        parser.add_argument(  | 
 | 28 | +            "--cve-id",  | 
 | 29 | +            type=str,  | 
 | 30 | +            help="Custom CVE ID (default: auto-generated)",  | 
 | 31 | +        )  | 
 | 32 | + | 
 | 33 | +    def handle(self, *args: Any, **options: Any) -> None:  | 
 | 34 | +        package_name = options["package_name"]  | 
 | 35 | +        cve_id = options.get("cve_id")  | 
 | 36 | + | 
 | 37 | +        # Generate CVE ID if not provided  | 
 | 38 | +        if not cve_id:  | 
 | 39 | +            current_year = datetime.now().year  | 
 | 40 | +            existing_cves = CveRecord.objects.filter(  | 
 | 41 | +                cve_id__startswith=f"CVE-{current_year}-"  | 
 | 42 | +            ).count()  | 
 | 43 | +            cve_id = f"CVE-{current_year}-{(existing_cves + 1):04d}"  | 
 | 44 | + | 
 | 45 | +        # Check if CVE already exists  | 
 | 46 | +        if CveRecord.objects.filter(cve_id=cve_id).exists():  | 
 | 47 | +            raise CommandError(f"CVE {cve_id} already exists")  | 
 | 48 | + | 
 | 49 | +        with transaction.atomic():  | 
 | 50 | +            # Create organization  | 
 | 51 | +            org, _ = Organization.objects.get_or_create(  | 
 | 52 | +                short_name="TEST_ORG",  | 
 | 53 | +                defaults={"uuid": "12345678-1234-5678-9abc-123456789012"},  | 
 | 54 | +            )  | 
 | 55 | + | 
 | 56 | +            # Create CVE record  | 
 | 57 | +            cve_record = CveRecord.objects.create(  | 
 | 58 | +                cve_id=cve_id,  | 
 | 59 | +                state=CveRecord.RecordState.PUBLISHED,  | 
 | 60 | +                assigner=org,  | 
 | 61 | +                date_published=datetime.now(),  | 
 | 62 | +                date_updated=datetime.now(),  | 
 | 63 | +                triaged=False,  | 
 | 64 | +            )  | 
 | 65 | + | 
 | 66 | +            # Create description  | 
 | 67 | +            description = Description.objects.create(  | 
 | 68 | +                lang="en",  | 
 | 69 | +                value=f"Test vulnerability in {package_name} package.",  | 
 | 70 | +            )  | 
 | 71 | + | 
 | 72 | +            # Create container  | 
 | 73 | +            container = Container.objects.create(  | 
 | 74 | +                _type=Container.Type.CNA,  | 
 | 75 | +                cve=cve_record,  | 
 | 76 | +                provider=org,  | 
 | 77 | +                title=f"Vulnerability in {package_name}",  | 
 | 78 | +                date_public=datetime.now(),  | 
 | 79 | +            )  | 
 | 80 | +            container.descriptions.add(description)  | 
 | 81 | + | 
 | 82 | +            # Create affected product  | 
 | 83 | +            affected_product = AffectedProduct.objects.create(  | 
 | 84 | +                vendor="nixpkgs",  | 
 | 85 | +                product=package_name,  | 
 | 86 | +                package_name=package_name,  | 
 | 87 | +                default_status=AffectedProduct.Status.AFFECTED,  | 
 | 88 | +            )  | 
 | 89 | + | 
 | 90 | +            # Add version constraint  | 
 | 91 | +            version_affected = Version.objects.create(  | 
 | 92 | +                status=Version.Status.AFFECTED,  | 
 | 93 | +                version_type="semver",  | 
 | 94 | +                less_than="*",  | 
 | 95 | +            )  | 
 | 96 | +            affected_product.versions.add(version_affected)  | 
 | 97 | + | 
 | 98 | +            # Link to container  | 
 | 99 | +            container.affected.add(affected_product)  | 
 | 100 | + | 
 | 101 | +        self.stdout.write(  | 
 | 102 | +            self.style.SUCCESS(f"Created CVE {cve_id} for {package_name}")  | 
 | 103 | +        )  | 
0 commit comments