From 7785d22c2bb48d7fd665a2dcfdfcc41f5a37d8a8 Mon Sep 17 00:00:00 2001 From: Tristan Cartledge Date: Thu, 16 Oct 2025 15:37:32 +1000 Subject: [PATCH 1/5] fix: snip TUI select-all/deselect bug; clean reachability seeded from /paths and top-level security; remove unused top-level tags; update docs --- README.md | 2 +- cmd/openapi/commands/openapi/README.md | 16 +- cmd/openapi/commands/openapi/clean.go | 19 +- cmd/openapi/commands/openapi/snip.go | 158 +++++++++- cmd/openapi/internal/explore/tui/model.go | 4 +- .../internal/explore/tui/model_test.go | 119 +++++++ openapi/clean.go | 298 ++++++++++++++---- openapi/clean_paths_reachability_test.go | 220 +++++++++++++ openapi/clean_tags_test.go | 133 ++++++++ openapi/reference.go | 2 +- 10 files changed, 882 insertions(+), 89 deletions(-) create mode 100644 cmd/openapi/internal/explore/tui/model_test.go create mode 100644 openapi/clean_paths_reachability_test.go create mode 100644 openapi/clean_tags_test.go diff --git a/README.md b/README.md index 6a147e5..73a61ea 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ The CLI provides three main command groups: - **`openapi spec`** - Commands for working with OpenAPI specifications ([documentation](./cmd/openapi/commands/openapi/README.md)) - `bootstrap` - Create a new OpenAPI document with best practice examples - `bundle` - Bundle external references into components section - - `clean` - Remove unused components from an OpenAPI specification + - `clean` - Remove unused components and unused top-level tags from an OpenAPI specification - `explore` - Interactively explore an OpenAPI specification in the terminal - `inline` - Inline all references in an OpenAPI specification - `join` - Join multiple OpenAPI documents into a single document diff --git a/cmd/openapi/commands/openapi/README.md b/cmd/openapi/commands/openapi/README.md index 6475ff5..3fa2e95 100644 --- a/cmd/openapi/commands/openapi/README.md +++ b/cmd/openapi/commands/openapi/README.md @@ -135,7 +135,7 @@ paths: ### `clean` -Remove unused components from an OpenAPI specification to create a cleaner, more maintainable document. +Remove unused components and unused top‑level tags from an OpenAPI specification using reachability seeded from /paths and top‑level security. ```bash # Clean to stdout (pipe-friendly) @@ -150,10 +150,12 @@ openapi spec clean -w ./spec.yaml What cleaning does: -- Removes unused components from all component types (schemas, responses, parameters, etc.) -- Tracks all references throughout the document including `$ref` and security scheme name references -- Preserves all components that are actually used in the specification -- Handles complex reference patterns including circular references and nested components +- Performs reachability-based cleanup seeded only from API surface areas (/paths and top‑level security) +- Expands through $ref links to components until a fixed point is reached +- Preserves security schemes referenced by name in security requirement objects (global and operation‑level) +- Removes unused components from all component types (schemas, responses, parameters, examples, request bodies, headers, links, callbacks, path items) +- Removes unused top‑level tags that are not referenced by any operation +- Handles complex reference patterns; only components reachable from the API surface are kept **Before cleaning:** @@ -823,6 +825,7 @@ openapi spec snip -w --operation /internal/debug:GET ./spec.yaml **Two Operation Modes:** **Interactive Mode** (no operation flags): + - Launch a terminal UI to browse all operations - Select operations with Space key - Press 'a' to select all, 'A' to deselect all @@ -830,6 +833,7 @@ openapi spec snip -w --operation /internal/debug:GET ./spec.yaml - Press 'q' or Esc to cancel **Command-Line Mode** (operation flags specified): + - Remove operations specified via flags without UI - Supports `--operationId` for operation IDs - Supports `--operation` for path:method pairs @@ -951,7 +955,7 @@ openapi spec clean | \ openapi spec upgrade | \ openapi spec validate -# Alternative: Clean after bundling to remove unused components +# Alternative: Clean after bundling to remove unused components and unused top-level tags openapi spec bundle ./spec.yaml ./bundled.yaml openapi spec clean ./bundled.yaml ./clean-bundled.yaml openapi spec validate ./clean-bundled.yaml diff --git a/cmd/openapi/commands/openapi/clean.go b/cmd/openapi/commands/openapi/clean.go index 48609cb..74a5ede 100644 --- a/cmd/openapi/commands/openapi/clean.go +++ b/cmd/openapi/commands/openapi/clean.go @@ -12,15 +12,19 @@ import ( var cleanCmd = &cobra.Command{ Use: "clean [output-file]", - Short: "Remove unused components from an OpenAPI specification", - Long: `Remove unused components from an OpenAPI specification to create a cleaner, more focused document. + Short: "Remove unused components and unused top-level tags from an OpenAPI specification", + Long: `Remove unused components and unused top-level tags from an OpenAPI specification to create a cleaner, more focused document. -This command analyzes an OpenAPI document to identify which components are actually referenced -and removes any unused components, reducing document size and improving clarity. +This command uses reachability-based analysis to keep only what is actually used by the API surface: +- Seeds reachability exclusively from API surface areas: entries under /paths and the top-level security section +- Expands through $ref links across component sections until a fixed point is reached +- Preserves security schemes referenced by name in security requirement objects (global or operation-level) +- Prunes any components that are not reachable from the API surface +- Removes unused top-level tags that are not referenced by any operation What gets cleaned: - Unused schemas in components/schemas -- Unused responses in components/responses +- Unused responses in components/responses - Unused parameters in components/parameters - Unused examples in components/examples - Unused request bodies in components/requestBodies @@ -29,14 +33,13 @@ What gets cleaned: - Unused links in components/links - Unused callbacks in components/callbacks - Unused path items in components/pathItems +- Unused top-level tags (global tags not referenced by any operation) Special handling for security schemes: Security schemes can be referenced in two ways: 1. By $ref (like other components) 2. By name in security requirement objects (global or operation-level) - -The clean command correctly handles both cases and preserves security schemes -that are referenced by name in security blocks. +The clean command correctly handles both cases and preserves security schemes that are referenced by name in security blocks. Benefits of cleaning: - Reduce document size by removing dead code diff --git a/cmd/openapi/commands/openapi/snip.go b/cmd/openapi/commands/openapi/snip.go index b6d9bf3..8ccf1a5 100644 --- a/cmd/openapi/commands/openapi/snip.go +++ b/cmd/openapi/commands/openapi/snip.go @@ -13,9 +13,11 @@ import ( ) var ( - snipWriteInPlace bool - snipOperationIDs []string - snipOperations []string + snipWriteInPlace bool + snipOperationIDs []string + snipOperations []string + snipKeepOperationIDs []string + snipKeepOperations []string ) var snipCmd = &cobra.Command{ @@ -73,6 +75,9 @@ func init() { snipCmd.Flags().BoolVarP(&snipWriteInPlace, "write", "w", false, "write result in-place to input file") snipCmd.Flags().StringSliceVar(&snipOperationIDs, "operationId", nil, "operation ID to remove (can be comma-separated or repeated)") snipCmd.Flags().StringSliceVar(&snipOperations, "operation", nil, "operation as path:method to remove (can be comma-separated or repeated)") + // Keep-mode flags (mutually exclusive with remove-mode flags) + snipCmd.Flags().StringSliceVar(&snipKeepOperationIDs, "keepOperationId", nil, "operation ID to keep (can be comma-separated or repeated)") + snipCmd.Flags().StringSliceVar(&snipKeepOperations, "keepOperation", nil, "operation as path:method to keep (can be comma-separated or repeated)") } func runSnip(cmd *cobra.Command, args []string) error { @@ -84,20 +89,29 @@ func runSnip(cmd *cobra.Command, args []string) error { outputFile = args[1] } - // Check if any operations were specified via flags - hasOperationFlags := len(snipOperationIDs) > 0 || len(snipOperations) > 0 + // Check which flag sets were specified + hasRemoveFlags := len(snipOperationIDs) > 0 || len(snipOperations) > 0 + hasKeepFlags := len(snipKeepOperationIDs) > 0 || len(snipKeepOperations) > 0 - // If -w is specified without operation flags, error - if snipWriteInPlace && !hasOperationFlags { - return fmt.Errorf("--write flag requires specifying operations via --operationId or --operation flags") + // If -w is specified without any operation selection flags, error + if snipWriteInPlace && !(hasRemoveFlags || hasKeepFlags) { + return fmt.Errorf("--write flag requires specifying operations via --operationId/--operation or --keepOperationId/--keepOperation") } - if !hasOperationFlags { - // No flags - interactive mode + // Interactive mode when no flags provided + if !hasRemoveFlags && !hasKeepFlags { return runSnipInteractive(ctx, inputFile, outputFile) } - // Flags specified - CLI mode + // Disallow mixing keep + remove flags; ambiguous intent + if hasRemoveFlags && hasKeepFlags { + return fmt.Errorf("cannot combine keep and remove flags; use either --operationId/--operation or --keepOperationId/--keepOperation") + } + + // CLI mode + if hasKeepFlags { + return runSnipCLIKeep(ctx, inputFile, outputFile) + } return runSnipCLI(ctx, inputFile, outputFile) } @@ -139,6 +153,87 @@ func runSnipCLI(ctx context.Context, inputFile, outputFile string) error { return processor.WriteDocument(ctx, doc) } +func runSnipCLIKeep(ctx context.Context, inputFile, outputFile string) error { + // Create processor + processor, err := NewOpenAPIProcessor(inputFile, outputFile, snipWriteInPlace) + if err != nil { + return err + } + + // Load document + doc, validationErrors, err := processor.LoadDocument(ctx) + if err != nil { + return err + } + + // Report validation errors (if any) + processor.ReportValidationErrors(validationErrors) + + // Parse keep flags + keepOps, err := parseKeepOperationFlags() + if err != nil { + return err + } + if len(keepOps) == 0 { + return fmt.Errorf("no operations specified to keep") + } + + // Collect all operations from the document + allOps, err := explore.CollectOperations(ctx, doc) + if err != nil { + return fmt.Errorf("failed to collect operations: %w", err) + } + if len(allOps) == 0 { + return fmt.Errorf("no operations found in the OpenAPI document") + } + + // Build lookup sets for keep filters + keepByID := map[string]bool{} + keepByPathMethod := map[string]bool{} + for _, k := range keepOps { + if k.OperationID != "" { + keepByID[k.OperationID] = true + } + if k.Path != "" && k.Method != "" { + key := strings.ToUpper(k.Method) + " " + k.Path + keepByPathMethod[key] = true + } + } + + // Compute removal list = all - keep + var operationsToRemove []openapi.OperationIdentifier + for _, op := range allOps { + if op.OperationID != "" && keepByID[op.OperationID] { + continue + } + key := strings.ToUpper(op.Method) + " " + op.Path + if keepByPathMethod[key] { + continue + } + operationsToRemove = append(operationsToRemove, openapi.OperationIdentifier{ + Path: op.Path, + Method: strings.ToUpper(op.Method), + }) + } + + // If nothing to remove, write as-is + if len(operationsToRemove) == 0 { + processor.PrintSuccess("No operations to remove based on keep filters; writing document unchanged") + return processor.WriteDocument(ctx, doc) + } + + // Perform the snip + removed, err := openapi.Snip(ctx, doc, operationsToRemove) + if err != nil { + return fmt.Errorf("failed to snip operations: %w", err) + } + + processor.PrintSuccess(fmt.Sprintf("Successfully kept %d operation(s) and removed %d operation(s) with cleanup", len(allOps)-removed, removed)) + + // Write the snipped document + return processor.WriteDocument(ctx, doc) +} + func runSnipInteractive(ctx context.Context, inputFile, outputFile string) error { // Load the OpenAPI document doc, err := loadOpenAPIDocument(ctx, inputFile) @@ -306,6 +401,47 @@ func parseOperationFlags() ([]openapi.OperationIdentifier, error) { return operations, nil } +// parseKeepOperationFlags parses the keep flags into operation identifiers +// Handles both repeated flags and comma-separated values +func parseKeepOperationFlags() ([]openapi.OperationIdentifier, error) { + var operations []openapi.OperationIdentifier + + // Parse keep operation IDs + for _, opID := range snipKeepOperationIDs { + if opID != "" { + operations = append(operations, openapi.OperationIdentifier{ + OperationID: opID, + }) + } + } + + // Parse keep path:method operations + for _, op := range snipKeepOperations { + if op == "" { + continue + } + + parts := strings.SplitN(op, ":", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("invalid keep operation format: %s (expected path:METHOD format, e.g., /users:GET)", op) + } + + path := parts[0] + method := strings.ToUpper(parts[1]) + + if path == "" || method == "" { + return nil, fmt.Errorf("invalid keep operation format: %s (path and method cannot be empty)", op) + } + + operations = append(operations, openapi.OperationIdentifier{ + Path: path, + Method: method, + }) + } + + return operations, nil +} + // GetSnipCommand returns the snip command for external use func GetSnipCommand() *cobra.Command { return snipCmd diff --git a/cmd/openapi/internal/explore/tui/model.go b/cmd/openapi/internal/explore/tui/model.go index 9819f08..2008291 100644 --- a/cmd/openapi/internal/explore/tui/model.go +++ b/cmd/openapi/internal/explore/tui/model.go @@ -140,8 +140,8 @@ func NewModelWithConfig(operations []explore.OperationInfo, docTitle, docVersion // Only relevant when selectionConfig.Enabled is true func (m Model) GetSelectedOperations() []explore.OperationInfo { var selected []explore.OperationInfo - for idx := range m.selected { - if idx < len(m.operations) { + for idx, isSelected := range m.selected { + if isSelected && idx < len(m.operations) { selected = append(selected, m.operations[idx]) } } diff --git a/cmd/openapi/internal/explore/tui/model_test.go b/cmd/openapi/internal/explore/tui/model_test.go new file mode 100644 index 0000000..7bcbdf9 --- /dev/null +++ b/cmd/openapi/internal/explore/tui/model_test.go @@ -0,0 +1,119 @@ +package tui_test + +import ( + "testing" + + "github.com/speakeasy-api/openapi/cmd/openapi/internal/explore" + "github.com/stretchr/testify/assert" +) + +// TestGetSelectedOperations_SelectAllThenDeselect reproduces GEN-2003 +// where selecting all operations then deselecting some would still return all operations +func TestGetSelectedOperations_SelectAllThenDeselect(t *testing.T) { + t.Parallel() + + // Create test operations + operations := []explore.OperationInfo{ + {Path: "/users", Method: "GET", OperationID: "getUsers"}, + {Path: "/users", Method: "POST", OperationID: "createUser"}, + {Path: "/users/{id}", Method: "GET", OperationID: "getUser"}, + {Path: "/users/{id}", Method: "DELETE", OperationID: "deleteUser"}, + {Path: "/posts", Method: "GET", OperationID: "getPosts"}, + } + + // Simulate selecting all operations (like pressing 'a') + // Then deselecting some (like pressing Space on specific operations) + + // Create a model with all operations initially selected + selectedMap := make(map[int]bool) + for i := range operations { + selectedMap[i] = true + } + + // Now deselect operations at indices 0 and 2 (simulating Space key on those) + selectedMap[0] = false + selectedMap[2] = false + + // Since we can't directly manipulate the model's internal state in tests, + // we're testing the core logic that was fixed in GetSelectedOperations() + + // Expected: Only operations at indices 1, 3, 4 should be returned + // (indices 0 and 2 were deselected) + expectedSelected := []explore.OperationInfo{ + operations[1], // POST /users + operations[3], // DELETE /users/{id} + operations[4], // GET /posts + } + + // Manual test of the fixed logic + var actualSelected []explore.OperationInfo + for idx, isSelected := range selectedMap { + if isSelected && idx < len(operations) { + actualSelected = append(actualSelected, operations[idx]) + } + } + + assert.ElementsMatch(t, expectedSelected, actualSelected, + "should only return operations that are marked as selected (true)") + assert.Len(t, actualSelected, 3, "should return 3 selected operations") + assert.NotContains(t, actualSelected, operations[0], "should not include deselected operation at index 0") + assert.NotContains(t, actualSelected, operations[2], "should not include deselected operation at index 2") +} + +// TestGetSelectedOperations_EmptySelection tests that no operations are returned when none are selected +func TestGetSelectedOperations_EmptySelection(t *testing.T) { + t.Parallel() + + operations := []explore.OperationInfo{ + {Path: "/users", Method: "GET", OperationID: "getUsers"}, + {Path: "/posts", Method: "GET", OperationID: "getPosts"}, + } + + selectedMap := make(map[int]bool) + // All entries are false or don't exist + + selectedMap[0] = false + selectedMap[1] = false + + var actualSelected []explore.OperationInfo + for idx, isSelected := range selectedMap { + if isSelected && idx < len(operations) { + actualSelected = append(actualSelected, operations[idx]) + } + } + + assert.Empty(t, actualSelected, "should return no operations when all are deselected") +} + +// TestGetSelectedOperations_PartialSelection tests mixed selection state +func TestGetSelectedOperations_PartialSelection(t *testing.T) { + t.Parallel() + + operations := []explore.OperationInfo{ + {Path: "/users", Method: "GET", OperationID: "getUsers"}, + {Path: "/users", Method: "POST", OperationID: "createUser"}, + {Path: "/posts", Method: "GET", OperationID: "getPosts"}, + } + + selectedMap := map[int]bool{ + 0: true, // Selected + 1: false, // Deselected + 2: true, // Selected + } + + expectedSelected := []explore.OperationInfo{ + operations[0], + operations[2], + } + + var actualSelected []explore.OperationInfo + for idx, isSelected := range selectedMap { + if isSelected && idx < len(operations) { + actualSelected = append(actualSelected, operations[idx]) + } + } + + assert.ElementsMatch(t, expectedSelected, actualSelected, + "should only return operations marked as selected") + assert.Len(t, actualSelected, 2, "should return 2 selected operations") +} diff --git a/openapi/clean.go b/openapi/clean.go index 3ac23d2..7cadfa9 100644 --- a/openapi/clean.go +++ b/openapi/clean.go @@ -9,18 +9,24 @@ import ( "github.com/speakeasy-api/openapi/sequencedmap" ) -// Clean removes unused components from the OpenAPI document. -// It walks through the document to track all referenced components and removes -// any components that are not referenced. Security schemes are handled specially -// as they can be referenced by name in security blocks rather than by $ref. +// Clean removes unused, unreachable elements from the OpenAPI document using reachability from paths and security. +// +// How it works (high level): +// +// - Seed reachability only from: +// - Operations under /paths (responses, request bodies, parameters, schemas, etc.) +// - Security requirements (top-level and operation-level), referenced by $ref and by name +// - Expand reachability only through components already marked as used until a fixed point +// - Remove anything not reachable from those seeds +// - Also remove top-level tags that are not referenced by any operation // // This function modifies the document in place. // // Why use Clean? // -// - **Reduce document size**: Remove unused component definitions that bloat the specification -// - **Improve clarity**: Keep only the components that are actually used in the API -// - **Optimize tooling performance**: Smaller documents with fewer unused components process faster +// - **Reduce document size**: Remove unused component definitions and tags that bloat the specification +// - **Improve clarity**: Keep only the elements that are actually used by operations/security +// - **Optimize tooling performance**: Smaller documents with fewer unused elements process faster // - **Maintain clean specifications**: Prevent accumulation of dead code in API definitions // - **Prepare for distribution**: Clean up specifications before sharing or publishing // @@ -36,6 +42,7 @@ import ( // - Unused links in components/links // - Unused callbacks in components/callbacks // - Unused path items in components/pathItems +// - Unused tags in the top-level tags array // // Special handling for security schemes: // @@ -47,16 +54,13 @@ import ( // // Example usage: // -// // Load an OpenAPI document with potentially unused components -// doc := &OpenAPI{...} -// -// // Clean up unused components (modifies doc in place) +// // Load an OpenAPI document and prune unused elements (modifies doc in place) // err := Clean(ctx, doc) // if err != nil { // return fmt.Errorf("failed to clean document: %w", err) // } // -// // doc now has only the components that are actually referenced +// // doc now contains only elements reachable from /paths and security, with unused tags removed // // Parameters: // - ctx: Context for the operation @@ -65,12 +69,12 @@ import ( // Returns: // - error: Any error that occurred during cleaning func Clean(ctx context.Context, doc *OpenAPI) error { - if doc == nil || doc.Components == nil { + if doc == nil { return nil } // Track referenced components by type and name - referencedComponents := &referencedComponentTracker{ + referenced := &referencedComponentTracker{ schemas: make(map[string]bool), responses: make(map[string]bool), parameters: make(map[string]bool), @@ -83,58 +87,66 @@ func Clean(ctx context.Context, doc *OpenAPI) error { pathItems: make(map[string]bool), } - // Walk through the document and track all references - for item := range Walk(ctx, doc) { - err := item.Match(Matcher{ - // Track schema references - Schema: func(schema *oas3.JSONSchema[oas3.Referenceable]) error { - return trackSchemaReferences(schema, referencedComponents) - }, - // Track component references - ReferencedPathItem: func(ref *ReferencedPathItem) error { - return trackPathItemReference(ref, referencedComponents.pathItems) - }, - ReferencedParameter: func(ref *ReferencedParameter) error { - return trackParameterReference(ref, referencedComponents.parameters) - }, - ReferencedExample: func(ref *ReferencedExample) error { - return trackExampleReference(ref, referencedComponents.examples) - }, - ReferencedRequestBody: func(ref *ReferencedRequestBody) error { - return trackRequestBodyReference(ref, referencedComponents.requestBodies) - }, - ReferencedResponse: func(ref *ReferencedResponse) error { - return trackResponseReference(ref, referencedComponents.responses) - }, - ReferencedHeader: func(ref *ReferencedHeader) error { - return trackHeaderReference(ref, referencedComponents.headers) - }, - ReferencedCallback: func(ref *ReferencedCallback) error { - return trackCallbackReference(ref, referencedComponents.callbacks) - }, - ReferencedLink: func(ref *ReferencedLink) error { - return trackLinkReference(ref, referencedComponents.links) - }, - ReferencedSecurityScheme: func(ref *ReferencedSecurityScheme) error { - return trackSecuritySchemeReference(ref, referencedComponents.securitySchemes) - }, - // Track security requirements (special case for security schemes) - Security: func(req *SecurityRequirement) error { - if req != nil { - for schemeName := range req.All() { - referencedComponents.securitySchemes[schemeName] = true - } - } - return nil - }, + // Phase 1: Seed references only from within /paths + err := walkAndTrackWithFilter(ctx, doc, referenced, func(ptr string) bool { + // Only allow references originating under paths + return strings.HasPrefix(ptr, "/paths/") || strings.HasPrefix(ptr, "/security") + }) + if err != nil { + return fmt.Errorf("failed to track references from paths: %w", err) + } + + // Phase 2: Expand closure of references reachable from used components. + // We repeatedly walk the document but only allow visiting content under components + // that are already marked as referenced, until no new references are discovered. + for { + before := countTracked(referenced) + + err := walkAndTrackWithFilter(ctx, doc, referenced, func(ptr string) bool { + typ, name, ok := extractComponentTypeAndName(ptr) + if !ok { + return false + } + switch typ { + case "schemas": + return referenced.schemas[name] + case "responses": + return referenced.responses[name] + case "parameters": + return referenced.parameters[name] + case "examples": + return referenced.examples[name] + case "requestBodies": + return referenced.requestBodies[name] + case "headers": + return referenced.headers[name] + case "securitySchemes": + return referenced.securitySchemes[name] + case "links": + return referenced.links[name] + case "callbacks": + return referenced.callbacks[name] + case "pathItems": + return referenced.pathItems[name] + default: + return false + } }) if err != nil { - return fmt.Errorf("failed to track references: %w", err) + return fmt.Errorf("failed to expand reachable references: %w", err) + } + + after := countTracked(referenced) + if after == before { + break // fixed point reached } } // Remove unused components - removeUnusedComponentsFromDocument(doc, referencedComponents) + removeUnusedComponentsFromDocument(doc, referenced) + + // Remove unused top-level tags + removeUnusedTagsFromDocument(doc, referenced) return nil } @@ -151,6 +163,8 @@ type referencedComponentTracker struct { links map[string]bool callbacks map[string]bool pathItems map[string]bool + // tags used by operations (referenced by name) + tags map[string]bool } // trackSchemaReferences tracks references within JSON schemas @@ -478,3 +492,167 @@ func removeUnusedComponentsFromDocument(doc *OpenAPI, tracker *referencedCompone doc.Components = nil } } + +// removeUnusedTagsFromDocument prunes tags declared in the top-level tags array +// when they are not referenced by any operation's tags. +func removeUnusedTagsFromDocument(doc *OpenAPI, tracker *referencedComponentTracker) { + if doc == nil { + return + } + + // If no tags are declared, nothing to do + if len(doc.Tags) == 0 { + return + } + + // If there were no tags referenced, drop tags entirely + if tracker == nil || len(tracker.tags) == 0 { + doc.Tags = nil + return + } + + // Keep only tags with names referenced by operations + kept := make([]*Tag, 0, len(doc.Tags)) + for _, tg := range doc.Tags { + if tg == nil { + continue + } + if tracker.tags[tg.GetName()] { + kept = append(kept, tg) + } + } + + if len(kept) > 0 { + doc.Tags = kept + } else { + doc.Tags = nil + } +} + +// walkAndTrackWithFilter walks the OpenAPI document and tracks referenced components, +// but only for WalkItems whose JSON pointer location satisfies the allow predicate. +func walkAndTrackWithFilter(ctx context.Context, doc *OpenAPI, tracker *referencedComponentTracker, allow func(ptr string) bool) error { + for item := range Walk(ctx, doc) { + loc := string(item.Location.ToJSONPointer()) + + if !allow(loc) { + // Skip tracking for this location + continue + } + + err := item.Match(Matcher{ + // Track schema references only when allowed by location + Schema: func(schema *oas3.JSONSchema[oas3.Referenceable]) error { + return trackSchemaReferences(schema, tracker) + }, + // Track component references only when allowed by location + ReferencedPathItem: func(ref *ReferencedPathItem) error { + return trackPathItemReference(ref, tracker.pathItems) + }, + ReferencedParameter: func(ref *ReferencedParameter) error { + return trackParameterReference(ref, tracker.parameters) + }, + ReferencedExample: func(ref *ReferencedExample) error { + return trackExampleReference(ref, tracker.examples) + }, + ReferencedRequestBody: func(ref *ReferencedRequestBody) error { + return trackRequestBodyReference(ref, tracker.requestBodies) + }, + ReferencedResponse: func(ref *ReferencedResponse) error { + return trackResponseReference(ref, tracker.responses) + }, + ReferencedHeader: func(ref *ReferencedHeader) error { + return trackHeaderReference(ref, tracker.headers) + }, + ReferencedCallback: func(ref *ReferencedCallback) error { + return trackCallbackReference(ref, tracker.callbacks) + }, + ReferencedLink: func(ref *ReferencedLink) error { + return trackLinkReference(ref, tracker.links) + }, + ReferencedSecurityScheme: func(ref *ReferencedSecurityScheme) error { + return trackSecuritySchemeReference(ref, tracker.securitySchemes) + }, + // Track operation tags (only under allowed locations) + Operation: func(op *Operation) error { + if op == nil { + return nil + } + for _, tag := range op.GetTags() { + if tracker.tags == nil { + tracker.tags = make(map[string]bool) + } + tracker.tags[tag] = true + } + return nil + }, + // Track security requirements (special case for security schemes) + Security: func(req *SecurityRequirement) error { + if req != nil { + for schemeName := range req.All() { + tracker.securitySchemes[schemeName] = true + } + } + return nil + }, + }) + if err != nil { + return fmt.Errorf("failed to track references: %w", err) + } + } + + return nil +} + +// extractComponentTypeAndName returns component type and name from a JSON pointer location like: +// +// /components/schemas/User/... -> ("schemas", "User", true) +// +// Returns ok=false if the pointer does not point under /components/{type}/{name} +func extractComponentTypeAndName(ptr string) (typ, name string, ok bool) { + const prefix = "/components/" + if !strings.HasPrefix(ptr, prefix) { + return "", "", false + } + + parts := strings.Split(ptr, "/") + // Expect at least: "", "components", "{type}", "{name}", ... + if len(parts) < 4 { + return "", "", false + } + + typ = parts[2] + name = unescapeJSONPointerToken(parts[3]) + if typ == "" || name == "" { + return "", "", false + } + return typ, name, true +} + +// unescapeJSONPointerToken reverses JSON Pointer escaping (~1 => /, ~0 => ~) +func unescapeJSONPointerToken(s string) string { + // Per RFC 6901: ~1 is '/', ~0 is '~'. Order matters: replace ~1 first, then ~0. + s = strings.ReplaceAll(s, "~1", "/") + s = strings.ReplaceAll(s, "~0", "~") + return s +} + +// countTracked returns the total number of referenced components recorded in the tracker. +// This is used to detect a fixed point during reachability expansion. +func countTracked(tr *referencedComponentTracker) int { + if tr == nil { + return 0 + } + total := 0 + total += len(tr.schemas) + total += len(tr.responses) + total += len(tr.parameters) + total += len(tr.examples) + total += len(tr.requestBodies) + total += len(tr.headers) + total += len(tr.securitySchemes) + total += len(tr.links) + total += len(tr.callbacks) + total += len(tr.pathItems) + return total +} diff --git a/openapi/clean_paths_reachability_test.go b/openapi/clean_paths_reachability_test.go new file mode 100644 index 0000000..606e9e4 --- /dev/null +++ b/openapi/clean_paths_reachability_test.go @@ -0,0 +1,220 @@ +package openapi_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/speakeasy-api/openapi/openapi" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Ensures Clean only preserves components reachable from /paths and security, +// and removes components that are only referenced from within components. +// +// Covers scenarios missed previously: +// - Keep schemas transitively reachable from operations (A -> B) +// - Remove self-referential and component-only cycles (Self, Cycle1, Cycle2) +// - Keep security schemes referenced by name via top-level security +// - Remove unused security schemes +func TestClean_ReachabilityFromPathsAndSecurity_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + const yml = ` +openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +security: + - ApiKeyAuth: [] +paths: + /keep: + get: + responses: + "200": + description: ok + content: + application/json: + schema: + $ref: "#/components/schemas/A" +components: + schemas: + A: + type: object + properties: + b: + $ref: "#/components/schemas/B" + B: + type: string + Self: + $ref: "#/components/schemas/Self" + Cycle1: + $ref: "#/components/schemas/Cycle2" + Cycle2: + $ref: "#/components/schemas/Cycle1" + securitySchemes: + ApiKeyAuth: + type: apiKey + in: header + name: X-API-Key + UnusedScheme: + type: http + scheme: bearer +` + + // Unmarshal + doc, validationErrs, err := openapi.Unmarshal(ctx, strings.NewReader(yml)) + require.NoError(t, err, "unmarshal should succeed") + require.Empty(t, validationErrs, "input should be valid") + + // Clean + err = openapi.Clean(ctx, doc) + require.NoError(t, err, "clean should succeed") + + // Marshal and assert against expected YAML output + var buf bytes.Buffer + err = openapi.Marshal(ctx, doc, &buf) + require.NoError(t, err, "marshal should succeed") + actual := buf.String() + + const expected = `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +security: + - ApiKeyAuth: [] +paths: + /keep: + get: + responses: + "200": + description: ok + content: + application/json: + schema: + $ref: "#/components/schemas/A" +components: + schemas: + A: + type: object + properties: + b: + $ref: "#/components/schemas/B" + B: + type: string + securitySchemes: + ApiKeyAuth: + type: apiKey + in: header + name: X-API-Key +` + + assert.Equal(t, expected, actual, "Clean should retain only reachable components (A, B) and ApiKeyAuth") +} + +// Ensures that when no paths (or top-level/operation security) reference components, +// purely self-referential or component-only cycles are all removed and the entire +// components section is dropped. +func TestClean_RemoveOnlySelfReferencedComponents_NoPaths_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + const yml = ` +openapi: 3.1.0 +info: + title: Only Self-Referenced Components + version: 1.0.0 +paths: {} +components: + schemas: + Self: + $ref: "#/components/schemas/Self" + LoopA: + $ref: "#/components/schemas/LoopB" + LoopB: + $ref: "#/components/schemas/LoopA" + responses: + OnlyComponentResponse: + description: Component-only, not referenced from any path + content: + application/json: + schema: + $ref: "#/components/schemas/Self" + parameters: + OnlyComponentParameter: + name: "p" + in: query + schema: + type: string + requestBodies: + OnlyComponentRequestBody: + required: false + content: + application/json: + schema: + type: object + headers: + OnlyComponentHeader: + schema: + type: string + examples: + OnlyComponentExample: + value: + ok: true + links: + OnlyComponentLink: + description: Not referenced from paths + parameters: + id: "$response.body#/id" + callbacks: + OnlyComponentCallback: + "{$request.body#/cb}": + post: + requestBody: + content: + application/json: + schema: + type: object + responses: + "200": + description: ok + pathItems: + OnlyComponentPathItem: + get: + responses: + "200": + description: ok + securitySchemes: + UnusedApiKey: + type: apiKey + in: header + name: X-API-Key + UnusedBearer: + type: http + scheme: bearer +` + + doc, validationErrs, err := openapi.Unmarshal(ctx, strings.NewReader(yml)) + require.NoError(t, err, "unmarshal should succeed") + require.Empty(t, validationErrs, "input should be valid") + + err = openapi.Clean(ctx, doc) + require.NoError(t, err, "clean should succeed") + + // Marshal and assert against expected YAML output + var buf bytes.Buffer + err = openapi.Marshal(ctx, doc, &buf) + require.NoError(t, err, "marshal should succeed") + actual := buf.String() + + const expected = `openapi: 3.1.0 +info: + title: Only Self-Referenced Components + version: 1.0.0 +paths: {} +` + + assert.Equal(t, expected, actual, "All components should be removed when only self/component-only references exist") +} diff --git a/openapi/clean_tags_test.go b/openapi/clean_tags_test.go new file mode 100644 index 0000000..0b894d2 --- /dev/null +++ b/openapi/clean_tags_test.go @@ -0,0 +1,133 @@ +package openapi_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/speakeasy-api/openapi/openapi" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// When top-level tags include tags unused by any operation, Clean should remove the unused ones +// and keep only those referenced by operations, preserving order. +func TestClean_RemoveUnusedTopLevelTags_KeepReferenced_PreserveOrder(t *testing.T) { + t.Parallel() + ctx := t.Context() + + const yml = ` +openapi: 3.1.0 +info: + title: Tags Test + version: 1.0.0 +tags: + - name: users + description: "Users related operations" + - name: admin + description: "Administrative operations" + - name: pets + description: "Pet operations" +paths: + /users: + get: + tags: ["users"] + responses: + "200": + description: ok + /admin: + post: + tags: ["admin"] + responses: + "201": + description: created +` + + doc, validationErrs, err := openapi.Unmarshal(ctx, strings.NewReader(yml)) + require.NoError(t, err, "unmarshal should succeed") + require.Empty(t, validationErrs) + + err = openapi.Clean(ctx, doc) + require.NoError(t, err, "clean should succeed") + + var buf bytes.Buffer + err = openapi.Marshal(ctx, doc, &buf) + require.NoError(t, err, "marshal should succeed") + actual := buf.String() + + // Expect only users and admin tags remain (pets removed), preserve original order + const expected = `openapi: 3.1.0 +info: + title: Tags Test + version: 1.0.0 +tags: + - name: users + description: "Users related operations" + - name: admin + description: "Administrative operations" +paths: + /users: + get: + tags: ["users"] + responses: + "200": + description: ok + /admin: + post: + tags: ["admin"] + responses: + "201": + description: created +` + + assert.Equal(t, expected, actual, "unused top-level tags should be removed; referenced tags kept in original order") +} + +// When no operation references any tag, Clean should remove the entire top-level tags array. +func TestClean_RemoveAllTopLevelTags_WhenUnused(t *testing.T) { + t.Parallel() + ctx := t.Context() + + const yml = ` +openapi: 3.1.0 +info: + title: Tags Test + version: 1.0.0 +tags: + - name: users + - name: admin +paths: + /ping: + get: + responses: + "200": + description: pong +` + + doc, validationErrs, err := openapi.Unmarshal(ctx, strings.NewReader(yml)) + require.NoError(t, err, "unmarshal should succeed") + require.Empty(t, validationErrs) + + err = openapi.Clean(ctx, doc) + require.NoError(t, err, "clean should succeed") + + var buf bytes.Buffer + err = openapi.Marshal(ctx, doc, &buf) + require.NoError(t, err, "marshal should succeed") + actual := buf.String() + + // Expect the tags array to be removed completely + const expected = `openapi: 3.1.0 +info: + title: Tags Test + version: 1.0.0 +paths: + /ping: + get: + responses: + "200": + description: pong +` + + assert.Equal(t, expected, actual, "top-level tags should be removed entirely when unused") +} diff --git a/openapi/reference.go b/openapi/reference.go index ceee4c7..da75067 100644 --- a/openapi/reference.go +++ b/openapi/reference.go @@ -633,7 +633,7 @@ func joinReferenceChain(chain []string) string { return result } -func unmarshaler[T any, V interfaces.Validator[T], C marshaller.CoreModeler](o *OpenAPI) func(context.Context, *yaml.Node, bool) (*Reference[T, V, C], []error, error) { +func unmarshaler[T any, V interfaces.Validator[T], C marshaller.CoreModeler](_ *OpenAPI) func(context.Context, *yaml.Node, bool) (*Reference[T, V, C], []error, error) { return func(ctx context.Context, node *yaml.Node, skipValidation bool) (*Reference[T, V, C], []error, error) { var ref Reference[T, V, C] validationErrs, err := marshaller.UnmarshalNode(ctx, "reference", node, &ref) From b1f1b31b82b92e7f47fa0082ce0fbd0efe86cbbd Mon Sep 17 00:00:00 2001 From: Tristan Cartledge Date: Thu, 16 Oct 2025 16:25:32 +1000 Subject: [PATCH 2/5] chore(openapi): address PR feedback - extract helper functions and add circular reachability test --- openapi/clean.go | 43 ++++++++++------ openapi/clean_paths_reachability_test.go | 64 ++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 16 deletions(-) diff --git a/openapi/clean.go b/openapi/clean.go index 7cadfa9..7554b7c 100644 --- a/openapi/clean.go +++ b/openapi/clean.go @@ -313,6 +313,31 @@ func trackSecuritySchemeReference(ref *ReferencedSecurityScheme, tracker map[str return nil } +// trackOperationTags tracks operation tag names into the tracker +func trackOperationTags(op *Operation, tracker *referencedComponentTracker) error { + if op == nil || tracker == nil { + return nil + } + for _, tag := range op.GetTags() { + if tracker.tags == nil { + tracker.tags = make(map[string]bool) + } + tracker.tags[tag] = true + } + return nil +} + +// trackSecurityRequirementNames tracks security scheme names referenced by a security requirement +func trackSecurityRequirementNames(req *SecurityRequirement, tracker map[string]bool) error { + if req == nil || tracker == nil { + return nil + } + for schemeName := range req.All() { + tracker[schemeName] = true + } + return nil +} + // extractComponentName extracts the component name from a reference string func extractComponentName(refStr, componentType string) string { prefix := "#/components/" + componentType + "/" @@ -575,25 +600,11 @@ func walkAndTrackWithFilter(ctx context.Context, doc *OpenAPI, tracker *referenc }, // Track operation tags (only under allowed locations) Operation: func(op *Operation) error { - if op == nil { - return nil - } - for _, tag := range op.GetTags() { - if tracker.tags == nil { - tracker.tags = make(map[string]bool) - } - tracker.tags[tag] = true - } - return nil + return trackOperationTags(op, tracker) }, // Track security requirements (special case for security schemes) Security: func(req *SecurityRequirement) error { - if req != nil { - for schemeName := range req.All() { - tracker.securitySchemes[schemeName] = true - } - } - return nil + return trackSecurityRequirementNames(req, tracker.securitySchemes) }, }) if err != nil { diff --git a/openapi/clean_paths_reachability_test.go b/openapi/clean_paths_reachability_test.go index 606e9e4..73abd25 100644 --- a/openapi/clean_paths_reachability_test.go +++ b/openapi/clean_paths_reachability_test.go @@ -218,3 +218,67 @@ paths: {} assert.Equal(t, expected, actual, "All components should be removed when only self/component-only references exist") } + +func TestClean_Reachability_CircularSchemas_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + const yml = ` +openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /keep: + get: + responses: + "200": + description: ok + content: + application/json: + schema: + $ref: "#/components/schemas/Cycle1" +components: + schemas: + Cycle1: + $ref: "#/components/schemas/Cycle2" + Cycle2: + $ref: "#/components/schemas/Cycle1" +` + + doc, validationErrs, err := openapi.Unmarshal(ctx, strings.NewReader(yml)) + require.NoError(t, err, "unmarshal should succeed") + require.Empty(t, validationErrs, "input should be valid") + + err = openapi.Clean(ctx, doc) + require.NoError(t, err, "clean should succeed") + + var buf bytes.Buffer + err = openapi.Marshal(ctx, doc, &buf) + require.NoError(t, err, "marshal should succeed") + actual := buf.String() + + const expected = `openapi: 3.1.0 +info: + title: Test API + version: 1.0.0 +paths: + /keep: + get: + responses: + "200": + description: ok + content: + application/json: + schema: + $ref: "#/components/schemas/Cycle1" +components: + schemas: + Cycle1: + $ref: "#/components/schemas/Cycle2" + Cycle2: + $ref: "#/components/schemas/Cycle1" +` + + assert.Equal(t, expected, actual, "Clean should keep circularly referenced components reachable from paths without hanging") +} From eb52ac534cb6ddaf2582b140a73fd7d22b4014dc Mon Sep 17 00:00:00 2001 From: Tristan Cartledge Date: Thu, 16 Oct 2025 17:38:32 +1000 Subject: [PATCH 3/5] feat: add Swagger 2.0 support with upgrade to OpenAPI 3.0 utility --- README.md | 16 +- arazzo/arazzo_test.go | 61 +- arazzo/core/criterion_syncchanges_test.go | 45 + arazzo/core/reusable_test.go | 57 + arazzo/criterion/criterion_validate_test.go | 476 ++++++ cmd/openapi/commands/swagger/README.md | 120 ++ cmd/openapi/commands/swagger/root.go | 256 +++ cmd/openapi/main.go | 14 + internal/testutils/utils.go | 65 + json/json.go | 502 +++++- json/json_test.go | 944 +++++++++-- jsonschema/oas3/core/factory_registration.go | 11 + marshaller/coremodel.go | 24 +- marshaller/unmarshaller.go | 19 +- mise-tasks/test-coverage | 5 +- openapi/operation.go | 1 + sequencedmap/map.go | 98 +- swagger/core/externaldocs.go | 15 + swagger/core/factory_registration.go | 76 + swagger/core/info.go | 38 + swagger/core/operation.go | 25 + swagger/core/parameter.go | 69 + swagger/core/paths.go | 90 ++ swagger/core/reference.go | 120 ++ swagger/core/response.go | 93 ++ swagger/core/security.go | 34 + swagger/core/swagger.go | 30 + swagger/core/tag.go | 16 + swagger/externaldocs.go | 70 + swagger/factory_registration.go | 83 + swagger/info.go | 251 +++ swagger/info_validate_test.go | 284 ++++ swagger/marshalling.go | 63 + swagger/operation.go | 223 +++ swagger/operation_validate_test.go | 116 ++ swagger/parameter.go | 400 +++++ swagger/parameter_test.go | 428 +++++ swagger/paths.go | 194 +++ swagger/reference.go | 142 ++ swagger/response.go | 259 +++ swagger/response_validate_test.go | 211 +++ swagger/roundtrip_test.go | 125 ++ swagger/security.go | 292 ++++ swagger/security_validate_test.go | 215 +++ swagger/swagger.go | 331 ++++ swagger/swagger_test.go | 173 ++ swagger/swagger_validate_test.go | 1155 ++++++++++++++ swagger/tag.go | 77 + swagger/tag_validate_test.go | 177 ++ swagger/testdata/test.swagger.json | 684 ++++++++ swagger/testdata/walk.swagger.json | 216 +++ swagger/upgrade.go | 867 ++++++++++ swagger/upgrade_test.go | 1506 ++++++++++++++++++ swagger/walk.go | 685 ++++++++ swagger/walk_matching.go | 142 ++ swagger/walk_test.go | 889 +++++++++++ yml/config.go | 88 +- yml/config_test.go | 134 +- 58 files changed, 13491 insertions(+), 309 deletions(-) create mode 100644 arazzo/core/criterion_syncchanges_test.go create mode 100644 arazzo/core/reusable_test.go create mode 100644 arazzo/criterion/criterion_validate_test.go create mode 100644 cmd/openapi/commands/swagger/README.md create mode 100644 cmd/openapi/commands/swagger/root.go create mode 100644 swagger/core/externaldocs.go create mode 100644 swagger/core/factory_registration.go create mode 100644 swagger/core/info.go create mode 100644 swagger/core/operation.go create mode 100644 swagger/core/parameter.go create mode 100644 swagger/core/paths.go create mode 100644 swagger/core/reference.go create mode 100644 swagger/core/response.go create mode 100644 swagger/core/security.go create mode 100644 swagger/core/swagger.go create mode 100644 swagger/core/tag.go create mode 100644 swagger/externaldocs.go create mode 100644 swagger/factory_registration.go create mode 100644 swagger/info.go create mode 100644 swagger/info_validate_test.go create mode 100644 swagger/marshalling.go create mode 100644 swagger/operation.go create mode 100644 swagger/operation_validate_test.go create mode 100644 swagger/parameter.go create mode 100644 swagger/parameter_test.go create mode 100644 swagger/paths.go create mode 100644 swagger/reference.go create mode 100644 swagger/response.go create mode 100644 swagger/response_validate_test.go create mode 100644 swagger/roundtrip_test.go create mode 100644 swagger/security.go create mode 100644 swagger/security_validate_test.go create mode 100644 swagger/swagger.go create mode 100644 swagger/swagger_test.go create mode 100644 swagger/swagger_validate_test.go create mode 100644 swagger/tag.go create mode 100644 swagger/tag_validate_test.go create mode 100644 swagger/testdata/test.swagger.json create mode 100644 swagger/testdata/walk.swagger.json create mode 100644 swagger/upgrade.go create mode 100644 swagger/upgrade_test.go create mode 100644 swagger/walk.go create mode 100644 swagger/walk_matching.go create mode 100644 swagger/walk_test.go diff --git a/README.md b/README.md index 73a61ea..a1502b8 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,10 @@ The `arazzo` package provides an API for working with Arazzo documents including The `openapi` package provides an API for working with OpenAPI documents including reading, creating, mutating, walking, validating and upgrading them. Supports both OpenAPI 3.0.x and 3.1.x specifications. +### [swagger](./swagger) + +The `swagger` package provides an API for working with Swagger 2.0 documents including reading, creating, mutating, walking, validating, and upgrading them to OpenAPI 3.0. + ### [overlay](./overlay) The `overlay` package provides an API for working with OpenAPI Overlays including applying overlays to specifications, comparing specifications to generate overlays, and validating overlay documents. @@ -93,7 +97,7 @@ go install github.com/speakeasy-api/openapi/cmd/openapi@latest ### Usage -The CLI provides three main command groups: +The CLI provides four main command groups: - **`openapi spec`** - Commands for working with OpenAPI specifications ([documentation](./cmd/openapi/commands/openapi/README.md)) - `bootstrap` - Create a new OpenAPI document with best practice examples @@ -109,6 +113,10 @@ The CLI provides three main command groups: - `upgrade` - Upgrade an OpenAPI specification to the latest supported version - `validate` - Validate an OpenAPI specification document +- **`openapi swagger`** - Commands for working with Swagger 2.0 documents ([documentation](./cmd/openapi/commands/swagger/README.md)) + - `validate` - Validate a Swagger 2.0 specification document + - `upgrade` - Upgrade a Swagger 2.0 specification to OpenAPI 3.0 + - **`openapi arazzo`** - Commands for working with Arazzo workflow documents ([documentation](./cmd/openapi/commands/arazzo/README.md)) - `validate` - Validate an Arazzo workflow document @@ -137,6 +145,12 @@ openapi overlay apply --overlay overlay.yaml --schema spec.yaml # Validate an Arazzo workflow document openapi arazzo validate ./workflow.arazzo.yaml + +# Validate a Swagger 2.0 document +openapi swagger validate ./api.swagger.yaml + +# Upgrade Swagger 2.0 to OpenAPI 3.0 +openapi swagger upgrade ./api.swagger.yaml ./openapi.yaml ``` For detailed usage instructions for each command group, see the individual documentation linked above. diff --git a/arazzo/arazzo_test.go b/arazzo/arazzo_test.go index c199a23..7937b57 100644 --- a/arazzo/arazzo_test.go +++ b/arazzo/arazzo_test.go @@ -2,14 +2,10 @@ package arazzo_test import ( "bytes" - "crypto/sha256" - "encoding/hex" "errors" "fmt" "io" - "net/http" "os" - "path/filepath" "strings" "testing" @@ -18,6 +14,7 @@ import ( "github.com/speakeasy-api/openapi/arazzo/criterion" "github.com/speakeasy-api/openapi/expression" "github.com/speakeasy-api/openapi/extensions" + "github.com/speakeasy-api/openapi/internal/testutils" "github.com/speakeasy-api/openapi/jsonpointer" "github.com/speakeasy-api/openapi/jsonschema/oas3" jsonschema_core "github.com/speakeasy-api/openapi/jsonschema/oas3/core" @@ -726,7 +723,7 @@ func TestArazzo_StressTests_Validate(t *testing.T) { require.NoError(t, err) } else { var err error - r, err = downloadFile(tt.args.location) + r, err = testutils.DownloadFile(tt.args.location, "ARAZZO_CACHE_DIR", "speakeasy-api_arazzo") require.NoError(t, err) } defer r.Close() @@ -759,7 +756,7 @@ func TestArazzo_StressTests_RoundTrip(t *testing.T) { require.NoError(t, err) } else { var err error - r, err = downloadFile(tt.args.location) + r, err = testutils.DownloadFile(tt.args.location, "ARAZZO_CACHE_DIR", "speakeasy-api_arazzo") require.NoError(t, err) } defer r.Close() @@ -787,58 +784,6 @@ func TestArazzo_StressTests_RoundTrip(t *testing.T) { } } -func downloadFile(url string) (io.ReadCloser, error) { - // Use environment variable for cache directory, fallback to system temp dir - cacheDir := os.Getenv("ARAZZO_CACHE_DIR") - if cacheDir == "" { - cacheDir = os.TempDir() - } - tempDir := filepath.Join(cacheDir, "speakeasy-api_arazzo") - - if err := os.MkdirAll(tempDir, os.ModePerm); err != nil { - return nil, err - } - - // hash url to create a unique filename - hash := sha256.Sum256([]byte(url)) - filename := hex.EncodeToString(hash[:]) - - filepath := filepath.Join(tempDir, filename) - - // check if file exists and return it otherwise download it - r, err := os.Open(filepath) - if err == nil { - return r, nil - } - - resp, err := http.Get(url) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - // Read all data from response body - data, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - // Write data to cache file - f, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return nil, err - } - defer f.Close() - - _, err = f.Write(data) - if err != nil { - return nil, err - } - - // Return the data as a ReadCloser - return io.NopCloser(bytes.NewReader(data)), nil -} - func roundTripYamlOnly(data []byte) ([]byte, error) { var node yaml.Node diff --git a/arazzo/core/criterion_syncchanges_test.go b/arazzo/core/criterion_syncchanges_test.go new file mode 100644 index 0000000..7a78888 --- /dev/null +++ b/arazzo/core/criterion_syncchanges_test.go @@ -0,0 +1,45 @@ +package core + +import ( + "testing" + + "github.com/speakeasy-api/openapi/pointer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestCriterionTypeUnion_SyncChanges_WithStringType_Success(t *testing.T) { + t.Parallel() + + yamlContent := "simple" + var node yaml.Node + err := yaml.Unmarshal([]byte(yamlContent), &node) + require.NoError(t, err, "unmarshal should succeed") + + var ctu CriterionTypeUnion + validationErrs, err := ctu.Unmarshal(t.Context(), "test", node.Content[0]) + require.NoError(t, err, "unmarshal should succeed") + require.Empty(t, validationErrs, "validation errors should be empty") + + model := CriterionTypeUnion{ + Type: pointer.From("simple"), + } + + resultNode, err := ctu.SyncChanges(t.Context(), model, node.Content[0]) + require.NoError(t, err, "SyncChanges should succeed") + assert.NotNil(t, resultNode, "result node should not be nil") +} + +func TestCriterionTypeUnion_SyncChanges_NonStruct_Error(t *testing.T) { + t.Parallel() + + var node yaml.Node + err := yaml.Unmarshal([]byte("simple"), &node) + require.NoError(t, err, "unmarshal should succeed") + + ctu := CriterionTypeUnion{} + _, err = ctu.SyncChanges(t.Context(), "not a struct", node.Content[0]) + require.Error(t, err, "SyncChanges should fail") + assert.Contains(t, err.Error(), "CriterionTypeUnion.SyncChanges expected a struct, got string", "error message should match") +} diff --git a/arazzo/core/reusable_test.go b/arazzo/core/reusable_test.go new file mode 100644 index 0000000..86d8c48 --- /dev/null +++ b/arazzo/core/reusable_test.go @@ -0,0 +1,57 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestReusable_Unmarshal_WithReference_Success(t *testing.T) { + t.Parallel() + + yamlContent := `reference: '#/components/parameters/userId'` + + var node yaml.Node + err := yaml.Unmarshal([]byte(yamlContent), &node) + require.NoError(t, err, "unmarshal should succeed") + + var reusable Reusable[*Parameter] + validationErrs, err := reusable.Unmarshal(t.Context(), "test", node.Content[0]) + require.NoError(t, err, "unmarshal should succeed") + require.Empty(t, validationErrs, "validation errors should be empty") + assert.True(t, reusable.GetValid(), "reusable should be valid") + assert.True(t, reusable.Reference.Present, "reference should be present") + assert.NotNil(t, reusable.Reference.Value, "reference value should not be nil") +} + +func TestReusable_Unmarshal_NonMappingNode_Error(t *testing.T) { + t.Parallel() + + yamlContent := "- item1\n- item2" + + var node yaml.Node + err := yaml.Unmarshal([]byte(yamlContent), &node) + require.NoError(t, err, "unmarshal should succeed") + + var reusable Reusable[*Parameter] + validationErrs, err := reusable.Unmarshal(t.Context(), "test", node.Content[0]) + require.NoError(t, err, "unmarshal error should be nil") + require.NotEmpty(t, validationErrs, "validation errors should not be empty") + assert.Contains(t, validationErrs[0].Error(), "reusable expected object", "error message should match") + assert.False(t, reusable.GetValid(), "reusable should not be valid") +} + +func TestReusable_SyncChanges_NonStruct_Error(t *testing.T) { + t.Parallel() + + var node yaml.Node + err := yaml.Unmarshal([]byte(`reference: '#/test'`), &node) + require.NoError(t, err, "unmarshal should succeed") + + reusable := Reusable[*Parameter]{} + _, err = reusable.SyncChanges(t.Context(), "not a struct", node.Content[0]) + require.Error(t, err, "SyncChanges should fail") + assert.Contains(t, err.Error(), "Reusable.SyncChanges expected a struct, got string", "error message should match") +} diff --git a/arazzo/criterion/criterion_validate_test.go b/arazzo/criterion/criterion_validate_test.go new file mode 100644 index 0000000..a34e28f --- /dev/null +++ b/arazzo/criterion/criterion_validate_test.go @@ -0,0 +1,476 @@ +package criterion_test + +import ( + "testing" + + "github.com/speakeasy-api/openapi/arazzo/criterion" + "github.com/speakeasy-api/openapi/expression" + "github.com/speakeasy-api/openapi/pointer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCriterionExpressionType_Validate_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cet *criterion.CriterionExpressionType + }{ + { + name: "valid jsonpath with correct version", + cet: &criterion.CriterionExpressionType{ + Type: criterion.CriterionTypeJsonPath, + Version: criterion.CriterionTypeVersionDraftGoessnerDispatchJsonPath00, + }, + }, + { + name: "valid xpath with version 3.0", + cet: &criterion.CriterionExpressionType{ + Type: criterion.CriterionTypeXPath, + Version: criterion.CriterionTypeVersionXPath30, + }, + }, + { + name: "valid xpath with version 2.0", + cet: &criterion.CriterionExpressionType{ + Type: criterion.CriterionTypeXPath, + Version: criterion.CriterionTypeVersionXPath20, + }, + }, + { + name: "valid xpath with version 1.0", + cet: &criterion.CriterionExpressionType{ + Type: criterion.CriterionTypeXPath, + Version: criterion.CriterionTypeVersionXPath10, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + errs := tt.cet.Validate() + assert.Empty(t, errs, "validation should succeed") + assert.True(t, tt.cet.Valid, "criterion expression type should be valid") + }) + } +} + +func TestCriterionExpressionType_Validate_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cet *criterion.CriterionExpressionType + expectedError string + }{ + { + name: "invalid jsonpath version", + cet: &criterion.CriterionExpressionType{ + Type: criterion.CriterionTypeJsonPath, + Version: "invalid-version", + }, + expectedError: "version must be one of [draft-goessner-dispatch-jsonpath-00]", + }, + { + name: "invalid xpath version", + cet: &criterion.CriterionExpressionType{ + Type: criterion.CriterionTypeXPath, + Version: "invalid-version", + }, + expectedError: "version must be one of [xpath-30, xpath-20, xpath-10]", + }, + { + name: "invalid type", + cet: &criterion.CriterionExpressionType{ + Type: "invalid-type", + Version: criterion.CriterionTypeVersionNone, + }, + expectedError: "type must be one of [jsonpath, xpath]", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + errs := tt.cet.Validate() + require.NotEmpty(t, errs, "validation should fail") + assert.Contains(t, errs[0].Error(), tt.expectedError, "error message should match") + assert.False(t, tt.cet.Valid, "criterion expression type should not be valid") + }) + } +} + +func TestCriterionExpressionType_IsTypeProvided(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cet *criterion.CriterionExpressionType + expected bool + }{ + { + name: "nil criterion expression type", + cet: nil, + expected: false, + }, + { + name: "empty type", + cet: &criterion.CriterionExpressionType{ + Type: "", + }, + expected: false, + }, + { + name: "type provided", + cet: &criterion.CriterionExpressionType{ + Type: criterion.CriterionTypeJsonPath, + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := tt.cet.IsTypeProvided() + assert.Equal(t, tt.expected, result, "IsTypeProvided should return expected value") + }) + } +} + +func TestCriterionTypeUnion_GetCore(t *testing.T) { + t.Parallel() + + ctu := &criterion.CriterionTypeUnion{ + Type: pointer.From(criterion.CriterionTypeSimple), + } + + core := ctu.GetCore() + assert.NotNil(t, core, "GetCore should return non-nil value") +} + +func TestCriterionTypeUnion_IsTypeProvided(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ctu *criterion.CriterionTypeUnion + expected bool + }{ + { + name: "nil criterion type union", + ctu: nil, + expected: false, + }, + { + name: "empty criterion type union", + ctu: &criterion.CriterionTypeUnion{}, + expected: false, + }, + { + name: "type provided as string", + ctu: &criterion.CriterionTypeUnion{ + Type: pointer.From(criterion.CriterionTypeSimple), + }, + expected: true, + }, + { + name: "type provided as expression type", + ctu: &criterion.CriterionTypeUnion{ + ExpressionType: &criterion.CriterionExpressionType{ + Type: criterion.CriterionTypeJsonPath, + }, + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := tt.ctu.IsTypeProvided() + assert.Equal(t, tt.expected, result, "IsTypeProvided should return expected value") + }) + } +} + +func TestCriterionTypeUnion_GetType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ctu criterion.CriterionTypeUnion + expected criterion.CriterionType + }{ + { + name: "empty criterion type union returns simple", + ctu: criterion.CriterionTypeUnion{}, + expected: criterion.CriterionTypeSimple, + }, + { + name: "type provided as string", + ctu: criterion.CriterionTypeUnion{ + Type: pointer.From(criterion.CriterionTypeRegex), + }, + expected: criterion.CriterionTypeRegex, + }, + { + name: "type provided as expression type", + ctu: criterion.CriterionTypeUnion{ + ExpressionType: &criterion.CriterionExpressionType{ + Type: criterion.CriterionTypeJsonPath, + }, + }, + expected: criterion.CriterionTypeJsonPath, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := tt.ctu.GetType() + assert.Equal(t, tt.expected, result, "GetType should return expected value") + }) + } +} + +func TestCriterionTypeUnion_GetVersion(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ctu criterion.CriterionTypeUnion + expected criterion.CriterionTypeVersion + }{ + { + name: "empty criterion type union returns none", + ctu: criterion.CriterionTypeUnion{}, + expected: criterion.CriterionTypeVersionNone, + }, + { + name: "type provided as string returns none", + ctu: criterion.CriterionTypeUnion{ + Type: pointer.From(criterion.CriterionTypeRegex), + }, + expected: criterion.CriterionTypeVersionNone, + }, + { + name: "type provided as expression type with version", + ctu: criterion.CriterionTypeUnion{ + ExpressionType: &criterion.CriterionExpressionType{ + Type: criterion.CriterionTypeJsonPath, + Version: criterion.CriterionTypeVersionDraftGoessnerDispatchJsonPath00, + }, + }, + expected: criterion.CriterionTypeVersionDraftGoessnerDispatchJsonPath00, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := tt.ctu.GetVersion() + assert.Equal(t, tt.expected, result, "GetVersion should return expected value") + }) + } +} + +func TestCriterion_GetCondition_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + criterion *criterion.Criterion + expectedCondition *criterion.Condition + }{ + { + name: "valid simple condition", + criterion: &criterion.Criterion{ + Condition: "$statusCode == 200", + }, + expectedCondition: &criterion.Condition{ + Expression: expression.Expression("$statusCode"), + Operator: criterion.OperatorEQ, + Value: "200", + }, + }, + { + name: "raw value returns nil", + criterion: &criterion.Criterion{ + Condition: "some raw value", + }, + expectedCondition: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.criterion.Sync(t.Context()) + require.NoError(t, err, "sync should succeed") + + cond, err := tt.criterion.GetCondition() + require.NoError(t, err, "GetCondition should succeed") + + if tt.expectedCondition == nil { + assert.Nil(t, cond, "condition should be nil") + } else { + require.NotNil(t, cond, "condition should not be nil") + assert.Equal(t, tt.expectedCondition.Expression, cond.Expression, "expression should match") + assert.Equal(t, tt.expectedCondition.Operator, cond.Operator, "operator should match") + assert.Equal(t, tt.expectedCondition.Value, cond.Value, "value should match") + } + }) + } +} + +func TestCriterion_Validate_WithTypes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + criterion *criterion.Criterion + wantError bool + }{ + { + name: "valid simple type with context", + criterion: &criterion.Criterion{ + Context: pointer.From(expression.Expression("$response.body")), + Condition: "$statusCode == 200", + Type: criterion.CriterionTypeUnion{ + Type: pointer.From(criterion.CriterionTypeSimple), + }, + }, + wantError: false, + }, + { + name: "valid simple condition without explicit type", + criterion: &criterion.Criterion{ + Condition: "$statusCode == 200", + }, + wantError: false, + }, + { + name: "valid regex type", + criterion: &criterion.Criterion{ + Context: pointer.From(expression.Expression("$response.body")), + Condition: "^[a-z]+$", + Type: criterion.CriterionTypeUnion{ + Type: pointer.From(criterion.CriterionTypeRegex), + }, + }, + wantError: false, + }, + { + name: "invalid regex pattern", + criterion: &criterion.Criterion{ + Context: pointer.From(expression.Expression("$response.body")), + Condition: "[invalid", + Type: criterion.CriterionTypeUnion{ + Type: pointer.From(criterion.CriterionTypeRegex), + }, + }, + wantError: true, + }, + { + name: "valid jsonpath type", + criterion: &criterion.Criterion{ + Context: pointer.From(expression.Expression("$response.body")), + Condition: "$[?count(@.pets) > 0]", + Type: criterion.CriterionTypeUnion{ + Type: pointer.From(criterion.CriterionTypeJsonPath), + }, + }, + wantError: false, + }, + { + name: "invalid jsonpath expression", + criterion: &criterion.Criterion{ + Context: pointer.From(expression.Expression("$response.body")), + Condition: "$[invalid jsonpath", + Type: criterion.CriterionTypeUnion{ + Type: pointer.From(criterion.CriterionTypeJsonPath), + }, + }, + wantError: true, + }, + { + name: "xpath type validation skipped", + criterion: &criterion.Criterion{ + Context: pointer.From(expression.Expression("$response.body")), + Condition: "//book[@category='web']", + Type: criterion.CriterionTypeUnion{ + Type: pointer.From(criterion.CriterionTypeXPath), + }, + }, + wantError: false, + }, + { + name: "invalid type", + criterion: &criterion.Criterion{ + Condition: "$statusCode == 200", + Type: criterion.CriterionTypeUnion{ + Type: pointer.From(criterion.CriterionType("invalid")), + }, + }, + wantError: true, + }, + { + name: "missing context when type is set", + criterion: &criterion.Criterion{ + Condition: "$statusCode == 200", + Type: criterion.CriterionTypeUnion{ + Type: pointer.From(criterion.CriterionTypeJsonPath), + }, + }, + wantError: true, + }, + { + name: "invalid context expression", + criterion: &criterion.Criterion{ + Context: pointer.From(expression.Expression("invalid_expression")), + Condition: "$[?count(@.pets) > 0]", + Type: criterion.CriterionTypeUnion{ + Type: pointer.From(criterion.CriterionTypeJsonPath), + }, + }, + wantError: true, + }, + { + name: "missing condition", + criterion: &criterion.Criterion{ + Type: criterion.CriterionTypeUnion{ + Type: pointer.From(criterion.CriterionTypeSimple), + }, + }, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tt.criterion.Sync(t.Context()) + require.NoError(t, err, "sync should succeed") + + errs := tt.criterion.Validate() + if tt.wantError { + assert.NotEmpty(t, errs, "validation should fail") + assert.False(t, tt.criterion.Valid, "criterion should not be valid") + } else { + assert.Empty(t, errs, "validation should succeed") + assert.True(t, tt.criterion.Valid, "criterion should be valid") + } + }) + } +} diff --git a/cmd/openapi/commands/swagger/README.md b/cmd/openapi/commands/swagger/README.md new file mode 100644 index 0000000..114ed09 --- /dev/null +++ b/cmd/openapi/commands/swagger/README.md @@ -0,0 +1,120 @@ +# Swagger Commands + +Commands for working with Swagger 2.0 (OpenAPI v2) specifications. + +Swagger 2.0 documents describe REST APIs prior to OpenAPI 3.x. These commands help you validate and upgrade Swagger documents. + +## Table of Contents + +- [Table of Contents](#table-of-contents) +- [Available Commands](#available-commands) + - [`validate`](#validate) + - [`upgrade`](#upgrade) +- [What is Swagger 2.0?](#what-is-swagger-20) +- [Common Options](#common-options) +- [Output Formats](#output-formats) +- [Examples](#examples) + - [Validate a Swagger document](#validate-a-swagger-document) + - [Upgrade Swagger to OpenAPI 3.0](#upgrade-swagger-to-openapi-30) + - [In-place upgrade](#in-place-upgrade) + - [Pipe-friendly usage](#pipe-friendly-usage) + +## Available Commands + +### `validate` + +Validate a Swagger 2.0 (OpenAPI v2) specification document for compliance. + +```bash +openapi swagger validate +``` + +This command checks for: + +- Structural validity according to the Swagger 2.0 Specification +- Required fields and proper data types +- Reference resolution and consistency +- Schema validation rules + +Exits with a non-zero status code when validation fails. + +### `upgrade` + +Convert a Swagger 2.0 document to OpenAPI 3.0 (3.0.0). + +```bash +openapi swagger upgrade [output-file] +``` + +The upgrade process includes: + +- Converting host/basePath/schemes to `servers` +- Transforming parameters, request bodies, and responses to OAS3 structures +- Mapping `definitions` to `components.schemas` +- Migrating `securityDefinitions` to `components.securitySchemes` +- Rewriting `$ref` targets from `#/definitions/...` to `#/components/schemas/...` + +Behavior: + +- If no `output-file` is provided, upgraded output is written to stdout (pipe-friendly) +- If `output-file` is provided, writes the upgraded document to that file +- If `--write`/`-w` is provided, upgrades in-place (overwrites the input file) + +## What is Swagger 2.0? + +Swagger 2.0 is an older version of the API description format now standardized as OpenAPI 3.x. This CLI supports validating Swagger 2.0 specs and upgrading them to OpenAPI 3.0 for compatibility with modern tooling and features. + +## Common Options + +All commands support these common options: + +- `-h, --help`: Show help for the command +- `-v, --verbose`: Enable verbose output (global flag) + +Upgrade-specific options: + +- `-w, --write`: Write result in-place to input file (overwrites the input) + +## Output Formats + +- Input files may be YAML or JSON +- Output respects YAML/JSON based on the marshaller and target file extension (when writing to a file) +- Stdout output is designed to be pipe-friendly + +## Examples + +### Validate a Swagger document + +```bash +# Validate a JSON Swagger document +openapi swagger validate ./api.swagger.json + +# Validate a YAML Swagger document +openapi swagger validate ./api.swagger.yaml +``` + +### Upgrade Swagger to OpenAPI 3.0 + +```bash +# Upgrade and write to stdout +openapi swagger upgrade ./api.swagger.yaml + +# Upgrade and write to a specific file +openapi swagger upgrade ./api.swagger.yaml ./openapi.yaml +``` + +### In-place upgrade + +```bash +# Overwrite the input file with the upgraded OpenAPI 3.0 document +openapi swagger upgrade -w ./api.swagger.yaml +``` + +### Pipe-friendly usage + +```bash +# Upgrade and then validate with the OpenAPI validator +openapi swagger upgrade ./api.swagger.yaml | openapi spec validate - + +# Upgrade and bundle +openapi swagger upgrade ./api.swagger.yaml | openapi spec bundle - ./openapi-bundled.yaml diff --git a/cmd/openapi/commands/swagger/root.go b/cmd/openapi/commands/swagger/root.go new file mode 100644 index 0000000..bd7aec2 --- /dev/null +++ b/cmd/openapi/commands/swagger/root.go @@ -0,0 +1,256 @@ +package swagger + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/openapi" + sw "github.com/speakeasy-api/openapi/swagger" + "github.com/spf13/cobra" +) + +var validateCmd = &cobra.Command{ + Use: "validate ", + Short: "Validate a Swagger 2.0 specification document", + Long: `Validate a Swagger 2.0 (OpenAPI v2) specification document for compliance. + +This command will parse and validate the provided Swagger document, checking for: +- Structural validity according to the Swagger 2.0 Specification +- Required fields and proper data types +- Reference resolution and consistency +- Schema validation rules`, + Args: cobra.ExactArgs(1), + Run: runValidate, +} + +var upgradeCmd = &cobra.Command{ + Use: "upgrade [output-file]", + Short: "Upgrade a Swagger 2.0 specification to OpenAPI 3.0", + Long: `Convert a Swagger 2.0 (OpenAPI v2) document to OpenAPI 3.0 (3.0.0). + +The upgrade process includes: +- Converting host/basePath/schemes to servers +- Transforming parameters, request bodies, and responses to OAS3 structures +- Mapping definitions to components.schemas +- Migrating securityDefinitions to components.securitySchemes +- Rewriting $ref targets to OAS3 component locations + +Output options: +- No output file specified: writes to stdout (pipe-friendly) +- Output file specified: writes to the specified file +- --write flag: writes in-place to the input file`, + Args: cobra.RangeArgs(1, 2), + Run: runUpgrade, +} + +var writeInPlace bool + +func init() { + upgradeCmd.Flags().BoolVarP(&writeInPlace, "write", "w", false, "write result in-place to input file") +} + +// Apply registers the swagger command group on the provided parent command. +func Apply(rootCmd *cobra.Command) { + rootCmd.AddCommand(validateCmd) + rootCmd.AddCommand(upgradeCmd) +} + +func runValidate(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + file := args[0] + + if err := validateSwagger(ctx, file); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} + +func validateSwagger(ctx context.Context, file string) error { + cleanFile := filepath.Clean(file) + fmt.Printf("Validating Swagger document: %s\n", cleanFile) + + f, err := os.Open(cleanFile) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer f.Close() + + _, validationErrors, err := sw.Unmarshal(ctx, f) + if err != nil { + return fmt.Errorf("failed to unmarshal file: %w", err) + } + + if len(validationErrors) == 0 { + fmt.Printf("✅ Swagger document is valid - 0 errors\n") + return nil + } + + fmt.Printf("❌ Swagger document is invalid - %d errors:\n\n", len(validationErrors)) + + for i, validationErr := range validationErrors { + fmt.Printf("%d. %s\n", i+1, validationErr.Error()) + } + + return errors.New("swagger document validation failed") +} + +// SwaggerProcessor handles IO for converting Swagger -> OpenAPI +type SwaggerProcessor struct { + InputFile string + OutputFile string + WriteToStdout bool +} + +// NewSwaggerProcessor creates a new processor with the given input and output files +func NewSwaggerProcessor(inputFile, outputFile string, writeInPlace bool) (*SwaggerProcessor, error) { + var finalOutputFile string + + if writeInPlace { + if outputFile != "" { + return nil, errors.New("cannot specify output file when using --write flag") + } + finalOutputFile = inputFile + } else { + finalOutputFile = outputFile + } + + return &SwaggerProcessor{ + InputFile: inputFile, + OutputFile: finalOutputFile, + WriteToStdout: finalOutputFile == "", + }, nil +} + +// LoadDocument loads and parses a Swagger 2.0 document from the input file +func (p *SwaggerProcessor) LoadDocument(ctx context.Context) (*sw.Swagger, []error, error) { + cleanInputFile := filepath.Clean(p.InputFile) + + // Only print status messages if not writing to stdout (keep stdout clean for piping) + if !p.WriteToStdout { + fmt.Printf("Processing Swagger document: %s\n", cleanInputFile) + } + + f, err := os.Open(cleanInputFile) + if err != nil { + return nil, nil, fmt.Errorf("failed to open input file: %w", err) + } + defer f.Close() + + doc, validationErrors, err := sw.Unmarshal(ctx, f) + if err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal Swagger document: %w", err) + } + if doc == nil { + return nil, nil, errors.New("failed to parse Swagger document: document is nil") + } + + return doc, validationErrors, nil +} + +// ReportValidationErrors reports validation errors if not writing to stdout +func (p *SwaggerProcessor) ReportValidationErrors(validationErrors []error) { + if len(validationErrors) > 0 && !p.WriteToStdout { + fmt.Printf("⚠️ Found %d validation errors in original document:\n", len(validationErrors)) + for i, validationErr := range validationErrors { + fmt.Printf(" %d. %s\n", i+1, validationErr.Error()) + } + fmt.Println() + } +} + +// WriteOpenAPIDocument writes the converted OpenAPI document to the output destination +func (p *SwaggerProcessor) WriteOpenAPIDocument(ctx context.Context, doc *openapi.OpenAPI) error { + if p.WriteToStdout { + // Write to stdout (pipe-friendly) + return marshaller.Marshal(ctx, doc, os.Stdout) + } + + // Write to file + cleanOutputFile := filepath.Clean(p.OutputFile) + outFile, err := os.Create(cleanOutputFile) + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + defer outFile.Close() + + if err := marshaller.Marshal(ctx, doc, outFile); err != nil { + return fmt.Errorf("failed to write document: %w", err) + } + + fmt.Printf("📄 Document written to: %s\n", cleanOutputFile) + + return nil +} + +// PrintSuccess prints a success message if not writing to stdout +func (p *SwaggerProcessor) PrintSuccess(message string) { + if !p.WriteToStdout { + fmt.Printf("✅ %s\n", message) + } +} + +// PrintInfo prints an info message if not writing to stdout +func (p *SwaggerProcessor) PrintInfo(message string) { + if !p.WriteToStdout { + fmt.Printf("📋 %s\n", message) + } +} + +// PrintWarning prints a warning message if not writing to stdout +func (p *SwaggerProcessor) PrintWarning(message string) { + if !p.WriteToStdout { + fmt.Printf("⚠️ Warning: %s\n", message) + } +} + +func runUpgrade(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + inputFile := args[0] + + var outputFile string + if len(args) > 1 { + outputFile = args[1] + } + + processor, err := NewSwaggerProcessor(inputFile, outputFile, writeInPlace) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + if err := upgradeSwagger(ctx, processor); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} + +func upgradeSwagger(ctx context.Context, processor *SwaggerProcessor) error { + // Load the Swagger document + swDoc, validationErrors, err := processor.LoadDocument(ctx) + if err != nil { + return err + } + if swDoc == nil { + return errors.New("failed to parse Swagger document: document is nil") + } + + // Report validation errors but continue with upgrade + processor.ReportValidationErrors(validationErrors) + + // Perform the upgrade (Swagger 2.0 -> OpenAPI 3.0) + oasDoc, err := sw.Upgrade(ctx, swDoc) + if err != nil { + return fmt.Errorf("failed to upgrade Swagger document: %w", err) + } + if oasDoc == nil { + return errors.New("upgrade returned a nil document") + } + + processor.PrintSuccess(fmt.Sprintf("Successfully upgraded to OpenAPI %s", oasDoc.OpenAPI)) + + return processor.WriteOpenAPIDocument(ctx, oasDoc) +} diff --git a/cmd/openapi/main.go b/cmd/openapi/main.go index 1541b1e..9410bf6 100644 --- a/cmd/openapi/main.go +++ b/cmd/openapi/main.go @@ -9,6 +9,7 @@ import ( arazzoCmd "github.com/speakeasy-api/openapi/cmd/openapi/commands/arazzo" openapiCmd "github.com/speakeasy-api/openapi/cmd/openapi/commands/openapi" overlayCmd "github.com/speakeasy-api/openapi/cmd/openapi/commands/overlay" + swaggerCmd "github.com/speakeasy-api/openapi/cmd/openapi/commands/swagger" "github.com/spf13/cobra" ) @@ -105,6 +106,15 @@ OpenAPI specifications define REST APIs in a standard format. These commands help you validate and work with OpenAPI documents.`, } +var swaggerCmds = &cobra.Command{ + Use: "swagger", + Short: "Work with Swagger 2.0 (OpenAPI v2) specifications", + Long: `Commands for working with Swagger 2.0 (OpenAPI v2) specifications. + +Swagger 2.0 documents describe REST APIs prior to OpenAPI 3.x. +These commands help you validate and upgrade Swagger documents.`, +} + var arazzoCmds = &cobra.Command{ Use: "arazzo", Short: "Work with Arazzo workflow documents", @@ -138,6 +148,9 @@ func init() { // Add OpenAPI spec validation command openapiCmd.Apply(openapiCmds) + // Add Swagger 2.0 commands + swaggerCmd.Apply(swaggerCmds) + // Add Arazzo workflow validation command arazzoCmd.Apply(arazzoCmds) @@ -146,6 +159,7 @@ func init() { // Add all commands to root rootCmd.AddCommand(openapiCmds) + rootCmd.AddCommand(swaggerCmds) rootCmd.AddCommand(arazzoCmds) rootCmd.AddCommand(overlayCmds) diff --git a/internal/testutils/utils.go b/internal/testutils/utils.go index 81dfa76..154d51b 100644 --- a/internal/testutils/utils.go +++ b/internal/testutils/utils.go @@ -1,7 +1,14 @@ package testutils import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "io" "iter" + "net/http" + "os" + "path/filepath" "reflect" "strconv" "testing" @@ -114,3 +121,61 @@ func AssertEqualSequencedMap(t *testing.T, expected, actual SequencedMap) { assert.EqualExportedValues(t, v, actualV) } } + +// DownloadFile downloads a file from a URL and caches it to avoid re-downloading. +// Uses the provided cacheEnvVar for cache location, fallback to system temp dir. +// The cacheDirName is used as the subdirectory name under the cache directory. +func DownloadFile(url, cacheEnvVar, cacheDirName string) (io.ReadCloser, error) { + // Use environment variable for cache directory, fallback to system temp dir + cacheDir := os.Getenv(cacheEnvVar) + if cacheDir == "" { + cacheDir = os.TempDir() + } + tempDir := filepath.Join(cacheDir, cacheDirName) + + if err := os.MkdirAll(tempDir, 0o750); err != nil { + return nil, err + } + + // hash url to create a unique filename + hash := sha256.Sum256([]byte(url)) + filename := hex.EncodeToString(hash[:]) + + filepath := filepath.Join(tempDir, filename) + + // check if file exists and return it otherwise download it + r, err := os.Open(filepath) // #nosec G304 -- filepath is controlled by caller in tests + if err == nil { + return r, nil + } + + resp, err := http.Get(url) // #nosec G107 -- url is controlled by caller in tests + if err != nil { + return nil, err + } + if resp == nil { + return nil, io.ErrUnexpectedEOF + } + defer resp.Body.Close() + + // Read all data from response body + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // Write data to cache file + f, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY, 0o600) // #nosec G304 -- filepath is controlled by caller in tests + if err != nil { + return nil, err + } + defer f.Close() + + _, err = f.Write(data) + if err != nil { + return nil, err + } + + // Return the data as a ReadCloser + return io.NopCloser(bytes.NewReader(data)), nil +} diff --git a/json/json.go b/json/json.go index 7bc9aad..0f6c76e 100644 --- a/json/json.go +++ b/json/json.go @@ -1,141 +1,495 @@ -// Package json provides utilities for working with JSON. package json import ( + "bytes" "encoding/json" "fmt" "io" - "reflect" + "strconv" "strings" - "github.com/speakeasy-api/openapi/sequencedmap" "github.com/speakeasy-api/openapi/yml" "gopkg.in/yaml.v3" ) -// YAMLToJSON will convert the provided YAML node to JSON in a stable way not reordering keys. +// YAMLToJSON converts a YAML node to JSON using a custom JSON writer that preserves formatting. +// This approach is particularly useful when the YAML nodes already represent JSON input (since +// the yaml decoder can parse JSON), and we want to preserve the original JSON formatting. func YAMLToJSON(node *yaml.Node, indentation int, buffer io.Writer) error { - v, err := handleYAMLNode(node) - if err != nil { + return YAMLToJSONWithConfig(node, " ", indentation, true, buffer) +} + +// YAMLToJSONWithConfig converts YAML to JSON with full control over formatting. +// When the input nodes already have JSON-style formatting (from JSON input), this preserves it +// by analyzing Line and Column metadata to recreate the original formatting. +func YAMLToJSONWithConfig(node *yaml.Node, indent string, indentCount int, trailingNewline bool, buffer io.Writer) error { + if node == nil { + return nil + } + + // Build the indent string + indentStr := strings.Repeat(indent, indentCount) + + // Create a JSON writer context + ctx := &jsonWriteContext{ + indent: indentStr, + buffer: &bytes.Buffer{}, + currentCol: 0, + forceCompact: indentCount == 0, // Force compact mode when no indentation + } + + // Write the JSON + if err := writeJSONNode(ctx, node, 0); err != nil { return err } - e := json.NewEncoder(buffer) - e.SetIndent("", strings.Repeat(" ", indentation)) + // Get the output + output := ctx.buffer.Bytes() + + // Add or remove trailing newline as requested + if trailingNewline && (len(output) == 0 || output[len(output)-1] != '\n') { + output = append(output, '\n') + } else if !trailingNewline && len(output) > 0 && output[len(output)-1] == '\n' { + output = output[:len(output)-1] + } + + _, err := buffer.Write(output) + return err +} + +type jsonWriteContext struct { + indent string + buffer *bytes.Buffer + currentCol int + forceCompact bool // When true, always output compact format +} + +// isSingleLineFlowNode checks if a flow-style node is on a single line +func isSingleLineFlowNode(node *yaml.Node) bool { + if node.Style != yaml.FlowStyle { + return false + } + + if len(node.Content) == 0 { + return true + } + + // Check if all children are on the same line as the parent + nodeLine := node.Line + for _, child := range node.Content { + if child.Line != nodeLine { + return false + } + } - return e.Encode(v) + return true } -// YAMLToJSONCompatibleGoType will convert the provided YAML node to a compatible Go type ready for json marshalling in a stable way not reordering keys. -// Provided to allow a custom JSON marshalling implementation to be used. -func YAMLToJSONCompatibleGoType(node *yaml.Node) (any, error) { - v, err := handleYAMLNode(node) - if err != nil { - return nil, err +// hasSpaceAfterColon checks if there's a space after the colon in a mapping node +// Returns true unless we can definitively detect NO space (compact JSON) +func hasSpaceAfterColon(node *yaml.Node) bool { + if node.Kind != yaml.MappingNode || len(node.Content) < 2 { + return true // Default to having space } - return v, nil + key := node.Content[0] + value := node.Content[1] + + if key.Line != value.Line { + return true // Multi-line, doesn't matter + } + + // Based on inspection: YAML flow uses Style=Default, JSON uses Style=DoubleQuoted + // Calculate the width of the key in the source + var keyWidth int + if key.Style == yaml.DoubleQuotedStyle || key.Style == yaml.SingleQuotedStyle { + // Quoted keys (JSON): need to account for quotes + // Column points to opening quote, width includes both quotes + keyWidth = len(strconv.Quote(key.Value)) + } else { + // Unquoted keys (YAML flow-style) + keyWidth = len(key.Value) + } + + // Expected column for value with NO space after colon: + // key.Column + keyWidth + 1 (for colon) + expectedNoSpaceCol := key.Column + keyWidth + 1 + + // If value starts AFTER expectedNoSpaceCol, there's a space + // If value starts AT expectedNoSpaceCol, there's NO space + return value.Column > expectedNoSpaceCol } -func handleYAMLNode(node *yaml.Node) (any, error) { +// hasSpaceAfterComma checks if there's a space after commas in a sequence node +// Returns true unless we can definitively detect NO space (compact JSON) +func hasSpaceAfterComma(node *yaml.Node) bool { + if node.Kind != yaml.SequenceNode || len(node.Content) < 2 { + return true // Default to having space + } + + first := node.Content[0] + second := node.Content[1] + + if first.Line != second.Line { + return true // Multi-line, doesn't matter + } + + // Calculate width of first element in source + var firstWidth int + if first.Kind == yaml.ScalarNode { + if first.Style == yaml.DoubleQuotedStyle || first.Style == yaml.SingleQuotedStyle { + // Quoted strings (JSON): account for quotes + firstWidth = len(strconv.Quote(first.Value)) + } else { + // Unquoted values (YAML flow-style or numbers) + firstWidth = len(first.Value) + } + } else { + // For nested structures, default to having space + return true + } + + // Expected column for second element with NO space after comma: + // first.Column + firstWidth + 1 (for comma) + expectedNoSpaceCol := first.Column + firstWidth + 1 + + // If second starts AFTER expectedNoSpaceCol, there's a space + return second.Column > expectedNoSpaceCol +} + +func (ctx *jsonWriteContext) write(s string) { + ctx.buffer.WriteString(s) + // Track column (simplified - doesn't handle newlines in s) + ctx.currentCol += len(s) +} + +func (ctx *jsonWriteContext) writeByte(b byte) { + ctx.buffer.WriteByte(b) + if b == '\n' { + ctx.currentCol = 0 + } else { + ctx.currentCol++ + } +} + +func writeJSONNode(ctx *jsonWriteContext, node *yaml.Node, depth int) error { if node == nil { - return nil, nil + return nil } switch node.Kind { case yaml.DocumentNode: - return handleYAMLNode(node.Content[0]) - case yaml.SequenceNode: - return handleSequenceNode(node) + // Unwrap document node + if len(node.Content) > 0 { + return writeJSONNode(ctx, node.Content[0], depth) + } + return nil + case yaml.MappingNode: - return handleMappingNode(node) + return writeJSONObject(ctx, node, depth) + + case yaml.SequenceNode: + return writeJSONArray(ctx, node, depth) + case yaml.ScalarNode: - return handleScalarNode(node) + return writeJSONScalar(ctx, node) + case yaml.AliasNode: - return handleYAMLNode(node.Alias) + // Resolve alias using yml package helper + resolved := yml.ResolveAlias(node) + if resolved != nil { + return writeJSONNode(ctx, resolved, depth) + } + return nil + default: - return nil, fmt.Errorf("unknown node kind: %s", yml.NodeKindToString(node.Kind)) + return fmt.Errorf("unknown node kind: %v", node.Kind) } } -func handleMappingNode(node *yaml.Node) (any, error) { - v := sequencedmap.New[string, any]() - for i, n := range node.Content { - if i%2 == 0 { - continue +func writeJSONObject(ctx *jsonWriteContext, node *yaml.Node, depth int) error { + if len(node.Content) == 0 { + ctx.write("{}") + return nil + } + + // Resolve merge keys first + mergedContent := resolveMergeKeys(node.Content) + + // Check if this is a single-line flow node (preserve spacing) + isSingleLine := isSingleLineFlowNode(node) + preserveSpacing := isSingleLine && node.Style == yaml.FlowStyle + + // Determine if we should format as multi-line + isMultiLine := !preserveSpacing && + ((node.Style != yaml.FlowStyle && !ctx.forceCompact) || + shouldBeMultiLine(ctx, node, mergedContent)) + + ctx.writeByte('{') + + if isMultiLine { + ctx.writeByte('\n') + } + + firstItem := true + // Process key-value pairs (using merged content) + for i := 0; i < len(mergedContent); i += 2 { + if i+1 >= len(mergedContent) { + break + } + + keyNode := mergedContent[i] + valueNode := mergedContent[i+1] + + // Add comma before this item if not the first + if !firstItem { + ctx.writeByte(',') + if isMultiLine { + ctx.writeByte('\n') + } else if !ctx.forceCompact { + // For single-line flow nodes, preserve original spacing + if preserveSpacing && !hasSpaceAfterColon(node) { + // No space in original + } else { + // Add space (default or original had space) + ctx.writeByte(' ') + } + } } - keyNode := node.Content[i-1] - kv, err := handleYAMLNode(keyNode) - if err != nil { - return nil, err + firstItem = false + + // Add indentation for multi-line + if isMultiLine { + ctx.write(strings.Repeat(ctx.indent, depth+1)) } - if reflect.TypeOf(kv).Kind() != reflect.String { - keyData, err := json.Marshal(kv) - if err != nil { - return nil, err + // Write key - always as a quoted string (JSON requirement) + ctx.write(quoteJSONString(keyNode.Value)) + + // Add space after colon based on context + switch { + case ctx.forceCompact: + ctx.write(":") + case preserveSpacing: + // Preserve original spacing for single-line flow nodes + if hasSpaceAfterColon(node) { + ctx.write(": ") + } else { + ctx.write(":") } - kv = string(keyData) + default: + // Default: add space + ctx.write(": ") + } + + // Write value + if err := writeJSONNode(ctx, valueNode, depth+1); err != nil { + return err } + } - keyStr := fmt.Sprintf("%v", kv) + if isMultiLine { + ctx.writeByte('\n') + ctx.write(strings.Repeat(ctx.indent, depth)) + } + ctx.writeByte('}') + + return nil +} - // Handle YAML merge key (<<) - if keyStr == "<<" { - vv, err := handleYAMLNode(n) - if err != nil { - return nil, err +func writeJSONArray(ctx *jsonWriteContext, node *yaml.Node, depth int) error { + if len(node.Content) == 0 { + ctx.write("[]") + return nil + } + + // Check if this is a single-line flow node (preserve spacing) + isSingleLine := isSingleLineFlowNode(node) + preserveSpacing := isSingleLine && node.Style == yaml.FlowStyle + + // Determine if we should format as multi-line + isMultiLine := !preserveSpacing && + ((node.Style != yaml.FlowStyle && !ctx.forceCompact) || + shouldBeMultiLine(ctx, node, node.Content)) + + ctx.writeByte('[') + + if isMultiLine { + ctx.writeByte('\n') + } + + for i, child := range node.Content { + // Add indentation for multi-line + if isMultiLine { + ctx.write(strings.Repeat(ctx.indent, depth+1)) + } + + if err := writeJSONNode(ctx, child, depth+1); err != nil { + return err + } + + // Add comma if not the last item + if i+1 < len(node.Content) { + ctx.writeByte(',') + if isMultiLine { + ctx.writeByte('\n') + } else if !ctx.forceCompact { + // For single-line flow nodes, preserve original spacing + if preserveSpacing && !hasSpaceAfterComma(node) { + // No space in original + } else { + // Add space (default or original had space) + ctx.writeByte(' ') + } } + } else if isMultiLine { + ctx.writeByte('\n') + } + } - // Merge the values from the referenced map - if mergeMap, ok := vv.(*sequencedmap.Map[string, any]); ok { - for mergeKey, mergeValue := range mergeMap.All() { - // Only set if the key doesn't already exist (merge keys have lower priority) - if !v.Has(mergeKey) { - v.Set(mergeKey, mergeValue) + if isMultiLine { + ctx.write(strings.Repeat(ctx.indent, depth)) + } + ctx.writeByte(']') + + return nil +} + +func writeJSONScalar(ctx *jsonWriteContext, node *yaml.Node) error { + switch node.Tag { + case "!!str": + // JSON string - must be quoted and escaped + ctx.write(quoteJSONString(node.Value)) + return nil + + case "!!int", "!!float": + // Numbers - write as-is + ctx.write(node.Value) + return nil + + case "!!bool": + // Booleans + ctx.write(node.Value) + return nil + + case "!!null": + // Null + ctx.write("null") + return nil + + default: + // Default to quoted string + ctx.write(quoteJSONString(node.Value)) + return nil + } +} + +// resolveMergeKeys processes YAML merge keys (<<) and returns content with merged values +func resolveMergeKeys(content []*yaml.Node) []*yaml.Node { + if len(content) == 0 { + return content + } + + result := make([]*yaml.Node, 0, len(content)) + mergedKeys := make(map[string]bool) // Track which keys have been merged + + // First pass: collect all merge key content + var mergeContent []*yaml.Node + for i := 0; i < len(content); i += 2 { + if i+1 >= len(content) { + break + } + + keyNode := content[i] + valueNode := content[i+1] + + // Check for merge key + if keyNode.Value == "<<" { + // Resolve the alias to get the merged content + resolved := yml.ResolveAlias(valueNode) + if resolved != nil && resolved.Kind == yaml.MappingNode { + // Add all key-value pairs from the merged content + for j := 0; j < len(resolved.Content); j += 2 { + if j+1 < len(resolved.Content) { + mergeKey := resolved.Content[j] + mergeValue := resolved.Content[j+1] + if !mergedKeys[mergeKey.Value] { + mergeContent = append(mergeContent, mergeKey, mergeValue) + mergedKeys[mergeKey.Value] = true + } } } } - continue } + } + + // Second pass: add merged content first, then original content (original overrides merged) + result = append(result, mergeContent...) + + // Add non-merge keys + for i := 0; i < len(content); i += 2 { + if i+1 >= len(content) { + break + } + + keyNode := content[i] + valueNode := content[i+1] - vv, err := handleYAMLNode(n) - if err != nil { - return nil, err + // Skip merge keys themselves + if keyNode.Value == "<<" { + continue } - v.Set(keyStr, vv) + // Add this key-value pair (it will override any merged value with same key) + result = append(result, keyNode, valueNode) } - return v, nil + return result } -func handleSequenceNode(node *yaml.Node) (any, error) { - var s []yaml.Node +// shouldBeMultiLine determines if a node's children should be formatted on multiple lines +// by checking if the first child is on a different line than the parent, OR if children +// are on different lines from each other +func shouldBeMultiLine(ctx *jsonWriteContext, parent *yaml.Node, children []*yaml.Node) bool { + // Force compact if requested + if ctx.forceCompact { + return false + } - if err := node.Decode(&s); err != nil { - return nil, err + if len(children) == 0 { + return false } - v := make([]any, len(s)) - for i, n := range s { - vv, err := handleYAMLNode(&n) - if err != nil { - return nil, err - } + // Check if first child is on different line than parent + firstChild := children[0] + if firstChild.Line != parent.Line { + return true + } - v[i] = vv + // Also check if children are on different lines from each other + for _, node := range children[1:] { + if node.Line != firstChild.Line { + return true + } } - return v, nil + return false } -func handleScalarNode(node *yaml.Node) (any, error) { - var v any +// quoteJSONString properly quotes and escapes a string for JSON output +func quoteJSONString(s string) string { + // Use json.Encoder with SetEscapeHTML(false) for proper JSON escaping without HTML entity encoding + // This handles nul characters correctly (\u0000) while keeping & as & instead of \u0026 + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) - if err := node.Decode(&v); err != nil { - return nil, err + if err := enc.Encode(s); err != nil { + // Fallback to strconv.Quote if encoding fails (shouldn't happen for strings) + return strconv.Quote(s) } - return v, nil + // Encoder.Encode adds a newline, so we need to trim it + result := buf.String() + return strings.TrimSuffix(result, "\n") } diff --git a/json/json_test.go b/json/json_test.go index a274e03..ca6dbbf 100644 --- a/json/json_test.go +++ b/json/json_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/speakeasy-api/openapi/json" - "github.com/speakeasy-api/openapi/sequencedmap" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" @@ -266,130 +265,250 @@ func TestYAMLToJSON_Error(t *testing.T) { } } -func TestYAMLToJSONCompatibleGoType_Success(t *testing.T) { +func TestYAMLToJSONWithConfig_Success(t *testing.T) { t.Parallel() tests := []struct { - name string - yamlInput string - wantAny any + name string + yamlInput string + indent string + indentCount int + expectedJSON string }{ { - name: "simple string", - yamlInput: `hello world`, - wantAny: "hello world", - }, - { - name: "simple number", - yamlInput: `42`, - wantAny: 42, + name: "2 spaces indentation", + yamlInput: `name: John +age: 30`, + indent: " ", + indentCount: 2, + expectedJSON: `{ + "name": "John", + "age": 30 +} +`, }, { - name: "simple boolean", - yamlInput: `true`, - wantAny: true, + name: "4 spaces indentation", + yamlInput: `name: John +age: 30`, + indent: " ", + indentCount: 4, + expectedJSON: `{ + "name": "John", + "age": 30 +} +`, }, { - name: "null value", - yamlInput: `null`, - wantAny: nil, + name: "single tab indentation", + yamlInput: `name: John +age: 30`, + indent: "\t", + indentCount: 1, + expectedJSON: "{\n\t\"name\": \"John\",\n\t\"age\": 30\n}\n", }, { - name: "simple object", + name: "double tab indentation", yamlInput: `name: John age: 30`, - wantAny: sequencedmap.New( - sequencedmap.NewElem("name", any("John")), - sequencedmap.NewElem("age", any(30)), - ), + indent: "\t", + indentCount: 2, + expectedJSON: "{\n\t\t\"name\": \"John\",\n\t\t\"age\": 30\n}\n", }, { - name: "nested object", + name: "tabs with nested object", yamlInput: `person: name: John - age: 30`, - wantAny: sequencedmap.New( - sequencedmap.NewElem("person", any(sequencedmap.New( - sequencedmap.NewElem("name", any("John")), - sequencedmap.NewElem("age", any(30)), - ))), - ), + age: 30 + address: + city: New York`, + indent: "\t", + indentCount: 1, + expectedJSON: "{\n\t\"person\": {\n\t\t\"name\": \"John\",\n\t\t\"age\": 30,\n\t\t\"address\": {\n\t\t\t\"city\": \"New York\"\n\t\t}\n\t}\n}\n", }, { - name: "simple array", - yamlInput: `- apple -- banana -- cherry`, - wantAny: []any{"apple", "banana", "cherry"}, + name: "tabs with array", + yamlInput: `items: + - apple + - banana + - cherry`, + indent: "\t", + indentCount: 1, + expectedJSON: "{\n\t\"items\": [\n\t\t\"apple\",\n\t\t\"banana\",\n\t\t\"cherry\"\n\t]\n}\n", }, { - name: "array of objects", + name: "zero indentation (compact)", + yamlInput: `name: John +age: 30`, + indent: " ", + indentCount: 0, + expectedJSON: `{"name":"John","age":30} +`, + }, + { + name: "3 spaces indentation", + yamlInput: `name: John +age: 30`, + indent: " ", + indentCount: 3, + expectedJSON: `{ + "name": "John", + "age": 30 +} +`, + }, + { + name: "tabs with array of objects", yamlInput: `- name: John age: 30 - name: Jane age: 25`, - wantAny: []any{ - sequencedmap.New( - sequencedmap.NewElem("name", any("John")), - sequencedmap.NewElem("age", any(30)), - ), - sequencedmap.New( - sequencedmap.NewElem("name", any("Jane")), - sequencedmap.NewElem("age", any(25)), - ), - }, + indent: "\t", + indentCount: 1, + expectedJSON: "[\n\t{\n\t\t\"name\": \"John\",\n\t\t\"age\": 30\n\t},\n\t{\n\t\t\"name\": \"Jane\",\n\t\t\"age\": 25\n\t}\n]\n", }, { - name: "preserves key order", + name: "scalar string with tabs", + yamlInput: `hello world`, + indent: "\t", + indentCount: 1, + expectedJSON: "\"hello world\"\n", + }, + { + name: "tabs preserving key order", yamlInput: `zebra: last apple: first middle: second`, - wantAny: sequencedmap.New( - sequencedmap.NewElem("zebra", any("last")), - sequencedmap.NewElem("apple", any("first")), - sequencedmap.NewElem("middle", any("second")), - ), + indent: "\t", + indentCount: 1, + expectedJSON: "{\n\t\"zebra\": \"last\",\n\t\"apple\": \"first\",\n\t\"middle\": \"second\"\n}\n", }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var node yaml.Node + err := yaml.Unmarshal([]byte(tt.yamlInput), &node) + require.NoError(t, err, "failed to parse YAML input") + + var buffer bytes.Buffer + err = json.YAMLToJSONWithConfig(&node, tt.indent, tt.indentCount, true, &buffer) + require.NoError(t, err, "YAMLToJSONWithIndentation should not return error") + + actualJSON := buffer.String() + assert.Equal(t, tt.expectedJSON, actualJSON, "JSON output should match expected") + }) + } +} + +func TestYAMLToJSON_ArrayFormats(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yamlInput string + expectedJSON string + indentation int + description string + }{ { - name: "numeric keys converted to strings", - yamlInput: `1: one -2: two -3: three`, - wantAny: sequencedmap.New( - sequencedmap.NewElem("1", any("one")), - sequencedmap.NewElem("2", any("two")), - sequencedmap.NewElem("3", any("three")), - ), + name: "flow style array - single line compact", + yamlInput: `items: [apple, banana, cherry]`, + expectedJSON: `{ + "items": ["apple", "banana", "cherry"] +} +`, + indentation: 2, + description: "Flow style arrays remain compact in JSON (preserves source formatting)", }, { - name: "empty object", - yamlInput: `{}`, - wantAny: sequencedmap.New[string, any](), + name: "block style array - already multi-line", + yamlInput: `items: + - apple + - banana + - cherry`, + expectedJSON: `{ + "items": [ + "apple", + "banana", + "cherry" + ] +} +`, + indentation: 2, + description: "Block style arrays remain multi-line in JSON", }, { - name: "empty array", - yamlInput: `[]`, - wantAny: []any{}, + name: "nested flow arrays", + yamlInput: `matrix: [[1, 2], [3, 4], [5, 6]]`, + expectedJSON: `{ + "matrix": [[1, 2], [3, 4], [5, 6]] +} +`, + indentation: 2, + description: "Nested flow arrays remain compact (preserves source formatting)", }, { - name: "yaml alias", - yamlInput: `defaults: &defaults - timeout: 30 - retries: 3 -production: - <<: *defaults - host: prod.example.com`, - wantAny: sequencedmap.New( - sequencedmap.NewElem("defaults", any(sequencedmap.New( - sequencedmap.NewElem("timeout", any(30)), - sequencedmap.NewElem("retries", any(3)), - ))), - sequencedmap.NewElem("production", any(sequencedmap.New( - sequencedmap.NewElem("timeout", any(30)), - sequencedmap.NewElem("retries", any(3)), - sequencedmap.NewElem("host", any("prod.example.com")), - ))), - ), + name: "mixed flow and block arrays", + yamlInput: `config: + inline: [1, 2, 3] + block: + - a + - b + - c`, + expectedJSON: `{ + "config": { + "inline": [1, 2, 3], + "block": [ + "a", + "b", + "c" + ] + } +} +`, + indentation: 2, + description: "Mixed flow and block style arrays - flow stays compact, block expands", + }, + { + name: "empty array flow style", + yamlInput: `empty: []`, + expectedJSON: `{ + "empty": [] +} +`, + indentation: 2, + description: "Empty flow style arrays - root expands, value stays compact", + }, + { + name: "single element flow array", + yamlInput: `single: [one]`, + expectedJSON: `{ + "single": ["one"] +} +`, + indentation: 2, + description: "Single element flow arrays remain compact", + }, + { + name: "array of objects in flow style", + yamlInput: `users: [{name: John, age: 30}, {name: Jane, age: 25}]`, + expectedJSON: `{ + "users": [{"name": "John", "age": 30}, {"name": "Jane", "age": 25}] +} +`, + indentation: 2, + description: "Flow style array of objects remains compact", + }, + { + name: "compact indentation with arrays", + yamlInput: `data: [1, 2, 3]`, + expectedJSON: `{"data":[1,2,3]} +`, + indentation: 0, + description: "Compact mode (indent=0) produces single-line arrays", }, } @@ -401,26 +520,135 @@ production: err := yaml.Unmarshal([]byte(tt.yamlInput), &node) require.NoError(t, err, "failed to parse YAML input") - actual, err := json.YAMLToJSONCompatibleGoType(&node) - require.NoError(t, err, "YAMLToJSONCompatibleGoType should not return error") + var buffer bytes.Buffer + err = json.YAMLToJSON(&node, tt.indentation, &buffer) + require.NoError(t, err, "YAMLToJSON should not return error") - assert.Equal(t, tt.wantAny, actual, "result should match expected value") + actualJSON := buffer.String() + assert.Equal(t, tt.expectedJSON, actualJSON, tt.description) }) } } -func TestYAMLToJSONCompatibleGoType_Error(t *testing.T) { +func TestYAMLToJSON_ObjectFormats(t *testing.T) { t.Parallel() tests := []struct { - name string - node *yaml.Node - wantError bool + name string + yamlInput string + expectedJSON string + indentation int + description string }{ { - name: "nil node", - node: nil, - wantError: false, // nil node returns nil, nil + name: "flow style object - single line compact", + yamlInput: `person: {name: John, age: 30, city: NYC}`, + expectedJSON: `{ + "person": {"name": "John", "age": 30, "city": "NYC"} +} +`, + indentation: 2, + description: "Flow style objects remain compact (preserves source formatting)", + }, + { + name: "block style object - already multi-line", + yamlInput: `person: + name: John + age: 30 + city: NYC`, + expectedJSON: `{ + "person": { + "name": "John", + "age": 30, + "city": "NYC" + } +} +`, + indentation: 2, + description: "Block style objects remain multi-line in JSON", + }, + { + name: "nested flow objects", + yamlInput: `data: {user: {name: John, email: john@example.com}, meta: {version: 1}}`, + expectedJSON: `{ + "data": {"user": {"name": "John", "email": "john@example.com"}, "meta": {"version": 1}} +} +`, + indentation: 2, + description: "Nested flow objects remain compact (preserves source formatting)", + }, + { + name: "mixed flow and block objects", + yamlInput: `config: + inline: {a: 1, b: 2} + block: + c: 3 + d: 4`, + expectedJSON: `{ + "config": { + "inline": {"a": 1, "b": 2}, + "block": { + "c": 3, + "d": 4 + } + } +} +`, + indentation: 2, + description: "Mixed flow and block style objects - flow stays compact, block expands", + }, + { + name: "empty object flow style", + yamlInput: `empty: {}`, + expectedJSON: `{ + "empty": {} +} +`, + indentation: 2, + description: "Empty flow style objects - root expands, value stays compact", + }, + { + name: "single property flow object", + yamlInput: `config: {key: value}`, + expectedJSON: `{ + "config": {"key": "value"} +} +`, + indentation: 2, + description: "Single property flow objects remain compact", + }, + { + name: "compact indentation with objects", + yamlInput: `data: {a: 1, b: 2}`, + expectedJSON: `{"data":{"a":1,"b":2}} +`, + indentation: 0, + description: "Compact mode (indent=0) produces single-line objects", + }, + { + name: "deeply nested mixed styles", + yamlInput: `root: + level1: {a: 1, b: [1, 2, 3]} + level2: + c: {x: 10, y: 20} + d: + - {id: 1} + - {id: 2}`, + expectedJSON: `{ + "root": { + "level1": {"a": 1, "b": [1, 2, 3]}, + "level2": { + "c": {"x": 10, "y": 20}, + "d": [ + {"id": 1}, + {"id": 2} + ] + } + } +} +`, + indentation: 2, + description: "Deeply nested mixed flow and block styles - flow stays compact", }, } @@ -428,16 +656,530 @@ func TestYAMLToJSONCompatibleGoType_Error(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - result, err := json.YAMLToJSONCompatibleGoType(tt.node) + var node yaml.Node + err := yaml.Unmarshal([]byte(tt.yamlInput), &node) + require.NoError(t, err, "failed to parse YAML input") + + var buffer bytes.Buffer + err = json.YAMLToJSON(&node, tt.indentation, &buffer) + require.NoError(t, err, "YAMLToJSON should not return error") + + actualJSON := buffer.String() + assert.Equal(t, tt.expectedJSON, actualJSON, tt.description) + }) + } +} + +func TestYAMLToJSON_ComplexMixedFormats(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yamlInput string + expectedJSON string + indentation int + description string + }{ + { + name: "swagger-like structure with mixed formats", + yamlInput: `swagger: "2.0" +info: {title: API, version: 1.0.0} +paths: + /users: + get: + tags: [users] + responses: + "200": + description: Success`, + expectedJSON: `{ + "swagger": "2.0", + "info": {"title": "API", "version": "1.0.0"}, + "paths": { + "/users": { + "get": { + "tags": ["users"], + "responses": { + "200": { + "description": "Success" + } + } + } + } + } +} +`, + indentation: 2, + description: "Real-world API spec with mixed flow/block styles", + }, + { + name: "configuration file with inline arrays", + yamlInput: `server: + ports: [80, 443, 8080] + hosts: [localhost, example.com] + options: {timeout: 30, retries: 3}`, + expectedJSON: `{ + "server": { + "ports": [80, 443, 8080], + "hosts": ["localhost", "example.com"], + "options": {"timeout": 30, "retries": 3} + } +} +`, + indentation: 2, + description: "Config file with inline arrays and objects", + }, + { + name: "matrix data structure", + yamlInput: `matrix: + - [1, 0, 0] + - [0, 1, 0] + - [0, 0, 1]`, + expectedJSON: `{ + "matrix": [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1] + ] +} +`, + indentation: 2, + description: "Matrix represented as array of flow arrays", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var node yaml.Node + err := yaml.Unmarshal([]byte(tt.yamlInput), &node) + require.NoError(t, err, "failed to parse YAML input") + + var buffer bytes.Buffer + err = json.YAMLToJSON(&node, tt.indentation, &buffer) + require.NoError(t, err, "YAMLToJSON should not return error") + + actualJSON := buffer.String() + assert.Equal(t, tt.expectedJSON, actualJSON, tt.description) + }) + } +} + +func TestYAMLToJSONWithConfig_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + node *yaml.Node + indent string + indentCount int + wantError bool + }{ + { + name: "nil node", + node: nil, + indent: " ", + indentCount: 2, + wantError: false, // nil node is handled gracefully + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var buffer bytes.Buffer + err := json.YAMLToJSONWithConfig(tt.node, tt.indent, tt.indentCount, true, &buffer) if tt.wantError { assert.Error(t, err, "expected error for invalid input") } else { assert.NoError(t, err, "expected no error") - if tt.node == nil { - assert.Nil(t, result, "nil node should return nil result") - } } }) } } + +func TestJSONRoundTrip_PreservesFormatting(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + inputJSON string + indent string + indentCount int + description string + }{ + { + name: "compact single-line object", + inputJSON: `{"name":"John","age":30,"active":true}` + "\n", + indent: " ", + indentCount: 2, + description: "Compact JSON should stay compact", + }, + { + name: "compact single-line array", + inputJSON: `[1,2,3,4,5]` + "\n", + indent: " ", + indentCount: 2, + description: "Compact arrays stay compact", + }, + { + name: "compact nested structures", + inputJSON: `{"user":{"name":"John","address":{"city":"NYC","zip":10001}},"tags":["dev","admin"]}` + "\n", + indent: " ", + indentCount: 2, + description: "Deeply nested compact JSON stays compact", + }, + { + name: "pretty 2-space indentation", + inputJSON: `{ + "name": "John", + "age": 30, + "address": { + "city": "New York", + "zip": 10001 + }, + "tags": [ + "developer", + "admin" + ] +} +`, + indent: " ", + indentCount: 2, + description: "Pretty JSON with 2-space indent preserved", + }, + { + name: "pretty 4-space indentation", + inputJSON: `{ + "name": "John", + "age": 30, + "nested": { + "level1": { + "level2": "value" + } + } +} +`, + indent: " ", + indentCount: 4, + description: "Pretty JSON with 4-space indent preserved", + }, + { + name: "tab indentation", + inputJSON: "{\n\t\"name\": \"John\",\n\t\"age\": 30,\n\t\"address\": {\n\t\t\"city\": \"NYC\"\n\t}\n}\n", + indent: "\t", + indentCount: 1, + description: "Tab-indented JSON preserved", + }, + { + name: "mixed formatting - compact and pretty", + inputJSON: `{ + "compact": {"a": 1, "b": 2}, + "pretty": { + "c": 3, + "d": 4 + }, + "array": [1, 2, 3], + "prettyArray": [ + "one", + "two" + ] +} +`, + indent: " ", + indentCount: 2, + description: "Mixed compact and pretty formatting preserved", + }, + { + name: "array of compact objects", + inputJSON: `[ + {"id": 1, "name": "Alice"}, + {"id": 2, "name": "Bob"}, + {"id": 3, "name": "Charlie"} +] +`, + indent: " ", + indentCount: 2, + description: "Array of compact objects in pretty array", + }, + { + name: "compact array of pretty objects", + inputJSON: `[{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}]` + "\n", + indent: " ", + indentCount: 2, + description: "Compact root with inline objects", + }, + { + name: "empty structures", + inputJSON: `{ + "emptyObject": {}, + "emptyArray": [], + "nestedEmpty": { + "obj": {}, + "arr": [] + } +} +`, + indent: " ", + indentCount: 2, + description: "Empty objects and arrays preserved", + }, + { + name: "all JSON data types", + inputJSON: `{ + "string": "hello", + "number": 42, + "float": 3.14, + "boolean": true, + "null": null, + "array": [1, 2, 3], + "object": {"nested": "value"} +} +`, + indent: " ", + indentCount: 2, + description: "All JSON data types with formatting", + }, + { + name: "deeply nested mixed formatting", + inputJSON: `{ + "level1": { + "compact": {"a": 1, "b": 2}, + "level2": { + "array": [1, 2, 3], + "prettyArray": [ + {"id": 1}, + {"id": 2} + ], + "level3": { + "compact": {"x": 10}, + "pretty": { + "y": 20 + } + } + } + } +} +`, + indent: " ", + indentCount: 2, + description: "Complex nested structure with mixed formatting", + }, + { + name: "single-line pretty root with nested structures", + inputJSON: `{"users": [{"name": "John", "age": 30}], "count": 1}` + "\n", + indent: " ", + indentCount: 2, + description: "Compact root with inline nested structures", + }, + { + name: "string escaping preserved", + inputJSON: `{ + "quote": "He said \"hello\"", + "newline": "line1\nline2", + "tab": "col1\tcol2", + "backslash": "path\\to\\file" +} +`, + indent: " ", + indentCount: 2, + description: "String escaping maintained through roundtrip", + }, + { + name: "numeric precision", + inputJSON: `{ + "int": 42, + "float": 3.14159, + "scientific": 1.23e-4, + "negative": -273.15, + "zero": 0 +} +`, + indent: " ", + indentCount: 2, + description: "Numeric values preserved exactly", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Parse JSON as YAML (yaml parser handles JSON) + var node yaml.Node + err := yaml.Unmarshal([]byte(tt.inputJSON), &node) + require.NoError(t, err, "failed to parse input JSON") + + // Convert back to JSON with specified indentation + var buffer bytes.Buffer + err = json.YAMLToJSONWithConfig(&node, tt.indent, tt.indentCount, true, &buffer) + require.NoError(t, err, "YAMLToJSONWithConfig should not return error") + + actualJSON := buffer.String() + + // The output should exactly match the input + assert.Equal(t, tt.inputJSON, actualJSON, tt.description) + }) + } +} + +func TestJSONRoundTrip_SwaggerLikeDocuments(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + inputJSON string + description string + }{ + { + name: "minimal swagger with mixed formatting", + inputJSON: `{ + "swagger": "2.0", + "info": {"title": "API", "version": "1.0.0"}, + "paths": { + "/users": { + "get": { + "responses": { + "200": {"description": "Success"} + } + } + } + } +} +`, + description: "Swagger-like doc with mixed compact/pretty", + }, + { + name: "fully compact swagger", + inputJSON: `{"swagger":"2.0","info":{"title":"API","version":"1.0.0"},"paths":{"/users":{"get":{"responses":{"200":{"description":"OK"}}}}}} +`, + description: "Fully compact Swagger stays compact", + }, + { + name: "pretty swagger with compact inline objects", + inputJSON: `{ + "swagger": "2.0", + "info": { + "title": "My API", + "version": "1.0.0" + }, + "paths": { + "/users": { + "get": { + "tags": ["users"], + "parameters": [ + {"name": "id", "in": "query", "type": "string"}, + {"name": "limit", "in": "query", "type": "integer"} + ], + "responses": { + "200": { + "description": "Success", + "schema": {"type": "array", "items": {"$ref": "#/definitions/User"}} + } + } + } + } + } +} +`, + description: "Realistic API spec with inline parameter objects", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Parse JSON + var node yaml.Node + err := yaml.Unmarshal([]byte(tt.inputJSON), &node) + require.NoError(t, err, "failed to parse input JSON") + + // Convert back to JSON + var buffer bytes.Buffer + err = json.YAMLToJSON(&node, 2, &buffer) + require.NoError(t, err, "YAMLToJSON should not return error") + + actualJSON := buffer.String() + + // Should match exactly + assert.Equal(t, tt.inputJSON, actualJSON, tt.description) + }) + } +} + +func TestJSONRoundTrip_DifferentIndentations(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + inputJSON string + indent string + indentCount int + description string + }{ + { + name: "2-space to 2-space", + inputJSON: `{ + "key": "value", + "nested": { + "inner": "data" + } +} +`, + indent: " ", + indentCount: 2, + description: "2-space indentation preserved", + }, + { + name: "4-space to 4-space", + inputJSON: `{ + "key": "value", + "nested": { + "inner": "data" + } +} +`, + indent: " ", + indentCount: 4, + description: "4-space indentation preserved", + }, + { + name: "tabs to tabs", + inputJSON: "{\n\t\"key\": \"value\",\n\t\"nested\": {\n\t\t\"inner\": \"data\"\n\t}\n}\n", + indent: "\t", + indentCount: 1, + description: "Tab indentation preserved", + }, + { + name: "compact to compact", + inputJSON: `{"key":"value","nested":{"inner":"data"}}` + "\n", + indent: " ", + indentCount: 0, + description: "Compact JSON preserved with indent=0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Parse JSON + var node yaml.Node + err := yaml.Unmarshal([]byte(tt.inputJSON), &node) + require.NoError(t, err, "failed to parse input JSON") + + // Convert with specified indentation + var buffer bytes.Buffer + err = json.YAMLToJSONWithConfig(&node, tt.indent, tt.indentCount, true, &buffer) + require.NoError(t, err, "conversion should succeed") + + actualJSON := buffer.String() + + // Should match exactly + assert.Equal(t, tt.inputJSON, actualJSON, tt.description) + }) + } +} diff --git a/jsonschema/oas3/core/factory_registration.go b/jsonschema/oas3/core/factory_registration.go index 96c9c31..6f710f6 100644 --- a/jsonschema/oas3/core/factory_registration.go +++ b/jsonschema/oas3/core/factory_registration.go @@ -2,6 +2,7 @@ package core import ( "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/sequencedmap" "github.com/speakeasy-api/openapi/values/core" ) @@ -18,4 +19,14 @@ func init() { marshaller.RegisterType(func() *core.EitherValue[[]marshaller.Node[string], string] { return &core.EitherValue[[]marshaller.Node[string], string]{} }) + + // Register Node-wrapped EitherValue for additionalProperties + marshaller.RegisterType(func() *marshaller.Node[*core.EitherValue[Schema, bool]] { + return &marshaller.Node[*core.EitherValue[Schema, bool]]{} + }) + + // Register sequencedmap for additionalProperties (used in properties field) + marshaller.RegisterType(func() *sequencedmap.Map[string, marshaller.Node[*core.EitherValue[Schema, bool]]] { + return &sequencedmap.Map[string, marshaller.Node[*core.EitherValue[Schema, bool]]]{} + }) } diff --git a/marshaller/coremodel.go b/marshaller/coremodel.go index 20db53a..75f80c9 100644 --- a/marshaller/coremodel.go +++ b/marshaller/coremodel.go @@ -17,6 +17,7 @@ import ( type CoreModeler interface { GetRootNode() *yaml.Node SetRootNode(rootNode *yaml.Node) + SetDocumentNode(documentNode *yaml.Node) GetValid() bool GetValidYaml() bool SetValid(valid, validYaml bool) @@ -30,6 +31,7 @@ type CoreModeler interface { type CoreModel struct { RootNode *yaml.Node // RootNode is the node that was unmarshaled into this model + DocumentNode *yaml.Node // DocumentNode is the top-level document node (only set for top-level models) - contains header comments Valid bool // Valid indicates whether the model passed validation, ie all its required fields were present and ValidYaml is true ValidYaml bool // ValidYaml indicates whether the model's underlying YAML representation is valid, for example a mapping node was received for a model Config *yml.Config // Generally only set on the top-level model that was unmarshaled @@ -51,6 +53,15 @@ func (c CoreModel) GetRootNodeLine() int { func (c *CoreModel) SetRootNode(rootNode *yaml.Node) { c.RootNode = rootNode + + // If we have a DocumentNode, update its content to point to the new RootNode + if c.DocumentNode != nil && c.DocumentNode.Kind == yaml.DocumentNode { + c.DocumentNode.Content = []*yaml.Node{rootNode} + } +} + +func (c *CoreModel) SetDocumentNode(documentNode *yaml.Node) { + c.DocumentNode = documentNode } func (c CoreModel) GetValid() bool { @@ -152,20 +163,27 @@ func (c *CoreModel) GetJSONPath(topLevelRootNode *yaml.Node) string { func (c *CoreModel) Marshal(ctx context.Context, w io.Writer) error { cfg := yml.GetConfigFromContext(ctx) + // Use DocumentNode if available for YAML output (to preserve comments) + // For JSON output, use RootNode since JSON doesn't support comments + nodeToMarshal := c.RootNode + if c.DocumentNode != nil && cfg.OutputFormat == yml.OutputFormatYAML { + nodeToMarshal = c.DocumentNode + } + switch cfg.OutputFormat { case yml.OutputFormatYAML: // Check if we need to reset node styles (original was JSON, now want YAML) if cfg.OriginalFormat == yml.OutputFormatJSON && cfg.OutputFormat == yml.OutputFormatYAML { - resetNodeStylesForYAML(c.RootNode, cfg) + resetNodeStylesForYAML(nodeToMarshal, cfg) } enc := yaml.NewEncoder(w) enc.SetIndent(cfg.Indentation) - if err := enc.Encode(c.RootNode); err != nil { + if err := enc.Encode(nodeToMarshal); err != nil { return err } case yml.OutputFormatJSON: - if err := json.YAMLToJSON(c.RootNode, cfg.Indentation, w); err != nil { + if err := json.YAMLToJSONWithConfig(nodeToMarshal, cfg.IndentationStyle.ToIndent(), cfg.Indentation, cfg.TrailingNewline, w); err != nil { return err } default: diff --git a/marshaller/unmarshaller.go b/marshaller/unmarshaller.go index 74cf0f5..2b6c958 100644 --- a/marshaller/unmarshaller.go +++ b/marshaller/unmarshaller.go @@ -81,12 +81,29 @@ func UnmarshalNode[T any](ctx context.Context, parentName string, node *yaml.Nod } func UnmarshalCore(ctx context.Context, parentName string, node *yaml.Node, out any) ([]error, error) { + // Store the DocumentNode if this is a top-level document and the output implements CoreModeler + var documentNode *yaml.Node if node.Kind == yaml.DocumentNode { if len(node.Content) != 1 { return nil, fmt.Errorf("expected 1 node, got %d at line %d, column %d", len(node.Content), node.Line, node.Column) } - return UnmarshalCore(ctx, parentName, node.Content[0], out) + // Save the document node for potential use by CoreModeler implementations + documentNode = node + node = node.Content[0] + } + + // Set DocumentNode on CoreModeler implementations after unwrapping + if documentNode != nil { + v := reflect.ValueOf(out) + if v.Kind() == reflect.Ptr && !v.IsNil() { + v = v.Elem() + } + if implementsInterface(v, coreModelerType) { + if coreModeler, ok := v.Addr().Interface().(CoreModeler); ok { + coreModeler.SetDocumentNode(documentNode) + } + } } v := reflect.ValueOf(out) diff --git a/mise-tasks/test-coverage b/mise-tasks/test-coverage index 064c44a..ed079fd 100755 --- a/mise-tasks/test-coverage +++ b/mise-tasks/test-coverage @@ -1,8 +1,11 @@ #!/usr/bin/env bash set -uo pipefail +# Use provided packages or default to all packages +PACKAGES="${@:-./...}" + echo "🧪 Running tests with coverage using gotestsum..." -if ! gotestsum --format testname -- -race -coverprofile=coverage.out -covermode=atomic ./...; then +if ! gotestsum --format testname -- -race -coverprofile=coverage.out -covermode=atomic ${PACKAGES}; then echo "❌ Tests failed!" exit 1 fi diff --git a/openapi/operation.go b/openapi/operation.go index ffba3a4..b3c19f9 100644 --- a/openapi/operation.go +++ b/openapi/operation.go @@ -169,6 +169,7 @@ func (o *Operation) Validate(ctx context.Context, opts ...validation.Option) []e errs = append(errs, securityRequirement.Validate(ctx, opts...)...) } + // TODO allow validation of parameters, this isn't done at the moment as we would need to resolve references for _, parameter := range o.Parameters { errs = append(errs, parameter.Validate(ctx, opts...)...) } diff --git a/sequencedmap/map.go b/sequencedmap/map.go index c5ea7ad..fdf00f4 100644 --- a/sequencedmap/map.go +++ b/sequencedmap/map.go @@ -588,38 +588,98 @@ func (m *Map[K, V]) NavigateWithKey(key string) (any, error) { return v, nil } -// MarshalJSON returns the JSON representation of the map. func (m *Map[K, V]) MarshalJSON() ([]byte, error) { if m == nil { return []byte("null"), nil } - // TODO there might be a more efficient way to serialize this but this is fine for now - var buf bytes.Buffer + var ( + buf bytes.Buffer + keyBuf bytes.Buffer + valBuf bytes.Buffer - buf.WriteString("{") + keyEnc = json.NewEncoder(&keyBuf) + valEnc = json.NewEncoder(&valBuf) + ) - for i, element := range m.l { - ks := fmt.Sprintf("%v", element.Key) - kb, err := json.Marshal(ks) - if err != nil { - return nil, err + // Preserve your “no HTML escaping” requirement. + keyEnc.SetEscapeHTML(false) + valEnc.SetEscapeHTML(false) + + // Heuristic growth to reduce reallocations for large maps. + // We don't know value sizes; this just cuts down growth churn. + // Tweak the multiplier based on your typical key/value sizes. + if n := len(m.l); n > 0 { + // ~= `{` + `}` + commas/colons + rough key/value payloads + buf.Grow(n * 32) + } + + buf.WriteByte('{') + + first := true + for _, element := range m.l { + // ---- comma ---- + if first { + first = false + } else { + buf.WriteByte(',') } - buf.Write(kb) - buf.WriteString(":") - vb, err := json.Marshal(element.Value) - if err != nil { - return nil, err + + // ---- key ---- + var ks string + switch k := any(element.Key).(type) { + case string: + ks = k + case fmt.Stringer: + ks = k.String() + default: + ks = fmt.Sprint(element.Key) } - buf.Write(vb) - if i < len(m.l)-1 { - buf.WriteString(",") + keyBuf.Reset() + if err := keyEnc.Encode(ks); err != nil { + return nil, err + } + // Remove trailing newline from encoded key + keyBytes := keyBuf.Bytes() + if len(keyBytes) > 0 && keyBytes[len(keyBytes)-1] == '\n' { + keyBytes = keyBytes[:len(keyBytes)-1] + } + buf.Write(keyBytes) + buf.WriteByte(':') + + // ---- value ---- + // Fast-path: json.RawMessage means "already-JSON" → write as-is. + switch v := any(element.Value).(type) { + case json.RawMessage: + // v is nil/empty → write "null" to match json.Marshal(nil) + if v == nil { + buf.WriteString("null") + } else { + buf.Write([]byte(v)) + } + case *json.RawMessage: + if v == nil || *v == nil { + buf.WriteString("null") + } else { + buf.Write([]byte(*v)) + } + default: + // Fallback: regular encode without HTML escaping. + valBuf.Reset() + if err := valEnc.Encode(element.Value); err != nil { + return nil, err + } + // Remove trailing newline from encoded value + valBytes := valBuf.Bytes() + if len(valBytes) > 0 && valBytes[len(valBytes)-1] == '\n' { + valBytes = valBytes[:len(valBytes)-1] + } + buf.Write(valBytes) } } - buf.WriteString("}") - + buf.WriteByte('}') return buf.Bytes(), nil } diff --git a/swagger/core/externaldocs.go b/swagger/core/externaldocs.go new file mode 100644 index 0000000..59c77b9 --- /dev/null +++ b/swagger/core/externaldocs.go @@ -0,0 +1,15 @@ +package core + +import ( + "github.com/speakeasy-api/openapi/extensions/core" + "github.com/speakeasy-api/openapi/marshaller" +) + +// ExternalDocumentation allows referencing an external resource for extended documentation. +type ExternalDocumentation struct { + marshaller.CoreModel `model:"externalDocumentation"` + + Description marshaller.Node[*string] `key:"description"` + URL marshaller.Node[string] `key:"url"` + Extensions core.Extensions `key:"extensions"` +} diff --git a/swagger/core/factory_registration.go b/swagger/core/factory_registration.go new file mode 100644 index 0000000..93edb94 --- /dev/null +++ b/swagger/core/factory_registration.go @@ -0,0 +1,76 @@ +package core + +import ( + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/sequencedmap" +) + +// init registers all Swagger 2.0 core types with the marshaller factory system +func init() { + // Register main Swagger types + marshaller.RegisterType(func() *Swagger { return &Swagger{} }) + marshaller.RegisterType(func() *Info { return &Info{} }) + marshaller.RegisterType(func() *Contact { return &Contact{} }) + marshaller.RegisterType(func() *License { return &License{} }) + marshaller.RegisterType(func() *Paths { return &Paths{} }) + marshaller.RegisterType(func() *PathItem { return &PathItem{} }) + marshaller.RegisterType(func() *Operation { return &Operation{} }) + marshaller.RegisterType(func() *Parameter { return &Parameter{} }) + marshaller.RegisterType(func() *Items { return &Items{} }) + marshaller.RegisterType(func() *Responses { return &Responses{} }) + marshaller.RegisterType(func() *Response { return &Response{} }) + marshaller.RegisterType(func() *Header { return &Header{} }) + marshaller.RegisterType(func() *SecurityScheme { return &SecurityScheme{} }) + marshaller.RegisterType(func() *SecurityRequirement { return &SecurityRequirement{} }) + marshaller.RegisterType(func() *Tag { return &Tag{} }) + marshaller.RegisterType(func() *ExternalDocumentation { return &ExternalDocumentation{} }) + + // Register Reference types + marshaller.RegisterType(func() *Reference[*Parameter] { return &Reference[*Parameter]{} }) + marshaller.RegisterType(func() *Reference[*Response] { return &Reference[*Response]{} }) + marshaller.RegisterType(func() *Reference[*PathItem] { return &Reference[*PathItem]{} }) + marshaller.RegisterType(func() *Reference[*SecurityScheme] { return &Reference[*SecurityScheme]{} }) + + // Register Node-wrapped types + marshaller.RegisterType(func() *marshaller.Node[*PathItem] { return &marshaller.Node[*PathItem]{} }) + marshaller.RegisterType(func() *marshaller.Node[*Operation] { return &marshaller.Node[*Operation]{} }) + marshaller.RegisterType(func() *marshaller.Node[*Parameter] { return &marshaller.Node[*Parameter]{} }) + marshaller.RegisterType(func() *marshaller.Node[*Response] { return &marshaller.Node[*Response]{} }) + marshaller.RegisterType(func() *marshaller.Node[*Reference[*Parameter]] { return &marshaller.Node[*Reference[*Parameter]]{} }) + marshaller.RegisterType(func() *marshaller.Node[*Reference[*Response]] { return &marshaller.Node[*Reference[*Response]]{} }) + marshaller.RegisterType(func() *marshaller.Node[*SecurityScheme] { return &marshaller.Node[*SecurityScheme]{} }) + marshaller.RegisterType(func() *marshaller.Node[*Header] { return &marshaller.Node[*Header]{} }) + marshaller.RegisterType(func() *marshaller.Node[[]string] { return &marshaller.Node[[]string]{} }) + + // Register sequencedmap types used in swagger/core + marshaller.RegisterType(func() *sequencedmap.Map[string, marshaller.Node[*PathItem]] { + return &sequencedmap.Map[string, marshaller.Node[*PathItem]]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, marshaller.Node[*Operation]] { + return &sequencedmap.Map[string, marshaller.Node[*Operation]]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, marshaller.Node[*Parameter]] { + return &sequencedmap.Map[string, marshaller.Node[*Parameter]]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, marshaller.Node[*Response]] { + return &sequencedmap.Map[string, marshaller.Node[*Response]]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, marshaller.Node[*Reference[*Parameter]]] { + return &sequencedmap.Map[string, marshaller.Node[*Reference[*Parameter]]]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, marshaller.Node[*Reference[*Response]]] { + return &sequencedmap.Map[string, marshaller.Node[*Reference[*Response]]]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, marshaller.Node[*SecurityScheme]] { + return &sequencedmap.Map[string, marshaller.Node[*SecurityScheme]]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, marshaller.Node[*Header]] { + return &sequencedmap.Map[string, marshaller.Node[*Header]]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, marshaller.Node[string]] { + return &sequencedmap.Map[string, marshaller.Node[string]]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, marshaller.Node[[]string]] { + return &sequencedmap.Map[string, marshaller.Node[[]string]]{} + }) +} diff --git a/swagger/core/info.go b/swagger/core/info.go new file mode 100644 index 0000000..d7d08bd --- /dev/null +++ b/swagger/core/info.go @@ -0,0 +1,38 @@ +package core + +import ( + "github.com/speakeasy-api/openapi/extensions/core" + "github.com/speakeasy-api/openapi/marshaller" +) + +// Info provides metadata about the API. +type Info struct { + marshaller.CoreModel `model:"info"` + + Title marshaller.Node[string] `key:"title"` + Description marshaller.Node[*string] `key:"description"` + TermsOfService marshaller.Node[*string] `key:"termsOfService"` + Contact marshaller.Node[*Contact] `key:"contact"` + License marshaller.Node[*License] `key:"license"` + Version marshaller.Node[string] `key:"version"` + Extensions core.Extensions `key:"extensions"` +} + +// Contact information for the exposed API. +type Contact struct { + marshaller.CoreModel `model:"contact"` + + Name marshaller.Node[*string] `key:"name"` + URL marshaller.Node[*string] `key:"url"` + Email marshaller.Node[*string] `key:"email"` + Extensions core.Extensions `key:"extensions"` +} + +// License information for the exposed API. +type License struct { + marshaller.CoreModel `model:"license"` + + Name marshaller.Node[string] `key:"name"` + URL marshaller.Node[*string] `key:"url"` + Extensions core.Extensions `key:"extensions"` +} diff --git a/swagger/core/operation.go b/swagger/core/operation.go new file mode 100644 index 0000000..9e203a7 --- /dev/null +++ b/swagger/core/operation.go @@ -0,0 +1,25 @@ +package core + +import ( + "github.com/speakeasy-api/openapi/extensions/core" + "github.com/speakeasy-api/openapi/marshaller" +) + +// Operation describes a single API operation on a path. +type Operation struct { + marshaller.CoreModel `model:"operation"` + + Tags marshaller.Node[[]string] `key:"tags"` + Summary marshaller.Node[*string] `key:"summary"` + Description marshaller.Node[*string] `key:"description"` + ExternalDocs marshaller.Node[*ExternalDocumentation] `key:"externalDocs"` + OperationID marshaller.Node[*string] `key:"operationId"` + Consumes marshaller.Node[[]string] `key:"consumes"` + Produces marshaller.Node[[]string] `key:"produces"` + Parameters marshaller.Node[[]marshaller.Node[*Reference[*Parameter]]] `key:"parameters"` + Responses marshaller.Node[Responses] `key:"responses"` + Schemes marshaller.Node[[]string] `key:"schemes"` + Deprecated marshaller.Node[*bool] `key:"deprecated"` + Security marshaller.Node[[]marshaller.Node[*SecurityRequirement]] `key:"security"` + Extensions core.Extensions `key:"extensions"` +} diff --git a/swagger/core/parameter.go b/swagger/core/parameter.go new file mode 100644 index 0000000..5bf05c4 --- /dev/null +++ b/swagger/core/parameter.go @@ -0,0 +1,69 @@ +package core + +import ( + "github.com/speakeasy-api/openapi/extensions/core" + oascore "github.com/speakeasy-api/openapi/jsonschema/oas3/core" + "github.com/speakeasy-api/openapi/marshaller" + values "github.com/speakeasy-api/openapi/values/core" +) + +// Parameter describes a single operation parameter. +type Parameter struct { + marshaller.CoreModel `model:"parameter"` + + // Common fields for all parameter types + Name marshaller.Node[string] `key:"name"` + In marshaller.Node[string] `key:"in"` + Description marshaller.Node[*string] `key:"description"` + Required marshaller.Node[*bool] `key:"required"` + + // For body parameters + Schema marshaller.Node[oascore.JSONSchema] `key:"schema"` + + // For non-body parameters + Type marshaller.Node[*string] `key:"type"` + Format marshaller.Node[*string] `key:"format"` + AllowEmptyValue marshaller.Node[*bool] `key:"allowEmptyValue"` + Items marshaller.Node[*Items] `key:"items"` + CollectionFormat marshaller.Node[*string] `key:"collectionFormat"` + Default marshaller.Node[values.Value] `key:"default"` + Maximum marshaller.Node[*float64] `key:"maximum"` + ExclusiveMaximum marshaller.Node[*bool] `key:"exclusiveMaximum"` + Minimum marshaller.Node[*float64] `key:"minimum"` + ExclusiveMinimum marshaller.Node[*bool] `key:"exclusiveMinimum"` + MaxLength marshaller.Node[*int64] `key:"maxLength"` + MinLength marshaller.Node[*int64] `key:"minLength"` + Pattern marshaller.Node[*string] `key:"pattern"` + MaxItems marshaller.Node[*int64] `key:"maxItems"` + MinItems marshaller.Node[*int64] `key:"minItems"` + UniqueItems marshaller.Node[*bool] `key:"uniqueItems"` + Enum marshaller.Node[[]marshaller.Node[values.Value]] `key:"enum"` + MultipleOf marshaller.Node[*float64] `key:"multipleOf"` + + Extensions core.Extensions `key:"extensions"` +} + +// Items is a limited subset of JSON-Schema's items object for array parameters. +type Items struct { + marshaller.CoreModel `model:"items"` + + Type marshaller.Node[string] `key:"type"` + Format marshaller.Node[*string] `key:"format"` + Items marshaller.Node[*Items] `key:"items"` + CollectionFormat marshaller.Node[*string] `key:"collectionFormat"` + Default marshaller.Node[values.Value] `key:"default"` + Maximum marshaller.Node[*float64] `key:"maximum"` + ExclusiveMaximum marshaller.Node[*bool] `key:"exclusiveMaximum"` + Minimum marshaller.Node[*float64] `key:"minimum"` + ExclusiveMinimum marshaller.Node[*bool] `key:"exclusiveMinimum"` + MaxLength marshaller.Node[*int64] `key:"maxLength"` + MinLength marshaller.Node[*int64] `key:"minLength"` + Pattern marshaller.Node[*string] `key:"pattern"` + MaxItems marshaller.Node[*int64] `key:"maxItems"` + MinItems marshaller.Node[*int64] `key:"minItems"` + UniqueItems marshaller.Node[*bool] `key:"uniqueItems"` + Enum marshaller.Node[[]marshaller.Node[values.Value]] `key:"enum"` + MultipleOf marshaller.Node[*float64] `key:"multipleOf"` + + Extensions core.Extensions `key:"extensions"` +} diff --git a/swagger/core/paths.go b/swagger/core/paths.go new file mode 100644 index 0000000..c05967f --- /dev/null +++ b/swagger/core/paths.go @@ -0,0 +1,90 @@ +package core + +import ( + "github.com/speakeasy-api/openapi/extensions/core" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/sequencedmap" + "gopkg.in/yaml.v3" +) + +// Paths holds the relative paths to the individual endpoints. +type Paths struct { + marshaller.CoreModel `model:"paths"` + *sequencedmap.Map[string, marshaller.Node[*PathItem]] + + Extensions core.Extensions `key:"extensions"` +} + +func NewPaths() *Paths { + return &Paths{ + Map: sequencedmap.New[string, marshaller.Node[*PathItem]](), + } +} + +func (p *Paths) GetMapKeyNodeOrRoot(key string, rootNode *yaml.Node) *yaml.Node { + if !p.IsInitialized() { + return rootNode + } + + if p.RootNode == nil { + return rootNode + } + + for i := 0; i < len(p.RootNode.Content); i += 2 { + if p.RootNode.Content[i].Value == key { + return p.RootNode.Content[i] + } + } + + return rootNode +} + +func (p *Paths) GetMapKeyNodeOrRootLine(key string, rootNode *yaml.Node) int { + node := p.GetMapKeyNodeOrRoot(key, rootNode) + if node == nil { + return -1 + } + return node.Line +} + +// PathItem describes the operations available on a single path. +type PathItem struct { + marshaller.CoreModel `model:"pathItem"` + *sequencedmap.Map[string, marshaller.Node[*Operation]] + + Ref marshaller.Node[*string] `key:"$ref"` + Parameters marshaller.Node[[]*Reference[*Parameter]] `key:"parameters"` + Extensions core.Extensions `key:"extensions"` +} + +func NewPathItem() *PathItem { + return &PathItem{ + Map: sequencedmap.New[string, marshaller.Node[*Operation]](), + } +} + +func (p *PathItem) GetMapKeyNodeOrRoot(key string, rootNode *yaml.Node) *yaml.Node { + if !p.IsInitialized() { + return rootNode + } + + if p.RootNode == nil { + return rootNode + } + + for i := 0; i < len(p.RootNode.Content); i += 2 { + if p.RootNode.Content[i].Value == key { + return p.RootNode.Content[i] + } + } + + return rootNode +} + +func (p *PathItem) GetMapKeyNodeOrRootLine(key string, rootNode *yaml.Node) int { + node := p.GetMapKeyNodeOrRoot(key, rootNode) + if node == nil { + return -1 + } + return node.Line +} diff --git a/swagger/core/reference.go b/swagger/core/reference.go new file mode 100644 index 0000000..d95f5fb --- /dev/null +++ b/swagger/core/reference.go @@ -0,0 +1,120 @@ +package core + +import ( + "context" + "errors" + "fmt" + "reflect" + + "github.com/speakeasy-api/openapi/internal/interfaces" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/validation" + "github.com/speakeasy-api/openapi/yml" + "gopkg.in/yaml.v3" +) + +// Reference represents either a reference to a component or an inline object. +type Reference[T marshaller.CoreModeler] struct { + marshaller.CoreModel `model:"reference"` + + Reference marshaller.Node[*string] `key:"$ref"` + Object T `populatorValue:"true"` +} + +var _ interfaces.CoreModel = (*Reference[*Parameter])(nil) + +func (r *Reference[T]) Unmarshal(ctx context.Context, parentName string, node *yaml.Node) ([]error, error) { + resolvedNode := yml.ResolveAlias(node) + if resolvedNode == nil { + return nil, errors.New("node is nil") + } + + if resolvedNode.Kind != yaml.MappingNode { + r.SetValid(false, false) + return []error{validation.NewValidationError(validation.NewTypeMismatchError(parentName, "reference expected object, got %s", yml.NodeKindToString(resolvedNode.Kind)), resolvedNode)}, nil + } + + if _, _, ok := yml.GetMapElementNodes(ctx, resolvedNode, "$ref"); ok { + return marshaller.UnmarshalModel(ctx, node, r) + } + + var obj T + validationErrs, err := marshaller.UnmarshalCore(ctx, parentName, node, &obj) + if err != nil { + return nil, err + } + + r.Object = obj + r.SetValid(r.Object.GetValid(), r.Object.GetValidYaml() && len(validationErrs) == 0) + + return validationErrs, nil +} + +func (r *Reference[T]) SyncChanges(ctx context.Context, model any, valueNode *yaml.Node) (*yaml.Node, error) { + mv := reflect.ValueOf(model) + + if mv.Kind() == reflect.Ptr { + mv = mv.Elem() + } + + if mv.Kind() != reflect.Struct { + return nil, fmt.Errorf("Reference.SyncChanges expected a struct, got %s", mv.Kind()) + } + + of := mv.FieldByName("Object") + rf := mv.FieldByName("Reference") + + hasObject := !of.IsZero() + hasReference := !rf.IsZero() && !rf.IsNil() + + if hasObject && !hasReference { + // Inlined case + r.Reference = marshaller.Node[*string]{} + + var err error + valueNode, err = marshaller.SyncValue(ctx, of.Interface(), &r.Object, valueNode, false) + if err != nil { + return nil, err + } + + if valueNode != nil && valueNode.Kind == yaml.MappingNode { + newContent := make([]*yaml.Node, 0, len(valueNode.Content)) + for i := 0; i < len(valueNode.Content); i += 2 { + if i+1 < len(valueNode.Content) && valueNode.Content[i].Value != "$ref" { + newContent = append(newContent, valueNode.Content[i], valueNode.Content[i+1]) + } + } + valueNode.Content = newContent + } + + r.SetValid(r.Object.GetValid(), r.Object.GetValidYaml()) + } else { + // Reference case + var zero T + r.Object = zero + + var err error + valueNode, err = marshaller.SyncValue(ctx, model, r, valueNode, true) + if err != nil { + return nil, err + } + + if valueNode != nil && valueNode.Kind == yaml.MappingNode { + newContent := make([]*yaml.Node, 0, len(valueNode.Content)) + for i := 0; i < len(valueNode.Content); i += 2 { + if i+1 < len(valueNode.Content) { + key := valueNode.Content[i].Value + if key == "$ref" { + newContent = append(newContent, valueNode.Content[i], valueNode.Content[i+1]) + } + } + } + valueNode.Content = newContent + } + + r.SetValid(true, true) + } + + r.SetRootNode(valueNode) + return valueNode, nil +} diff --git a/swagger/core/response.go b/swagger/core/response.go new file mode 100644 index 0000000..a321012 --- /dev/null +++ b/swagger/core/response.go @@ -0,0 +1,93 @@ +package core + +import ( + "github.com/speakeasy-api/openapi/extensions/core" + oascore "github.com/speakeasy-api/openapi/jsonschema/oas3/core" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/sequencedmap" + values "github.com/speakeasy-api/openapi/values/core" + "gopkg.in/yaml.v3" +) + +// Responses is a container for the expected responses of an operation. +type Responses struct { + marshaller.CoreModel `model:"responses"` + *sequencedmap.Map[string, marshaller.Node[*Reference[*Response]]] + + Default marshaller.Node[*Reference[*Response]] `key:"default"` + Extensions core.Extensions `key:"extensions"` +} + +func NewResponses() *Responses { + return &Responses{ + Map: sequencedmap.New[string, marshaller.Node[*Reference[*Response]]](), + } +} + +func (r *Responses) GetMapKeyNodeOrRoot(key string, rootNode *yaml.Node) *yaml.Node { + if !r.IsInitialized() { + return rootNode + } + + if r.RootNode == nil { + return rootNode + } + + for i := 0; i < len(r.RootNode.Content); i += 2 { + if r.RootNode.Content[i].Value == key { + return r.RootNode.Content[i] + } + } + + return rootNode +} + +func (r *Responses) GetMapKeyNodeOrRootLine(key string, rootNode *yaml.Node) int { + node := r.GetMapKeyNodeOrRoot(key, rootNode) + if node == nil { + return -1 + } + return node.Line +} + +// Response describes a single response from an API operation. +type Response struct { + marshaller.CoreModel `model:"response"` + + Description marshaller.Node[string] `key:"description"` + Schema marshaller.Node[oascore.JSONSchema] `key:"schema"` + Headers marshaller.Node[*Headers] `key:"headers"` + Examples marshaller.Node[*Examples] `key:"examples"` + Extensions core.Extensions `key:"extensions"` +} + +// Examples is a map of MIME types to example values. +type Examples = sequencedmap.Map[string, marshaller.Node[values.Value]] + +// Headers is a map of header names to header definitions. +type Headers = sequencedmap.Map[string, marshaller.Node[*Header]] + +// Header describes a single header in a response. +type Header struct { + marshaller.CoreModel `model:"header"` + + Description marshaller.Node[*string] `key:"description"` + Type marshaller.Node[string] `key:"type"` + Format marshaller.Node[*string] `key:"format"` + Items marshaller.Node[*Items] `key:"items"` + CollectionFormat marshaller.Node[*string] `key:"collectionFormat"` + Default marshaller.Node[values.Value] `key:"default"` + Maximum marshaller.Node[*float64] `key:"maximum"` + ExclusiveMaximum marshaller.Node[*bool] `key:"exclusiveMaximum"` + Minimum marshaller.Node[*float64] `key:"minimum"` + ExclusiveMinimum marshaller.Node[*bool] `key:"exclusiveMinimum"` + MaxLength marshaller.Node[*int64] `key:"maxLength"` + MinLength marshaller.Node[*int64] `key:"minLength"` + Pattern marshaller.Node[*string] `key:"pattern"` + MaxItems marshaller.Node[*int64] `key:"maxItems"` + MinItems marshaller.Node[*int64] `key:"minItems"` + UniqueItems marshaller.Node[*bool] `key:"uniqueItems"` + Enum marshaller.Node[[]marshaller.Node[values.Value]] `key:"enum"` + MultipleOf marshaller.Node[*float64] `key:"multipleOf"` + Extensions core.Extensions `key:"extensions"` +} diff --git a/swagger/core/security.go b/swagger/core/security.go new file mode 100644 index 0000000..eade5da --- /dev/null +++ b/swagger/core/security.go @@ -0,0 +1,34 @@ +package core + +import ( + "github.com/speakeasy-api/openapi/extensions/core" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/sequencedmap" +) + +// SecurityScheme defines a security scheme that can be used by the operations. +type SecurityScheme struct { + marshaller.CoreModel `model:"securityScheme"` + + Type marshaller.Node[string] `key:"type"` + Description marshaller.Node[*string] `key:"description"` + Name marshaller.Node[*string] `key:"name"` + In marshaller.Node[*string] `key:"in"` + Flow marshaller.Node[*string] `key:"flow"` + AuthorizationURL marshaller.Node[*string] `key:"authorizationUrl"` + TokenURL marshaller.Node[*string] `key:"tokenUrl"` + Scopes marshaller.Node[*sequencedmap.Map[string, marshaller.Node[string]]] `key:"scopes"` + Extensions core.Extensions `key:"extensions"` +} + +// SecurityRequirement lists the required security schemes to execute an operation. +type SecurityRequirement struct { + marshaller.CoreModel `model:"securityRequirement"` + *sequencedmap.Map[string, marshaller.Node[[]string]] +} + +func NewSecurityRequirement() *SecurityRequirement { + return &SecurityRequirement{ + Map: sequencedmap.New[string, marshaller.Node[[]string]](), + } +} diff --git a/swagger/core/swagger.go b/swagger/core/swagger.go new file mode 100644 index 0000000..a904d59 --- /dev/null +++ b/swagger/core/swagger.go @@ -0,0 +1,30 @@ +package core + +import ( + "github.com/speakeasy-api/openapi/extensions/core" + oascore "github.com/speakeasy-api/openapi/jsonschema/oas3/core" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/sequencedmap" +) + +// Swagger is the root document object for the API specification (Swagger 2.0). +type Swagger struct { + marshaller.CoreModel `model:"swagger"` + + Swagger marshaller.Node[string] `key:"swagger" required:"true"` + Info marshaller.Node[Info] `key:"info"` + Host marshaller.Node[*string] `key:"host"` + BasePath marshaller.Node[*string] `key:"basePath"` + Schemes marshaller.Node[[]string] `key:"schemes"` + Consumes marshaller.Node[[]string] `key:"consumes"` + Produces marshaller.Node[[]string] `key:"produces"` + Paths marshaller.Node[Paths] `key:"paths"` + Definitions marshaller.Node[*sequencedmap.Map[string, marshaller.Node[oascore.JSONSchema]]] `key:"definitions"` + Parameters marshaller.Node[*sequencedmap.Map[string, marshaller.Node[*Parameter]]] `key:"parameters"` + Responses marshaller.Node[*sequencedmap.Map[string, marshaller.Node[*Response]]] `key:"responses"` + SecurityDefinitions marshaller.Node[*sequencedmap.Map[string, marshaller.Node[*SecurityScheme]]] `key:"securityDefinitions"` + Security marshaller.Node[[]marshaller.Node[*SecurityRequirement]] `key:"security"` + Tags marshaller.Node[[]marshaller.Node[*Tag]] `key:"tags"` + ExternalDocs marshaller.Node[*ExternalDocumentation] `key:"externalDocs"` + Extensions core.Extensions `key:"extensions"` +} diff --git a/swagger/core/tag.go b/swagger/core/tag.go new file mode 100644 index 0000000..eae3138 --- /dev/null +++ b/swagger/core/tag.go @@ -0,0 +1,16 @@ +package core + +import ( + "github.com/speakeasy-api/openapi/extensions/core" + "github.com/speakeasy-api/openapi/marshaller" +) + +// Tag allows adding metadata to a single tag that is used by operations. +type Tag struct { + marshaller.CoreModel `model:"tag"` + + Name marshaller.Node[string] `key:"name"` + Description marshaller.Node[*string] `key:"description"` + ExternalDocs marshaller.Node[*ExternalDocumentation] `key:"externalDocs"` + Extensions core.Extensions `key:"extensions"` +} diff --git a/swagger/externaldocs.go b/swagger/externaldocs.go new file mode 100644 index 0000000..c85c8be --- /dev/null +++ b/swagger/externaldocs.go @@ -0,0 +1,70 @@ +package swagger + +import ( + "context" + "net/url" + + "github.com/speakeasy-api/openapi/extensions" + "github.com/speakeasy-api/openapi/internal/interfaces" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/swagger/core" + "github.com/speakeasy-api/openapi/validation" +) + +// ExternalDocumentation allows referencing an external resource for extended documentation. +type ExternalDocumentation struct { + marshaller.Model[core.ExternalDocumentation] + + // Description is a short description of the target documentation. GFM syntax can be used for rich text representation. + Description *string + // URL is the URL for the target documentation. MUST be in the format of a URL. + URL string + // Extensions provides a list of extensions to the ExternalDocumentation object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.ExternalDocumentation] = (*ExternalDocumentation)(nil) + +// GetDescription returns the value of the Description field. Returns empty string if not set. +func (e *ExternalDocumentation) GetDescription() string { + if e == nil || e.Description == nil { + return "" + } + return *e.Description +} + +// GetURL returns the value of the URL field. Returns empty string if not set. +func (e *ExternalDocumentation) GetURL() string { + if e == nil { + return "" + } + return e.URL +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (e *ExternalDocumentation) GetExtensions() *extensions.Extensions { + if e == nil || e.Extensions == nil { + return extensions.New() + } + return e.Extensions +} + +// Validate validates the ExternalDocumentation object against the Swagger Specification. +func (e *ExternalDocumentation) Validate(ctx context.Context, opts ...validation.Option) []error { + c := e.GetCore() + errs := []error{} + + if c.URL.Present && e.URL == "" { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("externalDocs.url is required"), c, c.URL)) + } + + if c.URL.Present { + if _, err := url.Parse(e.URL); err != nil { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("externalDocs.url is not a valid uri: %s", err), c, c.URL)) + } + } + + e.Valid = len(errs) == 0 && c.GetValid() + + return errs +} diff --git a/swagger/factory_registration.go b/swagger/factory_registration.go new file mode 100644 index 0000000..54fcb70 --- /dev/null +++ b/swagger/factory_registration.go @@ -0,0 +1,83 @@ +package swagger + +import ( + "github.com/speakeasy-api/openapi/jsonschema/oas3" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/sequencedmap" + "github.com/speakeasy-api/openapi/values" +) + +// init registers all Swagger 2.0 wrapper types with the marshaller factory system +func init() { + // Register wrapper types + marshaller.RegisterType(func() *Swagger { return &Swagger{} }) + marshaller.RegisterType(func() *Info { return &Info{} }) + marshaller.RegisterType(func() *Contact { return &Contact{} }) + marshaller.RegisterType(func() *License { return &License{} }) + marshaller.RegisterType(func() *Paths { return &Paths{} }) + marshaller.RegisterType(func() *PathItem { return &PathItem{} }) + marshaller.RegisterType(func() *Operation { return &Operation{} }) + marshaller.RegisterType(func() *Parameter { return &Parameter{} }) + marshaller.RegisterType(func() *Items { return &Items{} }) + marshaller.RegisterType(func() *Responses { return &Responses{} }) + marshaller.RegisterType(func() *Response { return &Response{} }) + marshaller.RegisterType(func() *Header { return &Header{} }) + marshaller.RegisterType(func() *SecurityScheme { return &SecurityScheme{} }) + marshaller.RegisterType(func() *SecurityRequirement { return &SecurityRequirement{} }) + marshaller.RegisterType(func() *Tag { return &Tag{} }) + marshaller.RegisterType(func() *ExternalDocumentation { return &ExternalDocumentation{} }) + + // Register Reference types + marshaller.RegisterType(func() *ReferencedParameter { return &ReferencedParameter{} }) + marshaller.RegisterType(func() *ReferencedResponse { return &ReferencedResponse{} }) + + // Register Reference types + marshaller.RegisterType(func() *ReferencedParameter { return &ReferencedParameter{} }) + marshaller.RegisterType(func() *ReferencedResponse { return &ReferencedResponse{} }) + + // Register enum types + marshaller.RegisterType(func() *HTTPMethod { return new(HTTPMethod) }) + marshaller.RegisterType(func() *ParameterIn { return new(ParameterIn) }) + marshaller.RegisterType(func() *CollectionFormat { return new(CollectionFormat) }) + marshaller.RegisterType(func() *SecuritySchemeType { return new(SecuritySchemeType) }) + marshaller.RegisterType(func() *SecuritySchemeIn { return new(SecuritySchemeIn) }) + marshaller.RegisterType(func() *OAuth2Flow { return new(OAuth2Flow) }) + + // Register sequencedmap types used in swagger package + marshaller.RegisterType(func() *sequencedmap.Map[string, *PathItem] { + return &sequencedmap.Map[string, *PathItem]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[HTTPMethod, *Operation] { + return &sequencedmap.Map[HTTPMethod, *Operation]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, *Parameter] { + return &sequencedmap.Map[string, *Parameter]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, *Response] { + return &sequencedmap.Map[string, *Response]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, *ReferencedParameter] { + return &sequencedmap.Map[string, *ReferencedParameter]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, *ReferencedResponse] { + return &sequencedmap.Map[string, *ReferencedResponse]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, *oas3.JSONSchema[oas3.Concrete]] { + return &sequencedmap.Map[string, *oas3.JSONSchema[oas3.Concrete]]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, *SecurityScheme] { + return &sequencedmap.Map[string, *SecurityScheme]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, *Header] { + return &sequencedmap.Map[string, *Header]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, values.Value] { + return &sequencedmap.Map[string, values.Value]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, string] { + return &sequencedmap.Map[string, string]{} + }) + marshaller.RegisterType(func() *sequencedmap.Map[string, []string] { + return &sequencedmap.Map[string, []string]{} + }) +} diff --git a/swagger/info.go b/swagger/info.go new file mode 100644 index 0000000..80b864e --- /dev/null +++ b/swagger/info.go @@ -0,0 +1,251 @@ +package swagger + +import ( + "context" + "net/mail" + "net/url" + + "github.com/speakeasy-api/openapi/extensions" + "github.com/speakeasy-api/openapi/internal/interfaces" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/swagger/core" + "github.com/speakeasy-api/openapi/validation" +) + +// Info provides metadata about the API. +type Info struct { + marshaller.Model[core.Info] + + // Title is the title of the application. + Title string + // Description is a short description of the application. GFM syntax can be used for rich text representation. + Description *string + // TermsOfService is the Terms of Service for the API. + TermsOfService *string + // Contact is the contact information for the exposed API. + Contact *Contact + // License is the license information for the exposed API. + License *License + // Version provides the version of the application API (not to be confused with the specification version). + Version string + // Extensions provides a list of extensions to the Info object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.Info] = (*Info)(nil) + +// GetTitle returns the value of the Title field. Returns empty string if not set. +func (i *Info) GetTitle() string { + if i == nil { + return "" + } + return i.Title +} + +// GetDescription returns the value of the Description field. Returns empty string if not set. +func (i *Info) GetDescription() string { + if i == nil || i.Description == nil { + return "" + } + return *i.Description +} + +// GetTermsOfService returns the value of the TermsOfService field. Returns empty string if not set. +func (i *Info) GetTermsOfService() string { + if i == nil || i.TermsOfService == nil { + return "" + } + return *i.TermsOfService +} + +// GetContact returns the value of the Contact field. Returns nil if not set. +func (i *Info) GetContact() *Contact { + if i == nil { + return nil + } + return i.Contact +} + +// GetLicense returns the value of the License field. Returns nil if not set. +func (i *Info) GetLicense() *License { + if i == nil { + return nil + } + return i.License +} + +// GetVersion returns the value of the Version field. Returns empty string if not set. +func (i *Info) GetVersion() string { + if i == nil { + return "" + } + return i.Version +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (i *Info) GetExtensions() *extensions.Extensions { + if i == nil || i.Extensions == nil { + return extensions.New() + } + return i.Extensions +} + +// Validate validates the Info object against the Swagger Specification. +func (i *Info) Validate(ctx context.Context, opts ...validation.Option) []error { + c := i.GetCore() + errs := []error{} + + if c.Title.Present && i.Title == "" { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("info.title is required"), c, c.Title)) + } + + if c.Version.Present && i.Version == "" { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("info.version is required"), c, c.Version)) + } + + if c.TermsOfService.Present { + if _, err := url.Parse(*i.TermsOfService); err != nil { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("info.termsOfService is not a valid uri: %s", err), c, c.TermsOfService)) + } + } + + if c.Contact.Present { + errs = append(errs, i.Contact.Validate(ctx, opts...)...) + } + + if c.License.Present { + errs = append(errs, i.License.Validate(ctx, opts...)...) + } + + i.Valid = len(errs) == 0 && c.GetValid() + + return errs +} + +// Contact information for the exposed API. +type Contact struct { + marshaller.Model[core.Contact] + + // Name is the identifying name of the contact person/organization. + Name *string + // URL is the URL pointing to the contact information. MUST be in the format of a URL. + URL *string + // Email is the email address of the contact person/organization. MUST be in the format of an email address. + Email *string + // Extensions provides a list of extensions to the Contact object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.Contact] = (*Contact)(nil) + +// GetName returns the value of the Name field. Returns empty string if not set. +func (c *Contact) GetName() string { + if c == nil || c.Name == nil { + return "" + } + return *c.Name +} + +// GetURL returns the value of the URL field. Returns empty string if not set. +func (c *Contact) GetURL() string { + if c == nil || c.URL == nil { + return "" + } + return *c.URL +} + +// GetEmail returns the value of the Email field. Returns empty string if not set. +func (c *Contact) GetEmail() string { + if c == nil || c.Email == nil { + return "" + } + return *c.Email +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (c *Contact) GetExtensions() *extensions.Extensions { + if c == nil || c.Extensions == nil { + return extensions.New() + } + return c.Extensions +} + +// Validate validates the Contact object against the Swagger Specification. +func (c *Contact) Validate(ctx context.Context, opts ...validation.Option) []error { + core := c.GetCore() + errs := []error{} + + if core.URL.Present { + if _, err := url.Parse(*c.URL); err != nil { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("contact.url is not a valid uri: %s", err), core, core.URL)) + } + } + + if core.Email.Present { + if _, err := mail.ParseAddress(*c.Email); err != nil { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("contact.email is not a valid email address: %s", err), core, core.Email)) + } + } + + c.Valid = len(errs) == 0 && core.GetValid() + + return errs +} + +// License information for the exposed API. +type License struct { + marshaller.Model[core.License] + + // Name is the license name used for the API. + Name string + // URL is a URL to the license used for the API. MUST be in the format of a URL. + URL *string + // Extensions provides a list of extensions to the License object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.License] = (*License)(nil) + +// GetName returns the value of the Name field. Returns empty string if not set. +func (l *License) GetName() string { + if l == nil { + return "" + } + return l.Name +} + +// GetURL returns the value of the URL field. Returns empty string if not set. +func (l *License) GetURL() string { + if l == nil || l.URL == nil { + return "" + } + return *l.URL +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (l *License) GetExtensions() *extensions.Extensions { + if l == nil || l.Extensions == nil { + return extensions.New() + } + return l.Extensions +} + +// Validate validates the License object against the Swagger Specification. +func (l *License) Validate(ctx context.Context, opts ...validation.Option) []error { + core := l.GetCore() + errs := []error{} + + if core.Name.Present && l.Name == "" { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("license.name is required"), core, core.Name)) + } + + if core.URL.Present { + if _, err := url.Parse(*l.URL); err != nil { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("license.url is not a valid uri: %s", err), core, core.URL)) + } + } + + l.Valid = len(errs) == 0 && core.GetValid() + + return errs +} diff --git a/swagger/info_validate_test.go b/swagger/info_validate_test.go new file mode 100644 index 0000000..f12c20f --- /dev/null +++ b/swagger/info_validate_test.go @@ -0,0 +1,284 @@ +package swagger_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/swagger" + "github.com/stretchr/testify/require" +) + +func TestInfo_Validate_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "minimal_valid_info", + yml: `title: Test API +version: 1.0.0`, + }, + { + name: "complete_info", + yml: `title: Test API +version: 1.0.0 +description: A test API +termsOfService: https://example.com/terms +contact: + name: API Support + url: https://example.com/support + email: support@example.com +license: + name: Apache 2.0 + url: https://www.apache.org/licenses/LICENSE-2.0.html`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var info swagger.Info + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &info) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := info.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestInfo_Validate_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "missing_title", + yml: `version: 1.0.0`, + wantErrs: []string{"info.title is missing"}, + }, + { + name: "missing_version", + yml: `title: Test API`, + wantErrs: []string{"info.version is missing"}, + }, + { + name: "invalid_contact_email", + yml: `title: Test API +version: 1.0.0 +contact: + email: not-an-email`, + wantErrs: []string{"contact.email is not a valid email address"}, + }, + { + name: "missing_license_name", + yml: `title: Test API +version: 1.0.0 +license: + url: https://example.com/license`, + wantErrs: []string{"license.name is missing"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var info swagger.Info + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &info) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := info.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} + +func TestContact_Validate_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "valid_contact_with_all_fields", + yml: `name: API Support +url: https://example.com/support +email: support@example.com`, + }, + { + name: "valid_contact_with_name_only", + yml: `name: API Support`, + }, + { + name: "valid_contact_with_email_only", + yml: `email: support@example.com`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var contact swagger.Contact + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &contact) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := contact.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestContact_Validate_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "invalid_email", + yml: `email: not-an-email`, + wantErrs: []string{"contact.email is not a valid email address"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var contact swagger.Contact + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &contact) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := contact.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} + +func TestLicense_Validate_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "valid_license_with_url", + yml: `name: Apache 2.0 +url: https://www.apache.org/licenses/LICENSE-2.0.html`, + }, + { + name: "valid_license_without_url", + yml: `name: MIT`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var license swagger.License + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &license) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := license.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestLicense_Validate_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "missing_name", + yml: `url: https://example.com/license`, + wantErrs: []string{"license.name is missing"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var license swagger.License + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &license) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := license.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} diff --git a/swagger/marshalling.go b/swagger/marshalling.go new file mode 100644 index 0000000..c135e88 --- /dev/null +++ b/swagger/marshalling.go @@ -0,0 +1,63 @@ +package swagger + +import ( + "context" + "io" + + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/swagger/core" + "github.com/speakeasy-api/openapi/validation" +) + +type Option[T any] func(o *T) + +type UnmarshalOptions struct { + skipValidation bool +} + +// WithSkipValidation will skip validation of the Swagger document during unmarshaling. +// Useful to quickly load a document that will be mutated and validated later. +func WithSkipValidation() Option[UnmarshalOptions] { + return func(o *UnmarshalOptions) { + o.skipValidation = true + } +} + +// Unmarshal will unmarshal and validate a Swagger 2.0 document from the provided io.Reader. +// Validation can be skipped by using swagger.WithSkipValidation() as one of the options when calling this function. +func Unmarshal(ctx context.Context, doc io.Reader, opts ...Option[UnmarshalOptions]) (*Swagger, []error, error) { + o := UnmarshalOptions{} + for _, opt := range opts { + opt(&o) + } + + var swagger Swagger + + validationErrs, err := marshaller.Unmarshal(ctx, doc, &swagger) + if err != nil { + return nil, nil, err + } + + if o.skipValidation { + return &swagger, nil, nil + } + + if !o.skipValidation { + validationErrs = append(validationErrs, swagger.Validate(ctx)...) + validation.SortValidationErrors(validationErrs) + } + + return &swagger, validationErrs, nil +} + +// Marshal will marshal the provided Swagger document to the provided io.Writer. +func Marshal(ctx context.Context, swagger *Swagger, w io.Writer) error { + return marshaller.Marshal(ctx, swagger, w) +} + +// Sync will sync the high-level model to the core model. +// This is useful when creating or mutating a high-level model and wanting access to the yaml nodes that back it. +func Sync(ctx context.Context, model marshaller.Marshallable[core.Swagger]) error { + _, err := marshaller.SyncValue(ctx, model, model.GetCore(), model.GetRootNode(), false) + return err +} diff --git a/swagger/operation.go b/swagger/operation.go new file mode 100644 index 0000000..90cb167 --- /dev/null +++ b/swagger/operation.go @@ -0,0 +1,223 @@ +package swagger + +import ( + "context" + "mime" + + "github.com/speakeasy-api/openapi/extensions" + "github.com/speakeasy-api/openapi/internal/interfaces" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/swagger/core" + "github.com/speakeasy-api/openapi/validation" +) + +// Operation describes a single API operation on a path. +type Operation struct { + marshaller.Model[core.Operation] + + // Tags is a list of tags for API documentation control. + Tags []string + // Summary is a short summary of what the operation does. + Summary *string + // Description is a verbose explanation of the operation behavior. + Description *string + // ExternalDocs is additional external documentation for this operation. + ExternalDocs *ExternalDocumentation + // OperationID is a unique string used to identify the operation. + OperationID *string + // Consumes is a list of MIME types the operation can consume. + Consumes []string + // Produces is a list of MIME types the operation can produce. + Produces []string + // Parameters is a list of parameters that are applicable for this operation. + Parameters []*ReferencedParameter + // Responses is the list of possible responses as they are returned from executing this operation. + Responses *Responses + // Schemes is the transfer protocol for the operation. + Schemes []string + // Deprecated declares this operation to be deprecated. + Deprecated *bool + // Security is a declaration of which security schemes are applied for this operation. + Security []*SecurityRequirement + // Extensions provides a list of extensions to the Operation object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.Operation] = (*Operation)(nil) + +// GetTags returns the value of the Tags field. Returns nil if not set. +func (o *Operation) GetTags() []string { + if o == nil { + return nil + } + return o.Tags +} + +// GetSummary returns the value of the Summary field. Returns empty string if not set. +func (o *Operation) GetSummary() string { + if o == nil || o.Summary == nil { + return "" + } + return *o.Summary +} + +// GetDescription returns the value of the Description field. Returns empty string if not set. +func (o *Operation) GetDescription() string { + if o == nil || o.Description == nil { + return "" + } + return *o.Description +} + +// GetExternalDocs returns the value of the ExternalDocs field. Returns nil if not set. +func (o *Operation) GetExternalDocs() *ExternalDocumentation { + if o == nil { + return nil + } + return o.ExternalDocs +} + +// GetOperationID returns the value of the OperationID field. Returns empty string if not set. +func (o *Operation) GetOperationID() string { + if o == nil || o.OperationID == nil { + return "" + } + return *o.OperationID +} + +// GetConsumes returns the value of the Consumes field. Returns nil if not set. +func (o *Operation) GetConsumes() []string { + if o == nil { + return nil + } + return o.Consumes +} + +// GetProduces returns the value of the Produces field. Returns nil if not set. +func (o *Operation) GetProduces() []string { + if o == nil { + return nil + } + return o.Produces +} + +// GetParameters returns the value of the Parameters field. Returns nil if not set. +func (o *Operation) GetParameters() []*ReferencedParameter { + if o == nil { + return nil + } + return o.Parameters +} + +// GetResponses returns the value of the Responses field. Returns nil if not set. +func (o *Operation) GetResponses() *Responses { + if o == nil { + return nil + } + return o.Responses +} + +// GetSchemes returns the value of the Schemes field. Returns nil if not set. +func (o *Operation) GetSchemes() []string { + if o == nil { + return nil + } + return o.Schemes +} + +// GetDeprecated returns the value of the Deprecated field. False by default if not set. +func (o *Operation) GetDeprecated() bool { + if o == nil || o.Deprecated == nil { + return false + } + return *o.Deprecated +} + +// GetSecurity returns the value of the Security field. Returns nil if not set. +func (o *Operation) GetSecurity() []*SecurityRequirement { + if o == nil { + return nil + } + return o.Security +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (o *Operation) GetExtensions() *extensions.Extensions { + if o == nil || o.Extensions == nil { + return extensions.New() + } + return o.Extensions +} + +// Validate validates the Operation object against the Swagger Specification. +func (o *Operation) Validate(ctx context.Context, opts ...validation.Option) []error { + c := o.GetCore() + errs := []error{} + + if !c.Responses.Present { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("operation.responses is required"), c, c.Responses)) + } else if o.Responses != nil { + errs = append(errs, o.Responses.Validate(ctx, opts...)...) + } + + // Validate schemes if present + if c.Schemes.Present { + validSchemes := []string{"http", "https", "ws", "wss"} + for _, scheme := range o.Schemes { + valid := false + for _, vs := range validSchemes { + if scheme == vs { + valid = true + break + } + } + if !valid { + errs = append(errs, validation.NewValueError( + validation.NewValueValidationError("operation.scheme must be one of [http, https, ws, wss], got '%s'", scheme), + c, c.Schemes)) + } + } + } + + // Validate consumes MIME types + if c.Consumes.Present { + for _, mimeType := range o.Consumes { + if _, _, err := mime.ParseMediaType(mimeType); err != nil { + errs = append(errs, validation.NewValueError( + validation.NewValueValidationError("operation.consumes contains invalid MIME type '%s': %s", mimeType, err), + c, c.Consumes)) + } + } + } + + // Validate produces MIME types + if c.Produces.Present { + for _, mimeType := range o.Produces { + if _, _, err := mime.ParseMediaType(mimeType); err != nil { + errs = append(errs, validation.NewValueError( + validation.NewValueValidationError("operation.produces contains invalid MIME type '%s': %s", mimeType, err), + c, c.Produces)) + } + } + } + + if c.ExternalDocs.Present && o.ExternalDocs != nil { + errs = append(errs, o.ExternalDocs.Validate(ctx, opts...)...) + } + + // TODO allow validation of parameter uniqueness and body parameter count, this isn't done at the moment as we would need to resolve references + // Pass operation as context for file type parameter validation + for _, param := range o.Parameters { + errs = append(errs, param.Validate(ctx, append(opts, validation.WithContextObject(o))...)...) + } + + // Pass operation's parent Swagger as context for security requirement validation + // Note: Swagger context should be provided by caller + for _, secReq := range o.Security { + errs = append(errs, secReq.Validate(ctx, opts...)...) + } + + o.Valid = len(errs) == 0 && c.GetValid() + + return errs +} diff --git a/swagger/operation_validate_test.go b/swagger/operation_validate_test.go new file mode 100644 index 0000000..899b1ec --- /dev/null +++ b/swagger/operation_validate_test.go @@ -0,0 +1,116 @@ +package swagger_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/swagger" + "github.com/stretchr/testify/require" +) + +func TestOperation_Validate_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "minimal_valid_operation", + yml: `responses: + 200: + description: Success`, + }, + { + name: "complete_operation", + yml: `summary: Get users +description: Retrieve a list of users +operationId: getUsers +tags: + - users +consumes: + - application/json +produces: + - application/json +parameters: + - name: limit + in: query + type: integer +responses: + 200: + description: Success + 404: + description: Not found`, + }, + { + name: "operation_with_schemes", + yml: `schemes: + - https +responses: + 200: + description: Success`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var operation swagger.Operation + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &operation) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := operation.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestOperation_Validate_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "missing_responses", + yml: `summary: Test operation`, + wantErrs: []string{"operation.responses is required"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var operation swagger.Operation + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &operation) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := operation.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} diff --git a/swagger/parameter.go b/swagger/parameter.go new file mode 100644 index 0000000..0dc4600 --- /dev/null +++ b/swagger/parameter.go @@ -0,0 +1,400 @@ +package swagger + +import ( + "context" + "strings" + + "github.com/speakeasy-api/openapi/extensions" + "github.com/speakeasy-api/openapi/internal/interfaces" + "github.com/speakeasy-api/openapi/jsonschema/oas3" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/swagger/core" + "github.com/speakeasy-api/openapi/validation" + "github.com/speakeasy-api/openapi/values" +) + +// ParameterIn represents the location of a parameter. +type ParameterIn string + +const ( + // ParameterInQuery represents a query parameter. + ParameterInQuery ParameterIn = "query" + // ParameterInHeader represents a header parameter. + ParameterInHeader ParameterIn = "header" + // ParameterInPath represents a path parameter. + ParameterInPath ParameterIn = "path" + // ParameterInFormData represents a form data parameter. + ParameterInFormData ParameterIn = "formData" + // ParameterInBody represents a body parameter. + ParameterInBody ParameterIn = "body" +) + +// CollectionFormat represents how array parameters are serialized. +type CollectionFormat string + +const ( + // CollectionFormatCSV represents comma-separated values. + CollectionFormatCSV CollectionFormat = "csv" + // CollectionFormatSSV represents space-separated values. + CollectionFormatSSV CollectionFormat = "ssv" + // CollectionFormatTSV represents tab-separated values. + CollectionFormatTSV CollectionFormat = "tsv" + // CollectionFormatPipes represents pipe-separated values. + CollectionFormatPipes CollectionFormat = "pipes" + // CollectionFormatMulti represents multiple parameter instances. + CollectionFormatMulti CollectionFormat = "multi" +) + +// Parameter describes a single operation parameter. +type Parameter struct { + marshaller.Model[core.Parameter] + + // Name is the name of the parameter. + Name string + // In is the location of the parameter. + In ParameterIn + // Description is a brief description of the parameter. + Description *string + // Required determines whether this parameter is mandatory. + Required *bool + + // For body parameters + // Schema is the schema defining the type used for the body parameter. + Schema *oas3.JSONSchema[oas3.Referenceable] + + // For non-body parameters + // Type is the type of the parameter. + Type *string + // Format is the extending format for the type. + Format *string + // AllowEmptyValue sets the ability to pass empty-valued parameters (query or formData only). + AllowEmptyValue *bool + // Items describes the type of items in the array (if type is array). + Items *Items + // CollectionFormat determines the format of the array. + CollectionFormat *CollectionFormat + // Default declares the value the server will use if none is provided. + Default values.Value + // Maximum specifies the maximum value. + Maximum *float64 + // ExclusiveMaximum specifies if maximum is exclusive. + ExclusiveMaximum *bool + // Minimum specifies the minimum value. + Minimum *float64 + // ExclusiveMinimum specifies if minimum is exclusive. + ExclusiveMinimum *bool + // MaxLength specifies the maximum length. + MaxLength *int64 + // MinLength specifies the minimum length. + MinLength *int64 + // Pattern specifies a regex pattern the string must match. + Pattern *string + // MaxItems specifies the maximum number of items in an array. + MaxItems *int64 + // MinItems specifies the minimum number of items in an array. + MinItems *int64 + // UniqueItems specifies if all items must be unique. + UniqueItems *bool + // Enum specifies a list of allowed values. + Enum []values.Value + // MultipleOf specifies the value must be a multiple of this number. + MultipleOf *float64 + + // Extensions provides a list of extensions to the Parameter object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.Parameter] = (*Parameter)(nil) + +// GetName returns the value of the Name field. Returns empty string if not set. +func (p *Parameter) GetName() string { + if p == nil { + return "" + } + return p.Name +} + +// GetIn returns the value of the In field. +func (p *Parameter) GetIn() ParameterIn { + if p == nil { + return "" + } + return p.In +} + +// GetDescription returns the value of the Description field. Returns empty string if not set. +func (p *Parameter) GetDescription() string { + if p == nil || p.Description == nil { + return "" + } + return *p.Description +} + +// GetRequired returns the value of the Required field. False by default if not set. +func (p *Parameter) GetRequired() bool { + if p == nil || p.Required == nil { + return false + } + return *p.Required +} + +// GetSchema returns the value of the Schema field. Returns nil if not set. +func (p *Parameter) GetSchema() *oas3.JSONSchema[oas3.Referenceable] { + if p == nil { + return nil + } + return p.Schema +} + +// GetType returns the value of the Type field. Returns empty string if not set. +func (p *Parameter) GetType() string { + if p == nil || p.Type == nil { + return "" + } + return *p.Type +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (p *Parameter) GetExtensions() *extensions.Extensions { + if p == nil || p.Extensions == nil { + return extensions.New() + } + return p.Extensions +} + +// Validate validates the Parameter object against the Swagger Specification. +func (p *Parameter) Validate(ctx context.Context, opts ...validation.Option) []error { + c := p.GetCore() + errs := []error{} + + if c.Name.Present && p.Name == "" { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("parameter.name is required"), c, c.Name)) + } + + if c.In.Present && p.In == "" { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("parameter.in is required"), c, c.In)) + } else if c.In.Present { + errs = append(errs, p.validateIn(c)...) + errs = append(errs, p.validateParameterType(ctx, c, opts...)...) + } + + // allowEmptyValue only valid for query or formData + if c.AllowEmptyValue.Present && p.In != ParameterInQuery && p.In != ParameterInFormData { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("parameter.allowEmptyValue is only valid for in=query or in=formData"), c, c.AllowEmptyValue)) + } + + // Validate items if present + if c.Items.Present && p.Items != nil { + errs = append(errs, p.Items.Validate(ctx, opts...)...) + } + + // Validate file type parameter consumes from operation context + if p.Type != nil && *p.Type == "file" { + validationOpts := validation.NewOptions(opts...) + if operation := validation.GetContextObject[Operation](validationOpts); operation != nil { + opCore := operation.GetCore() + if !opCore.Consumes.Present || len(operation.Consumes) == 0 { + errs = append(errs, validation.NewValueError( + validation.NewValueValidationError("parameter with type=file requires operation to have consumes defined"), + c, c.Type)) + } else { + hasValidConsumes := false + for _, mimeType := range operation.Consumes { + if mimeType == "multipart/form-data" || mimeType == "application/x-www-form-urlencoded" { + hasValidConsumes = true + break + } + } + if !hasValidConsumes { + errs = append(errs, validation.NewValueError( + validation.NewValueValidationError("parameter with type=file requires operation consumes to be 'multipart/form-data' or 'application/x-www-form-urlencoded'"), + c, c.Type)) + } + } + } + } + + p.Valid = len(errs) == 0 && c.GetValid() + + return errs +} + +func (p *Parameter) validateIn(c *core.Parameter) []error { + errs := []error{} + + validIns := []ParameterIn{ParameterInQuery, ParameterInHeader, ParameterInPath, ParameterInFormData, ParameterInBody} + valid := false + for _, in := range validIns { + if p.In == in { + valid = true + break + } + } + if !valid { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("parameter.in must be one of [%s]", strings.Join([]string{string(ParameterInQuery), string(ParameterInHeader), string(ParameterInPath), string(ParameterInFormData), string(ParameterInBody)}, ", ")), c, c.In)) + } + + return errs +} + +func (p *Parameter) validateParameterType(ctx context.Context, c *core.Parameter, opts ...validation.Option) []error { + errs := []error{} + + // Path parameters must be required + if p.In == ParameterInPath && (!c.Required.Present || !p.GetRequired()) { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("parameter.in=path requires required=true"), c, c.Required)) + } + + // Body parameters require schema + if p.In == ParameterInBody { + if !c.Schema.Present { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("parameter.schema is required for in=body"), c, c.Schema)) + return errs + } + errs = append(errs, p.Schema.Validate(ctx, opts...)...) + return errs + } + + // Non-body parameters require type + if !c.Type.Present { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("parameter.type is required for non-body parameters"), c, c.Type)) + return errs + } + + if c.Type.Present && (p.Type == nil || *p.Type == "") { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("parameter.type is required for non-body parameters"), c, c.Type)) + return errs + } + + if p.Type != nil { + validTypes := []string{"string", "number", "integer", "boolean", "array", "file"} + valid := false + for _, t := range validTypes { + if *p.Type == t { + valid = true + break + } + } + if !valid { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("parameter.type must be one of [%s]", strings.Join(validTypes, ", ")), c, c.Type)) + } + + // File type only allowed for formData + if *p.Type == "file" && p.In != ParameterInFormData { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("parameter.type=file requires in=formData"), c, c.Type)) + } + + // Array type requires items + if *p.Type == "array" && !c.Items.Present { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("parameter.items is required when type=array"), c, c.Items)) + } + + // Validate collectionFormat=multi only for query or formData + if p.CollectionFormat != nil && *p.CollectionFormat == CollectionFormatMulti { + if p.In != ParameterInQuery && p.In != ParameterInFormData { + errs = append(errs, validation.NewValueError( + validation.NewValueValidationError("collectionFormat='multi' is only valid for in=query or in=formData"), + c, c.CollectionFormat)) + } + } + } + + return errs +} + +// Items is a limited subset of JSON-Schema's items object for array parameters. +type Items struct { + marshaller.Model[core.Items] + + // Type is the internal type of the array. + Type string + // Format is the extending format for the type. + Format *string + // Items describes the type of items in nested arrays. + Items *Items + // CollectionFormat determines the format of the array. + CollectionFormat *CollectionFormat + // Default declares the value the server will use if none is provided. + Default values.Value + // Maximum specifies the maximum value. + Maximum *float64 + // ExclusiveMaximum specifies if maximum is exclusive. + ExclusiveMaximum *bool + // Minimum specifies the minimum value. + Minimum *float64 + // ExclusiveMinimum specifies if minimum is exclusive. + ExclusiveMinimum *bool + // MaxLength specifies the maximum length. + MaxLength *int64 + // MinLength specifies the minimum length. + MinLength *int64 + // Pattern specifies a regex pattern the string must match. + Pattern *string + // MaxItems specifies the maximum number of items in an array. + MaxItems *int64 + // MinItems specifies the minimum number of items in an array. + MinItems *int64 + // UniqueItems specifies if all items must be unique. + UniqueItems *bool + // Enum specifies a list of allowed values. + Enum []values.Value + // MultipleOf specifies the value must be a multiple of this number. + MultipleOf *float64 + + // Extensions provides a list of extensions to the Items object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.Items] = (*Items)(nil) + +// GetType returns the value of the Type field. Returns empty string if not set. +func (i *Items) GetType() string { + if i == nil { + return "" + } + return i.Type +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (i *Items) GetExtensions() *extensions.Extensions { + if i == nil || i.Extensions == nil { + return extensions.New() + } + return i.Extensions +} + +// Validate validates the Items object against the Swagger Specification. +func (i *Items) Validate(ctx context.Context, opts ...validation.Option) []error { + c := i.GetCore() + errs := []error{} + + if c.Type.Present && i.Type == "" { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("items.type is required"), c, c.Type)) + } else if c.Type.Present { + validTypes := []string{"string", "number", "integer", "boolean", "array"} + valid := false + for _, t := range validTypes { + if i.Type == t { + valid = true + break + } + } + if !valid { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("items.type must be one of [%s]", strings.Join(validTypes, ", ")), c, c.Type)) + } + + // Array type requires items + if i.Type == "array" && !c.Items.Present { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("items.items is required when type=array"), c, c.Items)) + } + } + + // Validate nested items if present + if c.Items.Present && i.Items != nil { + errs = append(errs, i.Items.Validate(ctx, opts...)...) + } + + i.Valid = len(errs) == 0 && c.GetValid() + + return errs +} diff --git a/swagger/parameter_test.go b/swagger/parameter_test.go new file mode 100644 index 0000000..e18f980 --- /dev/null +++ b/swagger/parameter_test.go @@ -0,0 +1,428 @@ +package swagger_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/swagger" + "github.com/stretchr/testify/require" +) + +func TestParameter_Unmarshal_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yaml string + in swagger.ParameterIn + typ string + }{ + { + name: "path parameter", + yaml: ` +name: id +in: path +description: The ID +required: true +type: string +`, + in: swagger.ParameterInPath, + typ: "string", + }, + { + name: "query parameter with array", + yaml: ` +name: ids +in: query +type: array +items: + type: string +collectionFormat: csv +`, + in: swagger.ParameterInQuery, + typ: "array", + }, + { + name: "body parameter", + yaml: ` +name: user +in: body +description: User object +required: true +schema: + type: object + properties: + name: + type: string +`, + in: swagger.ParameterInBody, + }, + { + name: "header parameter", + yaml: ` +name: X-API-Key +in: header +type: string +required: true +`, + in: swagger.ParameterInHeader, + typ: "string", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var param swagger.Parameter + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yaml), ¶m) + require.NoError(t, err, "unmarshal should succeed") + require.Empty(t, validationErrs, "should have no unmarshalling validation errors") + + require.Equal(t, tt.in, param.In, "should have correct location") + if tt.typ != "" { + require.NotNil(t, param.Type, "type should be set") + require.Equal(t, tt.typ, *param.Type, "should have correct type") + } + }) + } +} + +func TestParameter_Validate_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "valid path parameter", + yml: ` +name: id +in: path +required: true +type: string +`, + }, + { + name: "valid query parameter", + yml: ` +name: limit +in: query +type: integer +format: int32 +`, + }, + { + name: "valid body parameter", + yml: ` +name: body +in: body +schema: + type: object +`, + }, + { + name: "valid formData parameter", + yml: ` +name: file +in: formData +type: file +`, + }, + { + name: "valid header parameter", + yml: ` +name: X-API-Key +in: header +type: string +`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var param swagger.Parameter + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), ¶m) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := param.Validate(t.Context()) + require.Empty(t, errs, "expected no validation errors") + require.True(t, param.Valid, "expected parameter to be valid") + }) + } +} + +func TestParameter_Validate_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "missing name", + yml: ` +in: query +type: string +`, + wantErrs: []string{"parameter.name is missing"}, + }, + { + name: "empty name", + yml: ` +name: "" +in: query +type: string +`, + wantErrs: []string{"parameter.name is required"}, + }, + { + name: "missing in", + yml: ` +name: test +type: string +`, + wantErrs: []string{"parameter.in is missing"}, + }, + { + name: "path parameter not required", + yml: ` +name: userId +in: path +required: false +type: string +`, + wantErrs: []string{"parameter.in=path requires required=true"}, + }, + { + name: "path parameter missing required", + yml: ` +name: userId +in: path +type: string +`, + wantErrs: []string{"parameter.in=path requires required=true"}, + }, + { + name: "invalid parameter location", + yml: ` +name: test +in: invalid +type: string +`, + wantErrs: []string{"parameter.in must be one of"}, + }, + { + name: "body parameter without schema", + yml: ` +name: body +in: body +`, + wantErrs: []string{"parameter.schema is required for in=body"}, + }, + { + name: "non-body parameter without type", + yml: ` +name: id +in: query +`, + wantErrs: []string{"parameter.type is required for non-body parameters"}, + }, + { + name: "array parameter without items", + yml: ` +name: ids +in: query +type: array +`, + wantErrs: []string{"parameter.items is required when type=array"}, + }, + { + name: "file type not in formData", + yml: ` +name: file +in: query +type: file +`, + wantErrs: []string{"parameter.type=file requires in=formData"}, + }, + { + name: "invalid parameter type", + yml: ` +name: test +in: query +type: object +`, + wantErrs: []string{"parameter.type must be one of"}, + }, + { + name: "multiple validation errors", + yml: ` +name: "" +in: path +required: false +`, + wantErrs: []string{ + "parameter.name is required", + "parameter.in=path requires required=true", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var param swagger.Parameter + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), ¶m) + require.NoError(t, err) + + // Collect all errors from both unmarshalling and validation + var allErrors []error + allErrors = append(allErrors, validationErrs...) + + validateErrs := param.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "expected validation errors") + + // Check that all expected error messages are present + var errMessages []string + for _, err := range allErrors { + errMessages = append(errMessages, err.Error()) + } + + for _, expectedErr := range tt.wantErrs { + found := false + for _, errMsg := range errMessages { + if strings.Contains(errMsg, expectedErr) { + found = true + break + } + } + require.True(t, found, "expected error message '%s' not found in: %v", expectedErr, errMessages) + } + }) + } +} + +func TestItems_Validate_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "simple string items", + yml: ` +type: string +`, + }, + { + name: "nested array items", + yml: ` +type: array +items: + type: integer +`, + }, + { + name: "items with format", + yml: ` +type: integer +format: int64 +`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var items swagger.Items + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &items) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := items.Validate(t.Context()) + require.Empty(t, errs, "expected no validation errors") + require.True(t, items.Valid, "expected items to be valid") + }) + } +} + +func TestItems_Validate_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "missing type", + yml: ` +format: int32 +`, + wantErrs: []string{"items.type is missing"}, + }, + { + name: "array items without nested items", + yml: ` +type: array +`, + wantErrs: []string{"items.items is required when type=array"}, + }, + { + name: "invalid items type", + yml: ` +type: object +`, + wantErrs: []string{"items.type must be one of"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var items swagger.Items + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &items) + require.NoError(t, err) + + // Collect all errors from both unmarshalling and validation + var allErrors []error + allErrors = append(allErrors, validationErrs...) + + validateErrs := items.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "expected validation errors") + + // Check that all expected error messages are present + var errMessages []string + for _, err := range allErrors { + errMessages = append(errMessages, err.Error()) + } + + for _, expectedErr := range tt.wantErrs { + found := false + for _, errMsg := range errMessages { + if strings.Contains(errMsg, expectedErr) { + found = true + break + } + } + require.True(t, found, "expected error message '%s' not found in: %v", expectedErr, errMessages) + } + }) + } +} diff --git a/swagger/paths.go b/swagger/paths.go new file mode 100644 index 0000000..47e1016 --- /dev/null +++ b/swagger/paths.go @@ -0,0 +1,194 @@ +package swagger + +import ( + "context" + "strings" + + "github.com/speakeasy-api/openapi/extensions" + "github.com/speakeasy-api/openapi/internal/interfaces" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/sequencedmap" + "github.com/speakeasy-api/openapi/swagger/core" + "github.com/speakeasy-api/openapi/validation" +) + +// Paths holds the relative paths to the individual endpoints. +type Paths struct { + marshaller.Model[core.Paths] + *sequencedmap.Map[string, *PathItem] + + // Extensions provides a list of extensions to the Paths object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.Paths] = (*Paths)(nil) + +// NewPaths creates a new Paths object with an initialized map. +func NewPaths() *Paths { + return &Paths{ + Map: sequencedmap.New[string, *PathItem](), + } +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (p *Paths) GetExtensions() *extensions.Extensions { + if p == nil || p.Extensions == nil { + return extensions.New() + } + return p.Extensions +} + +// Validate validates the Paths object according to the Swagger Specification. +func (p *Paths) Validate(ctx context.Context, opts ...validation.Option) []error { + c := p.GetCore() + errs := []error{} + + // Validate that path keys start with a slash + for path, pathItem := range p.All() { + if !strings.HasPrefix(path, "/") { + pathKeyNode := c.GetMapKeyNodeOrRoot(path, c.RootNode) + errs = append(errs, validation.NewValidationError( + validation.NewValueValidationError("path '%s' must begin with a slash '/'", path), + pathKeyNode)) + } + errs = append(errs, pathItem.Validate(ctx, opts...)...) + } + + p.Valid = len(errs) == 0 + + return errs +} + +// HTTPMethod is an enum representing the HTTP methods available in the Swagger specification. +type HTTPMethod string + +const ( + // HTTPMethodGet represents the HTTP GET method. + HTTPMethodGet HTTPMethod = "get" + // HTTPMethodPut represents the HTTP PUT method. + HTTPMethodPut HTTPMethod = "put" + // HTTPMethodPost represents the HTTP POST method. + HTTPMethodPost HTTPMethod = "post" + // HTTPMethodDelete represents the HTTP DELETE method. + HTTPMethodDelete HTTPMethod = "delete" + // HTTPMethodOptions represents the HTTP OPTIONS method. + HTTPMethodOptions HTTPMethod = "options" + // HTTPMethodHead represents the HTTP HEAD method. + HTTPMethodHead HTTPMethod = "head" + // HTTPMethodPatch represents the HTTP PATCH method. + HTTPMethodPatch HTTPMethod = "patch" +) + +// PathItem describes the operations available on a single path. +type PathItem struct { + marshaller.Model[core.PathItem] + *sequencedmap.Map[HTTPMethod, *Operation] + + // Ref allows for an external definition of this path item. + Ref *string + // Parameters is a list of parameters that are applicable for all operations in this path. + Parameters []*ReferencedParameter + // Extensions provides a list of extensions to the PathItem object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.PathItem] = (*PathItem)(nil) + +// NewPathItem creates a new PathItem object with an initialized map. +func NewPathItem() *PathItem { + return &PathItem{ + Map: sequencedmap.New[HTTPMethod, *Operation](), + } +} + +// GetRef returns the value of the Ref field. Returns empty string if not set. +func (p *PathItem) GetRef() string { + if p == nil || p.Ref == nil { + return "" + } + return *p.Ref +} + +// GetParameters returns the value of the Parameters field. Returns nil if not set. +func (p *PathItem) GetParameters() []*ReferencedParameter { + if p == nil { + return nil + } + return p.Parameters +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (p *PathItem) GetExtensions() *extensions.Extensions { + if p == nil || p.Extensions == nil { + return extensions.New() + } + return p.Extensions +} + +// GetOperation returns the operation for the specified HTTP method. +func (p *PathItem) GetOperation(method HTTPMethod) *Operation { + if p == nil || !p.IsInitialized() { + return nil + } + + op, ok := p.Map.Get(method) + if !ok { + return nil + } + + return op +} + +// Get returns the GET operation for this path item. +func (p *PathItem) Get() *Operation { + return p.GetOperation(HTTPMethodGet) +} + +// Put returns the PUT operation for this path item. +func (p *PathItem) Put() *Operation { + return p.GetOperation(HTTPMethodPut) +} + +// Post returns the POST operation for this path item. +func (p *PathItem) Post() *Operation { + return p.GetOperation(HTTPMethodPost) +} + +// Delete returns the DELETE operation for this path item. +func (p *PathItem) Delete() *Operation { + return p.GetOperation(HTTPMethodDelete) +} + +// Options returns the OPTIONS operation for this path item. +func (p *PathItem) Options() *Operation { + return p.GetOperation(HTTPMethodOptions) +} + +// Head returns the HEAD operation for this path item. +func (p *PathItem) Head() *Operation { + return p.GetOperation(HTTPMethodHead) +} + +// Patch returns the PATCH operation for this path item. +func (p *PathItem) Patch() *Operation { + return p.GetOperation(HTTPMethodPatch) +} + +// Validate validates the PathItem object according to the Swagger Specification. +func (p *PathItem) Validate(ctx context.Context, opts ...validation.Option) []error { + c := p.GetCore() + errs := []error{} + + // TODO allow validation of parameter uniqueness and body parameter count, this isn't done at the moment as we would need to resolve references + for _, parameter := range p.Parameters { + errs = append(errs, parameter.Validate(ctx, opts...)...) + } + + for _, op := range p.All() { + errs = append(errs, op.Validate(ctx, opts...)...) + } + + p.Valid = len(errs) == 0 && c.GetValid() + + return errs +} diff --git a/swagger/reference.go b/swagger/reference.go new file mode 100644 index 0000000..5a29232 --- /dev/null +++ b/swagger/reference.go @@ -0,0 +1,142 @@ +package swagger + +import ( + "context" + "errors" + "fmt" + + "github.com/speakeasy-api/openapi/internal/interfaces" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/pointer" + "github.com/speakeasy-api/openapi/references" + "github.com/speakeasy-api/openapi/swagger/core" + "github.com/speakeasy-api/openapi/validation" +) + +type ( + // ReferencedParameter represents a parameter that can either be referenced from elsewhere or declared inline. + ReferencedParameter = Reference[Parameter, *Parameter, *core.Parameter] + // ReferencedResponse represents a response that can either be referenced from elsewhere or declared inline. + ReferencedResponse = Reference[Response, *Response, *core.Response] +) + +// NewReferencedParameterFromRef creates a new ReferencedParameter from a reference string. +func NewReferencedParameterFromRef(ref references.Reference) *ReferencedParameter { + return &ReferencedParameter{ + Reference: &ref, + } +} + +// NewReferencedParameterFromParameter creates a new ReferencedParameter from a Parameter. +func NewReferencedParameterFromParameter(parameter *Parameter) *ReferencedParameter { + return &ReferencedParameter{ + Object: parameter, + } +} + +// NewReferencedResponseFromRef creates a new ReferencedResponse from a reference string. +func NewReferencedResponseFromRef(ref references.Reference) *ReferencedResponse { + return &ReferencedResponse{ + Reference: &ref, + } +} + +// NewReferencedResponseFromResponse creates a new ReferencedResponse from a Response. +func NewReferencedResponseFromResponse(response *Response) *ReferencedResponse { + return &ReferencedResponse{ + Object: response, + } +} + +// Reference represents an object that can either be referenced from elsewhere or declared inline. +type Reference[T any, V interfaces.Validator[T], C marshaller.CoreModeler] struct { + marshaller.Model[core.Reference[C]] + + // Reference is the reference string ($ref). + Reference *references.Reference + + // If this was an inline object instead of a reference this will contain that object. + Object *T +} + +var _ interfaces.Model[core.Reference[*core.Parameter]] = (*Reference[Parameter, *Parameter, *core.Parameter])(nil) + +// IsReference returns true if the reference is a reference (via $ref) to an object as opposed to an inline object. +func (r *Reference[T, V, C]) IsReference() bool { + if r == nil { + return false + } + return r.Reference != nil +} + +// GetReference returns the value of the Reference field. Returns empty string if not set. +func (r *Reference[T, V, C]) GetReference() references.Reference { + if r == nil || r.Reference == nil { + return "" + } + return *r.Reference +} + +// GetObject returns the referenced object. If this is a reference, this will return nil. +func (r *Reference[T, V, C]) GetObject() *T { + if r == nil { + return nil + } + + if r.IsReference() { + return nil + } + + return r.Object +} + +// Validate validates the Reference object against the Swagger Specification. +func (r *Reference[T, V, C]) Validate(ctx context.Context, opts ...validation.Option) []error { + if r == nil { + return []error{errors.New("reference is nil")} + } + + c := r.GetCore() + if c == nil { + return []error{errors.New("reference core is nil")} + } + + errs := []error{} + + if c.Reference.Present && r.Object != nil { + // Use the validator interface V to validate the object + var validator V + if v, ok := any(r.Object).(V); ok { + validator = v + errs = append(errs, validator.Validate(ctx, opts...)...) + } + } + + r.Valid = len(errs) == 0 && c.GetValid() + + return errs +} + +func (r *Reference[T, V, C]) Populate(source any) error { + var s *core.Reference[C] + switch src := source.(type) { + case *core.Reference[C]: + s = src + case core.Reference[C]: + s = &src + default: + return fmt.Errorf("expected *core.Reference[C] or core.Reference[C], got %T", source) + } + + if s.Reference.Present { + r.Reference = pointer.From(references.Reference(*s.Reference.Value)) + } else { + if err := marshaller.Populate(s.Object, &r.Object); err != nil { + return err + } + } + + r.SetCore(s) + + return nil +} diff --git a/swagger/response.go b/swagger/response.go new file mode 100644 index 0000000..baea119 --- /dev/null +++ b/swagger/response.go @@ -0,0 +1,259 @@ +package swagger + +import ( + "context" + + "github.com/speakeasy-api/openapi/extensions" + "github.com/speakeasy-api/openapi/internal/interfaces" + "github.com/speakeasy-api/openapi/jsonschema/oas3" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/sequencedmap" + "github.com/speakeasy-api/openapi/swagger/core" + "github.com/speakeasy-api/openapi/validation" + "github.com/speakeasy-api/openapi/values" +) + +// Responses is a container for the expected responses of an operation. +type Responses struct { + marshaller.Model[core.Responses] + *sequencedmap.Map[string, *ReferencedResponse] + + // Default is the documentation of responses other than the ones declared for specific HTTP response codes. + Default *ReferencedResponse + // Extensions provides a list of extensions to the Responses object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.Responses] = (*Responses)(nil) + +// NewResponses creates a new Responses object with an initialized map. +func NewResponses() *Responses { + return &Responses{ + Map: sequencedmap.New[string, *ReferencedResponse](), + } +} + +// GetDefault returns the value of the Default field. Returns nil if not set. +func (r *Responses) GetDefault() *ReferencedResponse { + if r == nil { + return nil + } + return r.Default +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (r *Responses) GetExtensions() *extensions.Extensions { + if r == nil || r.Extensions == nil { + return extensions.New() + } + return r.Extensions +} + +// Validate validates the Responses object against the Swagger Specification. +func (r *Responses) Validate(ctx context.Context, opts ...validation.Option) []error { + c := r.GetCore() + errs := []error{} + + // Responses object must contain at least one response code + hasResponse := (c.Default.Present && r.Default != nil) || (r.Map != nil && r.Len() > 0) + if !hasResponse { + errs = append(errs, validation.NewValueError( + validation.NewMissingValueError("responses must contain at least one response code or default"), + c, c.Default)) + } + + if c.Default.Present && r.Default != nil { + errs = append(errs, r.Default.Validate(ctx, opts...)...) + } + + for _, response := range r.All() { + errs = append(errs, response.Validate(ctx, opts...)...) + } + + r.Valid = len(errs) == 0 && c.GetValid() + + return errs +} + +// Response describes a single response from an API operation. +type Response struct { + marshaller.Model[core.Response] + + // Description is a short description of the response. + Description string + // Schema is a definition of the response structure. + Schema *oas3.JSONSchema[oas3.Referenceable] + // Headers is a list of headers that are sent with the response. + Headers *sequencedmap.Map[string, *Header] + // Examples is an example of the response message. + Examples *sequencedmap.Map[string, values.Value] + // Extensions provides a list of extensions to the Response object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.Response] = (*Response)(nil) + +// GetDescription returns the value of the Description field. Returns empty string if not set. +func (r *Response) GetDescription() string { + if r == nil { + return "" + } + return r.Description +} + +// GetSchema returns the value of the Schema field. Returns nil if not set. +func (r *Response) GetSchema() *oas3.JSONSchema[oas3.Referenceable] { + if r == nil { + return nil + } + return r.Schema +} + +// GetHeaders returns the value of the Headers field. Returns nil if not set. +func (r *Response) GetHeaders() *sequencedmap.Map[string, *Header] { + if r == nil { + return nil + } + return r.Headers +} + +// GetExamples returns the value of the Examples field. Returns nil if not set. +func (r *Response) GetExamples() *sequencedmap.Map[string, values.Value] { + if r == nil { + return nil + } + return r.Examples +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (r *Response) GetExtensions() *extensions.Extensions { + if r == nil || r.Extensions == nil { + return extensions.New() + } + return r.Extensions +} + +// Validate validates the Response object against the Swagger Specification. +func (r *Response) Validate(ctx context.Context, opts ...validation.Option) []error { + c := r.GetCore() + errs := []error{} + + if c.Description.Present && r.Description == "" { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("response.description is required"), c, c.Description)) + } + + for _, header := range r.Headers.All() { + errs = append(errs, header.Validate(ctx, opts...)...) + } + + r.Valid = len(errs) == 0 && c.GetValid() + + return errs +} + +// Header describes a single header in a response. +type Header struct { + marshaller.Model[core.Header] + + // Description is a short description of the header. + Description *string + // Type is the type of the object. + Type string + // Format is the extending format for the type. + Format *string + // Items describes the type of items in the array (if type is array). + Items *Items + // CollectionFormat determines the format of the array. + CollectionFormat *CollectionFormat + // Default declares the value the server will use if none is provided. + Default values.Value + // Maximum specifies the maximum value. + Maximum *float64 + // ExclusiveMaximum specifies if maximum is exclusive. + ExclusiveMaximum *bool + // Minimum specifies the minimum value. + Minimum *float64 + // ExclusiveMinimum specifies if minimum is exclusive. + ExclusiveMinimum *bool + // MaxLength specifies the maximum length. + MaxLength *int64 + // MinLength specifies the minimum length. + MinLength *int64 + // Pattern specifies a regex pattern the string must match. + Pattern *string + // MaxItems specifies the maximum number of items in an array. + MaxItems *int64 + // MinItems specifies the minimum number of items in an array. + MinItems *int64 + // UniqueItems specifies if all items must be unique. + UniqueItems *bool + // Enum specifies a list of allowed values. + Enum []values.Value + // MultipleOf specifies the value must be a multiple of this number. + MultipleOf *float64 + + // Extensions provides a list of extensions to the Header object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.Header] = (*Header)(nil) + +// GetDescription returns the value of the Description field. Returns empty string if not set. +func (h *Header) GetDescription() string { + if h == nil || h.Description == nil { + return "" + } + return *h.Description +} + +// GetType returns the value of the Type field. Returns empty string if not set. +func (h *Header) GetType() string { + if h == nil { + return "" + } + return h.Type +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (h *Header) GetExtensions() *extensions.Extensions { + if h == nil || h.Extensions == nil { + return extensions.New() + } + return h.Extensions +} + +// Validate validates the Header object against the Swagger Specification. +func (h *Header) Validate(ctx context.Context, opts ...validation.Option) []error { + c := h.GetCore() + errs := []error{} + + if c.Type.Present && h.Type == "" { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("header.type is required"), c, c.Type)) + } else if c.Type.Present { + validTypes := []string{"string", "number", "integer", "boolean", "array"} + valid := false + for _, t := range validTypes { + if h.Type == t { + valid = true + break + } + } + if !valid { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("header.type must be one of [string, number, integer, boolean, array]"), c, c.Type)) + } + + // Array type requires items + if h.Type == "array" && !c.Items.Present { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("header.items is required when type=array"), c, c.Items)) + } + } + + // Validate items if present + if c.Items.Present && h.Items != nil { + errs = append(errs, h.Items.Validate(ctx, opts...)...) + } + + h.Valid = len(errs) == 0 && c.GetValid() + + return errs +} diff --git a/swagger/response_validate_test.go b/swagger/response_validate_test.go new file mode 100644 index 0000000..5832d92 --- /dev/null +++ b/swagger/response_validate_test.go @@ -0,0 +1,211 @@ +package swagger_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/swagger" + "github.com/stretchr/testify/require" +) + +func TestResponse_Validate_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "minimal_response", + yml: `description: Success`, + }, + { + name: "response_with_schema", + yml: `description: User response +schema: + type: object + properties: + id: + type: integer + name: + type: string`, + }, + { + name: "response_with_headers", + yml: `description: Success +headers: + X-Rate-Limit: + type: integer + description: Rate limit`, + }, + { + name: "response_with_examples", + yml: `description: Success +schema: + type: string +examples: + application/json: "example value"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var response swagger.Response + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &response) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := response.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestResponse_Validate_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "missing_description", + yml: `schema: {type: object}`, + wantErrs: []string{"response.description is missing"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var response swagger.Response + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &response) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := response.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} + +func TestHeader_Validate_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "valid_header_integer", + yml: `type: integer +description: Rate limit`, + }, + { + name: "valid_header_string", + yml: `type: string +description: Request ID`, + }, + { + name: "valid_header_array", + yml: `type: array +items: + type: string +description: Multiple values`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var header swagger.Header + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &header) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := header.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestHeader_Validate_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "missing_type", + yml: `description: Some header`, + wantErrs: []string{"header.type is missing"}, + }, + { + name: "invalid_type", + yml: `type: object +description: Invalid type`, + wantErrs: []string{"header.type must be one of"}, + }, + { + name: "array_without_items", + yml: `type: array +description: Array header`, + wantErrs: []string{"header.items is required when type=array"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var header swagger.Header + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &header) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := header.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} diff --git a/swagger/roundtrip_test.go b/swagger/roundtrip_test.go new file mode 100644 index 0000000..f7cbc4f --- /dev/null +++ b/swagger/roundtrip_test.go @@ -0,0 +1,125 @@ +package swagger_test + +import ( + "bytes" + "io" + "os" + "strings" + "testing" + + "github.com/speakeasy-api/openapi/internal/testutils" + "github.com/speakeasy-api/openapi/json" + "github.com/speakeasy-api/openapi/swagger" + "github.com/speakeasy-api/openapi/yml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +type roundTripTest struct { + name string + location string + skipRoundTrip bool + needsSanitization bool // If true we will put the input document through the go marshallers as well to reduce whitespace diffs (as the YAML library doesn't preserve whitespace well) +} + +var roundTripTests = []roundTripTest{ + { + name: "Comprehensive Test Swagger", + location: "testdata/test.swagger.json", + }, + { + name: "Swagger Petstore", + location: "https://raw.githubusercontent.com/swagger-api/swagger-ui/04224150734be88f70a0bbd3f61bbe444606b657/test/unit/core/plugins/spec/assets/petstore.json", + }, + { + name: "Twilio API", + location: "https://github.com/dreamfactorysoftware/df-service-apis/raw/0dfd7df7ae217041c642bd045461cf5ed35a548b/twilio/twilio.json", + }, + { + name: "eBay Key Management API", + location: "https://developer.ebay.com/api-docs/master/developer/key-management/openapi/2/developer_key_management_v1_oas2.yaml", + }, + { + name: "Docker Engine API", + location: "https://github.com/docker-archive/engine/raw/25381123d3483eacfa02e989381bd36939b02d1d/api/swagger.yaml", + needsSanitization: true, + }, + { + name: "DocuSign Admin API", + location: "https://github.com/docusign/OpenAPI-Specifications/raw/10056bd7c07c2f8e41f5cb382be85d31e233f1ec/admin.rest.swagger-v2.1.json", + }, +} + +func TestSwagger_RoundTrip(t *testing.T) { + t.Parallel() + for _, tt := range roundTripTests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if tt.skipRoundTrip { + t.SkipNow() + } + + ctx := t.Context() + + var r io.ReadCloser + if strings.HasPrefix(tt.location, "testdata/") { + var err error + r, err = os.Open(tt.location) + require.NoError(t, err) + } else { + var err error + r, err = testutils.DownloadFile(tt.location, "SWAGGER_CACHE_DIR", "speakeasy-api_swagger") + require.NoError(t, err) + } + defer r.Close() + + inBuf := bytes.NewBuffer([]byte{}) + tee := io.TeeReader(r, inBuf) + + s, validationErrs, err := swagger.Unmarshal(ctx, tee, swagger.WithSkipValidation()) + require.NoError(t, err) + assert.Empty(t, validationErrs) + + outBuf := bytes.NewBuffer([]byte{}) + + err = swagger.Marshal(ctx, s, outBuf) + require.NoError(t, err) + + if tt.needsSanitization { + sanitizedData, err := sanitize(inBuf.Bytes()) + require.NoError(t, err) + inBuf = bytes.NewBuffer(sanitizedData) + } + + assert.Equal(t, inBuf.String(), outBuf.String()) + }) + } +} + +func sanitize(data []byte) ([]byte, error) { + var node yaml.Node + + if err := yaml.Unmarshal(data, &node); err != nil { + return nil, err + } + + cfg := yml.GetConfigFromDoc(data, &node) + + b := bytes.NewBuffer([]byte{}) + + if cfg.OriginalFormat == yml.OutputFormatYAML { + enc := yaml.NewEncoder(b) + + enc.SetIndent(cfg.Indentation) + if err := enc.Encode(&node); err != nil { + return nil, err + } + } else { + if err := json.YAMLToJSONWithConfig(&node, cfg.IndentationStyle.ToIndent(), cfg.Indentation, true, b); err != nil { + return nil, err + } + } + + return b.Bytes(), nil +} diff --git a/swagger/security.go b/swagger/security.go new file mode 100644 index 0000000..56fa886 --- /dev/null +++ b/swagger/security.go @@ -0,0 +1,292 @@ +package swagger + +import ( + "context" + "net/url" + "strings" + + "github.com/speakeasy-api/openapi/extensions" + "github.com/speakeasy-api/openapi/internal/interfaces" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/sequencedmap" + "github.com/speakeasy-api/openapi/swagger/core" + "github.com/speakeasy-api/openapi/validation" +) + +// SecuritySchemeType represents the type of security scheme. +type SecuritySchemeType string + +const ( + // SecuritySchemeTypeBasic represents basic authentication. + SecuritySchemeTypeBasic SecuritySchemeType = "basic" + // SecuritySchemeTypeAPIKey represents API key authentication. + SecuritySchemeTypeAPIKey SecuritySchemeType = "apiKey" + // SecuritySchemeTypeOAuth2 represents OAuth2 authentication. + SecuritySchemeTypeOAuth2 SecuritySchemeType = "oauth2" +) + +// SecuritySchemeIn represents the location of the API key. +type SecuritySchemeIn string + +const ( + // SecuritySchemeInQuery represents an API key in the query string. + SecuritySchemeInQuery SecuritySchemeIn = "query" + // SecuritySchemeInHeader represents an API key in the header. + SecuritySchemeInHeader SecuritySchemeIn = "header" +) + +// OAuth2Flow represents the flow type for OAuth2. +type OAuth2Flow string + +const ( + // OAuth2FlowImplicit represents the implicit flow. + OAuth2FlowImplicit OAuth2Flow = "implicit" + // OAuth2FlowPassword represents the password flow. + OAuth2FlowPassword OAuth2Flow = "password" + // OAuth2FlowApplication represents the application flow. + OAuth2FlowApplication OAuth2Flow = "application" + // OAuth2FlowAccessCode represents the access code flow. + OAuth2FlowAccessCode OAuth2Flow = "accessCode" +) + +// SecurityScheme defines a security scheme that can be used by the operations. +type SecurityScheme struct { + marshaller.Model[core.SecurityScheme] + + // Type is the type of the security scheme. Valid values are "basic", "apiKey" or "oauth2". + Type SecuritySchemeType + // Description is a short description for security scheme. + Description *string + // Name is the name of the header or query parameter to be used (apiKey only). + Name *string + // In is the location of the API key. Valid values are "query" or "header" (apiKey only). + In *SecuritySchemeIn + // Flow is the flow used by the OAuth2 security scheme. Valid values are "implicit", "password", "application" or "accessCode" (oauth2 only). + Flow *OAuth2Flow + // AuthorizationURL is the authorization URL to be used for this flow (oauth2 "implicit" and "accessCode" only). + AuthorizationURL *string + // TokenURL is the token URL to be used for this flow (oauth2 "password", "application" and "accessCode" only). + TokenURL *string + // Scopes lists the available scopes for the OAuth2 security scheme (oauth2 only). + Scopes *sequencedmap.Map[string, string] + // Extensions provides a list of extensions to the SecurityScheme object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.SecurityScheme] = (*SecurityScheme)(nil) + +// GetType returns the value of the Type field. +func (s *SecurityScheme) GetType() SecuritySchemeType { + if s == nil { + return "" + } + return s.Type +} + +// GetDescription returns the value of the Description field. Returns empty string if not set. +func (s *SecurityScheme) GetDescription() string { + if s == nil || s.Description == nil { + return "" + } + return *s.Description +} + +// GetName returns the value of the Name field. Returns empty string if not set. +func (s *SecurityScheme) GetName() string { + if s == nil || s.Name == nil { + return "" + } + return *s.Name +} + +// GetIn returns the value of the In field. Returns empty string if not set. +func (s *SecurityScheme) GetIn() SecuritySchemeIn { + if s == nil || s.In == nil { + return "" + } + return *s.In +} + +// GetFlow returns the value of the Flow field. Returns empty string if not set. +func (s *SecurityScheme) GetFlow() OAuth2Flow { + if s == nil || s.Flow == nil { + return "" + } + return *s.Flow +} + +// GetAuthorizationURL returns the value of the AuthorizationURL field. Returns empty string if not set. +func (s *SecurityScheme) GetAuthorizationURL() string { + if s == nil || s.AuthorizationURL == nil { + return "" + } + return *s.AuthorizationURL +} + +// GetTokenURL returns the value of the TokenURL field. Returns empty string if not set. +func (s *SecurityScheme) GetTokenURL() string { + if s == nil || s.TokenURL == nil { + return "" + } + return *s.TokenURL +} + +// GetScopes returns the value of the Scopes field. Returns nil if not set. +func (s *SecurityScheme) GetScopes() *sequencedmap.Map[string, string] { + if s == nil { + return nil + } + return s.Scopes +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (s *SecurityScheme) GetExtensions() *extensions.Extensions { + if s == nil || s.Extensions == nil { + return extensions.New() + } + return s.Extensions +} + +// Validate validates the SecurityScheme object against the Swagger Specification. +func (s *SecurityScheme) Validate(ctx context.Context, opts ...validation.Option) []error { + c := s.GetCore() + errs := []error{} + + if c.Type.Present && s.Type == "" { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("securityScheme.type is required"), c, c.Type)) + } else { + validTypes := []SecuritySchemeType{SecuritySchemeTypeBasic, SecuritySchemeTypeAPIKey, SecuritySchemeTypeOAuth2} + valid := false + for _, t := range validTypes { + if s.Type == t { + valid = true + break + } + } + if !valid { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("securityScheme.type must be one of [%s]", strings.Join([]string{string(SecuritySchemeTypeBasic), string(SecuritySchemeTypeAPIKey), string(SecuritySchemeTypeOAuth2)}, ", ")), c, c.Type)) + } + } + + // Validate apiKey specific fields + if s.Type == SecuritySchemeTypeAPIKey { + if !c.Name.Present || s.Name == nil || *s.Name == "" { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("securityScheme.name is required for type=apiKey"), c, c.Name)) + } + if !c.In.Present || s.In == nil { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("securityScheme.in is required for type=apiKey"), c, c.In)) + } else if *s.In != SecuritySchemeInQuery && *s.In != SecuritySchemeInHeader { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("securityScheme.in must be one of [%s] for type=apiKey", strings.Join([]string{string(SecuritySchemeInQuery), string(SecuritySchemeInHeader)}, ", ")), c, c.In)) + } + } + + // Validate oauth2 specific fields + if s.Type == SecuritySchemeTypeOAuth2 { + if !c.Flow.Present || s.Flow == nil { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("securityScheme.flow is required for type=oauth2"), c, c.Flow)) + } else { + validFlows := []OAuth2Flow{OAuth2FlowImplicit, OAuth2FlowPassword, OAuth2FlowApplication, OAuth2FlowAccessCode} + valid := false + for _, f := range validFlows { + if *s.Flow == f { + valid = true + break + } + } + if !valid { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("securityScheme.flow must be one of [%s] for type=oauth2", strings.Join([]string{string(OAuth2FlowImplicit), string(OAuth2FlowPassword), string(OAuth2FlowApplication), string(OAuth2FlowAccessCode)}, ", ")), c, c.Flow)) + } + + if s.Flow != nil { + // authorizationUrl required for implicit and accessCode flows + if (*s.Flow == OAuth2FlowImplicit || *s.Flow == OAuth2FlowAccessCode) && (!c.AuthorizationURL.Present || s.AuthorizationURL == nil || *s.AuthorizationURL == "") { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("securityScheme.authorizationUrl is required for flow=%s", *s.Flow), c, c.AuthorizationURL)) + } + + // tokenUrl required for password, application and accessCode flows + if (*s.Flow == OAuth2FlowPassword || *s.Flow == OAuth2FlowApplication || *s.Flow == OAuth2FlowAccessCode) && (!c.TokenURL.Present || s.TokenURL == nil || *s.TokenURL == "") { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("securityScheme.tokenUrl is required for flow=%s", *s.Flow), c, c.TokenURL)) + } + } + } + + if !c.Scopes.Present { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("securityScheme.scopes is required for type=oauth2"), c, c.Scopes)) + } + } + + // Validate URLs + if c.AuthorizationURL.Present && s.AuthorizationURL != nil && *s.AuthorizationURL != "" { + if _, err := url.Parse(*s.AuthorizationURL); err != nil { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("securityScheme.authorizationUrl is not a valid uri: %s", err), c, c.AuthorizationURL)) + } + } + + if c.TokenURL.Present && s.TokenURL != nil && *s.TokenURL != "" { + if _, err := url.Parse(*s.TokenURL); err != nil { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("securityScheme.tokenUrl is not a valid uri: %s", err), c, c.TokenURL)) + } + } + + s.Valid = len(errs) == 0 && c.GetValid() + + return errs +} + +// SecurityRequirement lists the required security schemes to execute an operation. +// The keys are the names of security schemes and the values are lists of scope names. +// For non-oauth2 security schemes, the array MUST be empty. +type SecurityRequirement struct { + marshaller.Model[core.SecurityRequirement] + *sequencedmap.Map[string, []string] +} + +var _ interfaces.Model[core.SecurityRequirement] = (*SecurityRequirement)(nil) + +// NewSecurityRequirement creates a new SecurityRequirement with an initialized map. +func NewSecurityRequirement() *SecurityRequirement { + return &SecurityRequirement{ + Map: sequencedmap.New[string, []string](), + } +} + +// Validate validates the SecurityRequirement object against the Swagger Specification. +func (s *SecurityRequirement) Validate(ctx context.Context, opts ...validation.Option) []error { + c := s.GetCore() + errs := []error{} + + // Get Swagger context to access security definitions + validationOpts := validation.NewOptions(opts...) + swagger := validation.GetContextObject[Swagger](validationOpts) + + if swagger == nil || s.Map == nil { + s.Valid = c.GetValid() + return errs + } + + // Validate each security requirement + for name := range s.Keys() { + scopes, _ := s.Get(name) + + // Check that the security scheme name exists in securityDefinitions + secScheme, exists := swagger.SecurityDefinitions.Get(name) + if !exists { + errs = append(errs, validation.NewValidationError( + validation.NewValueValidationError("security requirement '%s' does not match any security scheme in securityDefinitions", name), + c.RootNode)) + continue + } + + // For non-oauth2 security schemes, the array MUST be empty + if secScheme.Type != SecuritySchemeTypeOAuth2 { + if len(scopes) > 0 { + errs = append(errs, validation.NewValidationError( + validation.NewValueValidationError("security requirement '%s' must have empty scopes array for non-oauth2 security scheme (type=%s)", name, secScheme.Type), + c.RootNode)) + } + } + } + + s.Valid = len(errs) == 0 && c.GetValid() + return errs +} diff --git a/swagger/security_validate_test.go b/swagger/security_validate_test.go new file mode 100644 index 0000000..45d2405 --- /dev/null +++ b/swagger/security_validate_test.go @@ -0,0 +1,215 @@ +package swagger_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/swagger" + "github.com/stretchr/testify/require" +) + +func TestSecurityScheme_Validate_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "valid_basic_auth", + yml: `type: basic +description: Basic authentication`, + }, + { + name: "valid_apiKey_header", + yml: `type: apiKey +name: X-API-Key +in: header +description: API key authentication`, + }, + { + name: "valid_apiKey_query", + yml: `type: apiKey +name: api_key +in: query`, + }, + { + name: "valid_oauth2_implicit", + yml: `type: oauth2 +flow: implicit +authorizationUrl: https://example.com/oauth/authorize +scopes: + read: Read access + write: Write access`, + }, + { + name: "valid_oauth2_password", + yml: `type: oauth2 +flow: password +tokenUrl: https://example.com/oauth/token +scopes: + admin: Admin access`, + }, + { + name: "valid_oauth2_application", + yml: `type: oauth2 +flow: application +tokenUrl: https://example.com/oauth/token +scopes: + api: API access`, + }, + { + name: "valid_oauth2_accessCode", + yml: `type: oauth2 +flow: accessCode +authorizationUrl: https://example.com/oauth/authorize +tokenUrl: https://example.com/oauth/token +scopes: + read: Read access + write: Write access`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var securityScheme swagger.SecurityScheme + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &securityScheme) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := securityScheme.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestSecurityScheme_Validate_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "missing_type", + yml: `description: Some security scheme`, + wantErrs: []string{"securityScheme.type is missing"}, + }, + { + name: "invalid_type", + yml: `type: invalid +description: Test`, + wantErrs: []string{"securityScheme.type must be one of"}, + }, + { + name: "apiKey_missing_name", + yml: `type: apiKey +in: header`, + wantErrs: []string{"securityScheme.name is required for type=apiKey"}, + }, + { + name: "apiKey_missing_in", + yml: `type: apiKey +name: X-API-Key`, + wantErrs: []string{"securityScheme.in is required for type=apiKey"}, + }, + { + name: "apiKey_invalid_in", + yml: `type: apiKey +name: X-API-Key +in: invalid`, + wantErrs: []string{"securityScheme.in must be one of"}, + }, + { + name: "oauth2_missing_flow", + yml: `type: oauth2 +scopes: + read: Read access`, + wantErrs: []string{"securityScheme.flow is required for type=oauth2"}, + }, + { + name: "oauth2_invalid_flow", + yml: `type: oauth2 +flow: invalid +scopes: + read: Read access`, + wantErrs: []string{"securityScheme.flow must be one of"}, + }, + { + name: "oauth2_implicit_missing_authorizationUrl", + yml: `type: oauth2 +flow: implicit +scopes: + read: Read access`, + wantErrs: []string{"securityScheme.authorizationUrl is required for flow=implicit"}, + }, + { + name: "oauth2_password_missing_tokenUrl", + yml: `type: oauth2 +flow: password +scopes: + read: Read access`, + wantErrs: []string{"securityScheme.tokenUrl is required for flow=password"}, + }, + { + name: "oauth2_accessCode_missing_authorizationUrl", + yml: `type: oauth2 +flow: accessCode +tokenUrl: https://example.com/token +scopes: + read: Read access`, + wantErrs: []string{"securityScheme.authorizationUrl is required for flow=accessCode"}, + }, + { + name: "oauth2_accessCode_missing_tokenUrl", + yml: `type: oauth2 +flow: accessCode +authorizationUrl: https://example.com/authorize +scopes: + read: Read access`, + wantErrs: []string{"securityScheme.tokenUrl is required for flow=accessCode"}, + }, + { + name: "oauth2_missing_scopes", + yml: `type: oauth2 +flow: password +tokenUrl: https://example.com/token`, + wantErrs: []string{"securityScheme.scopes is required for type=oauth2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var securityScheme swagger.SecurityScheme + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &securityScheme) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := securityScheme.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} diff --git a/swagger/swagger.go b/swagger/swagger.go new file mode 100644 index 0000000..9e1f86d --- /dev/null +++ b/swagger/swagger.go @@ -0,0 +1,331 @@ +package swagger + +import ( + "context" + "mime" + "strings" + + "github.com/speakeasy-api/openapi/extensions" + "github.com/speakeasy-api/openapi/internal/interfaces" + "github.com/speakeasy-api/openapi/jsonschema/oas3" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/sequencedmap" + "github.com/speakeasy-api/openapi/swagger/core" + "github.com/speakeasy-api/openapi/validation" +) + +// Version is the Swagger specification version supported by this package. +const Version = "2.0" + +// Swagger is the root document object for the API specification. +type Swagger struct { + marshaller.Model[core.Swagger] + + // Swagger is the version of the Swagger specification that this document uses. + Swagger string + // Info provides metadata about the API. + Info Info + // Host is the host (name or ip) serving the API. + Host *string + // BasePath is the base path on which the API is served. + BasePath *string + // Schemes is the transfer protocol of the API. + Schemes []string + // Consumes is a list of MIME types the APIs can consume. + Consumes []string + // Produces is a list of MIME types the APIs can produce. + Produces []string + // Paths is the available paths and operations for the API. + Paths *Paths + // Definitions is an object to hold data types produced and consumed by operations. + Definitions *sequencedmap.Map[string, *oas3.JSONSchema[oas3.Concrete]] + // Parameters is an object to hold parameters that can be used across operations. + Parameters *sequencedmap.Map[string, *Parameter] + // Responses is an object to hold responses that can be used across operations. + Responses *sequencedmap.Map[string, *Response] + // SecurityDefinitions are security scheme definitions that can be used across the specification. + SecurityDefinitions *sequencedmap.Map[string, *SecurityScheme] + // Security is a declaration of which security schemes are applied for the API as a whole. + Security []*SecurityRequirement + // Tags is a list of tags used by the specification with additional metadata. + Tags []*Tag + // ExternalDocs is additional external documentation. + ExternalDocs *ExternalDocumentation + // Extensions provides a list of extensions to the Swagger object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.Swagger] = (*Swagger)(nil) + +// GetSwagger returns the value of the Swagger field. Returns empty string if not set. +func (s *Swagger) GetSwagger() string { + if s == nil { + return "" + } + return s.Swagger +} + +// GetInfo returns the value of the Info field. +func (s *Swagger) GetInfo() *Info { + if s == nil { + return nil + } + return &s.Info +} + +// GetHost returns the value of the Host field. Returns empty string if not set. +func (s *Swagger) GetHost() string { + if s == nil || s.Host == nil { + return "" + } + return *s.Host +} + +// GetBasePath returns the value of the BasePath field. Returns empty string if not set. +func (s *Swagger) GetBasePath() string { + if s == nil || s.BasePath == nil { + return "" + } + return *s.BasePath +} + +// GetSchemes returns the value of the Schemes field. Returns nil if not set. +func (s *Swagger) GetSchemes() []string { + if s == nil { + return nil + } + return s.Schemes +} + +// GetConsumes returns the value of the Consumes field. Returns nil if not set. +func (s *Swagger) GetConsumes() []string { + if s == nil { + return nil + } + return s.Consumes +} + +// GetProduces returns the value of the Produces field. Returns nil if not set. +func (s *Swagger) GetProduces() []string { + if s == nil { + return nil + } + return s.Produces +} + +// GetPaths returns the value of the Paths field. Returns nil if not set. +func (s *Swagger) GetPaths() *Paths { + if s == nil { + return nil + } + return s.Paths +} + +// GetDefinitions returns the value of the Definitions field. Returns nil if not set. +func (s *Swagger) GetDefinitions() *sequencedmap.Map[string, *oas3.JSONSchema[oas3.Concrete]] { + if s == nil { + return nil + } + return s.Definitions +} + +// GetParameters returns the value of the Parameters field. Returns nil if not set. +func (s *Swagger) GetParameters() *sequencedmap.Map[string, *Parameter] { + if s == nil { + return nil + } + return s.Parameters +} + +// GetResponses returns the value of the Responses field. Returns nil if not set. +func (s *Swagger) GetResponses() *sequencedmap.Map[string, *Response] { + if s == nil { + return nil + } + return s.Responses +} + +// GetSecurityDefinitions returns the value of the SecurityDefinitions field. Returns nil if not set. +func (s *Swagger) GetSecurityDefinitions() *sequencedmap.Map[string, *SecurityScheme] { + if s == nil { + return nil + } + return s.SecurityDefinitions +} + +// GetSecurity returns the value of the Security field. Returns nil if not set. +func (s *Swagger) GetSecurity() []*SecurityRequirement { + if s == nil { + return nil + } + return s.Security +} + +// GetTags returns the value of the Tags field. Returns nil if not set. +func (s *Swagger) GetTags() []*Tag { + if s == nil { + return nil + } + return s.Tags +} + +// GetExternalDocs returns the value of the ExternalDocs field. Returns nil if not set. +func (s *Swagger) GetExternalDocs() *ExternalDocumentation { + if s == nil { + return nil + } + return s.ExternalDocs +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (s *Swagger) GetExtensions() *extensions.Extensions { + if s == nil || s.Extensions == nil { + return extensions.New() + } + return s.Extensions +} + +// Validate validates the Swagger object against the Swagger Specification. +func (s *Swagger) Validate(ctx context.Context, opts ...validation.Option) []error { + c := s.GetCore() + errs := []error{} + + if c.Swagger.Present && s.Swagger == "" { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("swagger is required"), c, c.Swagger)) + } else if c.Swagger.Present && s.Swagger != "2.0" { + errs = append(errs, validation.NewValueError(validation.NewValueValidationError("swagger must be '2.0'"), c, c.Swagger)) + } + + if c.Info.Present { + errs = append(errs, s.Info.Validate(ctx, opts...)...) + } + + // Validate basePath starts with leading slash + if c.BasePath.Present && s.BasePath != nil && *s.BasePath != "" { + if !strings.HasPrefix(*s.BasePath, "/") { + errs = append(errs, validation.NewValueError( + validation.NewValueValidationError("basePath must start with a leading slash '/'"), + c, c.BasePath)) + } + } + + // Validate schemes if present + if c.Schemes.Present { + validSchemes := []string{"http", "https", "ws", "wss"} + for _, scheme := range s.Schemes { + valid := false + for _, vs := range validSchemes { + if scheme == vs { + valid = true + break + } + } + if !valid { + errs = append(errs, validation.NewValueError( + validation.NewValueValidationError("scheme must be one of [http, https, ws, wss], got '%s'", scheme), + c, c.Schemes)) + } + } + } + + // Validate consumes MIME types + if c.Consumes.Present { + for _, mimeType := range s.Consumes { + if _, _, err := mime.ParseMediaType(mimeType); err != nil { + errs = append(errs, validation.NewValueError( + validation.NewValueValidationError("consumes contains invalid MIME type '%s': %s", mimeType, err), + c, c.Consumes)) + } + } + } + + // Validate produces MIME types + if c.Produces.Present { + for _, mimeType := range s.Produces { + if _, _, err := mime.ParseMediaType(mimeType); err != nil { + errs = append(errs, validation.NewValueError( + validation.NewValueValidationError("produces contains invalid MIME type '%s': %s", mimeType, err), + c, c.Produces)) + } + } + } + + // Pass Swagger as context for nested validation (operations, security requirements) + if c.Paths.Present && s.Paths != nil { + errs = append(errs, s.Paths.Validate(ctx, append(opts, validation.WithContextObject(s))...)...) + } + + // Validate tag names are unique + tagNames := make(map[string]bool) + for _, tag := range s.Tags { + if tag != nil && tag.Name != "" { + if tagNames[tag.Name] { + errs = append(errs, validation.NewValueError( + validation.NewValueValidationError("tag name '%s' must be unique", tag.Name), + c, c.Tags)) + } + tagNames[tag.Name] = true + } + errs = append(errs, tag.Validate(ctx, opts...)...) + } + + if c.ExternalDocs.Present && s.ExternalDocs != nil { + errs = append(errs, s.ExternalDocs.Validate(ctx, opts...)...) + } + + for _, param := range s.Parameters.All() { + errs = append(errs, param.Validate(ctx, opts...)...) + } + + for _, resp := range s.Responses.All() { + errs = append(errs, resp.Validate(ctx, opts...)...) + } + + for _, secScheme := range s.SecurityDefinitions.All() { + errs = append(errs, secScheme.Validate(ctx, opts...)...) + } + + // Pass Swagger as context for security requirement validation + for _, secReq := range s.Security { + errs = append(errs, secReq.Validate(ctx, append(opts, validation.WithContextObject(s))...)...) + } + + // Validate operationId uniqueness across all operations + errs = append(errs, s.validateOperationIDUniqueness(c)...) + + s.Valid = len(errs) == 0 && c.GetValid() + + return errs +} + +// validateOperationIDUniqueness validates that all operationIds are unique across the document +func (s *Swagger) validateOperationIDUniqueness(c *core.Swagger) []error { + errs := []error{} + operationIDs := make(map[string]bool) + + if s.Paths == nil { + return errs + } + + for _, pathItem := range s.Paths.All() { + if pathItem == nil { + continue + } + + for _, operation := range pathItem.All() { + if operation == nil || operation.OperationID == nil || *operation.OperationID == "" { + continue + } + + opID := *operation.OperationID + if operationIDs[opID] { + errs = append(errs, validation.NewValueError( + validation.NewValueValidationError("operationId '%s' must be unique among all operations", opID), + c, c.Paths)) + } + operationIDs[opID] = true + } + } + + return errs +} diff --git a/swagger/swagger_test.go b/swagger/swagger_test.go new file mode 100644 index 0000000..59e8965 --- /dev/null +++ b/swagger/swagger_test.go @@ -0,0 +1,173 @@ +package swagger_test + +import ( + "strings" + "testing" + + "github.com/speakeasy-api/openapi/swagger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnmarshal_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yaml string + }{ + { + name: "minimal valid swagger document", + yaml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +paths: {}`, + }, + { + name: "swagger with host and basePath", + yaml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +host: api.example.com +basePath: /v1 +paths: {}`, + }, + { + name: "swagger with schemes and consumes/produces", + yaml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +schemes: + - https + - http +consumes: + - application/json +produces: + - application/json +paths: {}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc, validationErrs, err := swagger.Unmarshal(ctx, strings.NewReader(tt.yaml)) + require.NoError(t, err, "unmarshal should succeed") + require.Empty(t, validationErrs, "should have no validation errors") + require.NotNil(t, doc, "document should not be nil") + assert.Equal(t, "2.0", doc.Swagger, "swagger version should be 2.0") + }) + } +} + +func TestUnmarshal_ValidationErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yaml string + expectedError string + }{ + { + name: "missing swagger field", + yaml: `info: + title: Test API + version: 1.0.0 +paths: {}`, + expectedError: "swagger is missing", + }, + { + name: "missing info field", + yaml: `swagger: "2.0" +paths: {}`, + expectedError: "info is missing", + }, + { + name: "missing paths field", + yaml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0`, + expectedError: "paths is missing", + }, + { + name: "missing info.title", + yaml: `swagger: "2.0" +info: + version: 1.0.0 +paths: {}`, + expectedError: "info.title is missing", + }, + { + name: "missing info.version", + yaml: `swagger: "2.0" +info: + title: Test API +paths: {}`, + expectedError: "info.version is missing", + }, + { + name: "invalid swagger version", + yaml: `swagger: "3.0" +info: + title: Test API + version: 1.0.0 +paths: {}`, + expectedError: "swagger must be '2.0'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc, validationErrs, err := swagger.Unmarshal(ctx, strings.NewReader(tt.yaml)) + require.NoError(t, err, "unmarshal should not return error") + require.NotNil(t, doc, "document should not be nil") + require.NotEmpty(t, validationErrs, "should have validation errors") + + found := false + var allErrors []string + for _, verr := range validationErrs { + allErrors = append(allErrors, verr.Error()) + if strings.Contains(verr.Error(), tt.expectedError) { + found = true + break + } + } + assert.True(t, found, "should contain expected error: %s\nGot errors: %v", tt.expectedError, allErrors) + }) + } +} + +func TestMarshal_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + doc := &swagger.Swagger{ + Swagger: swagger.Version, + Info: swagger.Info{ + Title: "Test API", + Version: "1.0.0", + }, + Paths: swagger.NewPaths(), + } + + var buf strings.Builder + err := swagger.Marshal(ctx, doc, &buf) + require.NoError(t, err, "marshal should succeed") + + expected := `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +paths: {} +` + assert.Equal(t, expected, buf.String(), "marshaled output should match expected YAML") +} diff --git a/swagger/swagger_validate_test.go b/swagger/swagger_validate_test.go new file mode 100644 index 0000000..7a4b0f8 --- /dev/null +++ b/swagger/swagger_validate_test.go @@ -0,0 +1,1155 @@ +package swagger_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/sequencedmap" + "github.com/speakeasy-api/openapi/swagger" + "github.com/speakeasy-api/openapi/validation" + "github.com/stretchr/testify/require" +) + +func TestSwagger_Validate_BasePath_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "valid_basePath_with_leading_slash", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +basePath: /v1 +paths: {}`, + }, + { + name: "valid_basePath_with_just_slash", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +basePath: / +paths: {}`, + }, + { + name: "valid_basePath_with_multiple_segments", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +basePath: /api/v1 +paths: {}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var doc swagger.Swagger + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &doc) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := doc.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestSwagger_Validate_BasePath_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "basePath_without_leading_slash", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +basePath: v1 +paths: {}`, + wantErrs: []string{"basePath must start with a leading slash"}, + }, + { + name: "basePath_with_only_text", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +basePath: api +paths: {}`, + wantErrs: []string{"basePath must start with a leading slash"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var doc swagger.Swagger + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &doc) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := doc.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} + +func TestPaths_Validate_PathKeys_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "valid_path_with_leading_slash", + yml: `/users: + get: + responses: + 200: + description: Success`, + }, + { + name: "valid_path_with_parameters", + yml: `/users/{id}: + get: + responses: + 200: + description: Success`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var paths swagger.Paths + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &paths) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := paths.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestPaths_Validate_PathKeys_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "path_without_leading_slash", + yml: `users: + get: + responses: + 200: + description: Success`, + wantErrs: []string{"must begin with a slash"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var paths swagger.Paths + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &paths) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := paths.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} + +func TestSwagger_Validate_Schemes_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "valid_http_scheme", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +schemes: + - http +paths: {}`, + }, + { + name: "valid_https_scheme", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +schemes: + - https +paths: {}`, + }, + { + name: "valid_multiple_schemes", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +schemes: + - http + - https + - ws + - wss +paths: {}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var doc swagger.Swagger + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &doc) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := doc.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestSwagger_Validate_Schemes_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "invalid_scheme_ftp", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +schemes: + - ftp +paths: {}`, + wantErrs: []string{"scheme must be one of"}, + }, + { + name: "invalid_scheme_mixed", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +schemes: + - https + - invalid +paths: {}`, + wantErrs: []string{"scheme must be one of"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var doc swagger.Swagger + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &doc) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := doc.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} + +func TestOperation_Validate_Schemes_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "valid_operation_schemes", + yml: `schemes: + - https +responses: + 200: + description: Success`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var op swagger.Operation + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &op) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := op.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestOperation_Validate_Schemes_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "invalid_operation_scheme", + yml: `schemes: + - invalid +responses: + 200: + description: Success`, + wantErrs: []string{"operation.scheme must be one of"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var op swagger.Operation + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &op) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := op.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} + +func TestSwagger_Validate_MIMETypes_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "valid_consumes", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +consumes: + - application/json + - application/xml +paths: {}`, + }, + { + name: "valid_produces", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +produces: + - application/json + - text/plain +paths: {}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var doc swagger.Swagger + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &doc) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := doc.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestSwagger_Validate_MIMETypes_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "invalid_consumes_MIME_type", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +consumes: + - "invalid mime" +paths: {}`, + wantErrs: []string{"consumes contains invalid MIME type"}, + }, + { + name: "invalid_produces_MIME_type", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +produces: + - invalid//mime +paths: {}`, + wantErrs: []string{"produces contains invalid MIME type"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var doc swagger.Swagger + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &doc) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := doc.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} + +func TestSwagger_Validate_TagNameUniqueness_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "unique_tag_names", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +tags: + - name: users + description: User operations + - name: posts + description: Post operations +paths: {}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var doc swagger.Swagger + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &doc) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := doc.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestSwagger_Validate_TagNameUniqueness_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "duplicate_tag_names", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +tags: + - name: users + description: User operations + - name: users + description: Duplicate tag +paths: {}`, + wantErrs: []string{"tag name 'users' must be unique"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var doc swagger.Swagger + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &doc) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := doc.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} + +func TestSwagger_Validate_OperationIdUniqueness_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "unique_operation_IDs", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +paths: + /users: + get: + operationId: getUsers + responses: + 200: + description: Success + /posts: + get: + operationId: getPosts + responses: + 200: + description: Success`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var doc swagger.Swagger + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &doc) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := doc.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestSwagger_Validate_OperationIdUniqueness_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "duplicate_operation_IDs", + yml: `swagger: "2.0" +info: + title: Test API + version: 1.0.0 +paths: + /users: + get: + operationId: getItems + responses: + 200: + description: Success + /posts: + get: + operationId: getItems + responses: + 200: + description: Success`, + wantErrs: []string{"operationId 'getItems' must be unique"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var doc swagger.Swagger + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &doc) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := doc.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} + +func TestResponses_Validate_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "responses_with_200", + yml: `200: + description: Success`, + }, + { + name: "responses_with_default", + yml: `default: + description: Default response`, + }, + { + name: "responses_with_multiple", + yml: `200: + description: Success +404: + description: Not found`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var responses swagger.Responses + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &responses) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := responses.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestResponses_Validate_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "empty_responses", + yml: `{}`, + wantErrs: []string{"responses must contain at least one response code"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var responses swagger.Responses + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &responses) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := responses.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} + +func TestParameter_Validate_CollectionFormat_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "collectionFormat_multi_with_query", + yml: `name: ids +in: query +type: array +items: + type: string +collectionFormat: multi`, + }, + { + name: "collectionFormat_multi_with_formData", + yml: `name: tags +in: formData +type: array +items: + type: string +collectionFormat: multi`, + }, + { + name: "collectionFormat_csv_with_path", + yml: `name: ids +in: path +required: true +type: array +items: + type: string +collectionFormat: csv`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var param swagger.Parameter + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), ¶m) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := param.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestParameter_Validate_CollectionFormat_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "collectionFormat_multi_with_path", + yml: `name: ids +in: path +required: true +type: array +items: + type: string +collectionFormat: multi`, + wantErrs: []string{"collectionFormat='multi' is only valid for in=query or in=formData"}, + }, + { + name: "collectionFormat_multi_with_header", + yml: `name: ids +in: header +type: array +items: + type: string +collectionFormat: multi`, + wantErrs: []string{"collectionFormat='multi' is only valid for in=query or in=formData"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var param swagger.Parameter + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), ¶m) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := param.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} + +func TestParameter_Validate_FileTypeConsumes_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + consumes []string + }{ + { + name: "file_parameter_with_multipart_form_data", + yml: `name: file +in: formData +type: file`, + consumes: []string{"multipart/form-data"}, + }, + { + name: "file_parameter_with_urlencoded", + yml: `name: file +in: formData +type: file`, + consumes: []string{"application/x-www-form-urlencoded"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var param swagger.Parameter + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), ¶m) + require.NoError(t, err) + require.Empty(t, validationErrs) + + // Create operation context with appropriate consumes + var operation swagger.Operation + _, err = marshaller.Unmarshal(t.Context(), bytes.NewBufferString(`consumes: + - `+tt.consumes[0]+` +responses: + 200: + description: OK`), &operation) + require.NoError(t, err) + + errs := param.Validate(t.Context(), validation.WithContextObject(&operation)) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestParameter_Validate_FileTypeConsumes_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + opYml string + wantErrs []string + }{ + { + name: "file_parameter_without_consumes", + yml: `name: file +in: formData +type: file`, + opYml: `responses: + 200: + description: OK`, + wantErrs: []string{"parameter with type=file requires operation to have consumes defined"}, + }, + { + name: "file_parameter_with_invalid_consumes", + yml: `name: file +in: formData +type: file`, + opYml: `consumes: + - application/json +responses: + 200: + description: OK`, + wantErrs: []string{"parameter with type=file requires operation consumes to be 'multipart/form-data' or 'application/x-www-form-urlencoded'"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var param swagger.Parameter + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), ¶m) + require.NoError(t, err) + require.Empty(t, validationErrs) + + // Create operation context + var operation swagger.Operation + _, err = marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.opYml), &operation) + require.NoError(t, err) + + var allErrors []error + validateErrs := param.Validate(t.Context(), validation.WithContextObject(&operation)) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} + +func TestSecurityRequirement_Validate_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "valid_oauth2_with_scopes", + yml: `oauth: + - read + - write`, + }, + { + name: "valid_apiKey_empty_scopes", + yml: `apiKey: []`, + }, + { + name: "valid_basic_empty_scopes", + yml: `basic: []`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var secReq swagger.SecurityRequirement + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &secReq) + require.NoError(t, err) + require.Empty(t, validationErrs) + + // Create Swagger context with security definitions + doc := &swagger.Swagger{ + SecurityDefinitions: sequencedmap.New( + sequencedmap.NewElem("oauth", &swagger.SecurityScheme{Type: swagger.SecuritySchemeTypeOAuth2}), + sequencedmap.NewElem("apiKey", &swagger.SecurityScheme{Type: swagger.SecuritySchemeTypeAPIKey}), + sequencedmap.NewElem("basic", &swagger.SecurityScheme{Type: swagger.SecuritySchemeTypeBasic}), + ), + } + + errs := secReq.Validate(t.Context(), validation.WithContextObject(doc)) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestSecurityRequirement_Validate_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "undefined_security_scheme", + yml: `undefined: []`, + wantErrs: []string{"security requirement 'undefined' does not match any security scheme"}, + }, + { + name: "apiKey_with_non_empty_scopes", + yml: `apiKey: ["some_scope"]`, + wantErrs: []string{"security requirement 'apiKey' must have empty scopes array for non-oauth2"}, + }, + { + name: "basic_with_non_empty_scopes", + yml: `basic: ["some_scope"]`, + wantErrs: []string{"security requirement 'basic' must have empty scopes array for non-oauth2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var secReq swagger.SecurityRequirement + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &secReq) + require.NoError(t, err) + require.Empty(t, validationErrs) + + // Create Swagger context with security definitions + doc := &swagger.Swagger{ + SecurityDefinitions: sequencedmap.New( + sequencedmap.NewElem("oauth", &swagger.SecurityScheme{Type: swagger.SecuritySchemeTypeOAuth2}), + sequencedmap.NewElem("apiKey", &swagger.SecurityScheme{Type: swagger.SecuritySchemeTypeAPIKey}), + sequencedmap.NewElem("basic", &swagger.SecurityScheme{Type: swagger.SecuritySchemeTypeBasic}), + ), + } + + var allErrors []error + validateErrs := secReq.Validate(t.Context(), validation.WithContextObject(doc)) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} diff --git a/swagger/tag.go b/swagger/tag.go new file mode 100644 index 0000000..60b6aba --- /dev/null +++ b/swagger/tag.go @@ -0,0 +1,77 @@ +package swagger + +import ( + "context" + + "github.com/speakeasy-api/openapi/extensions" + "github.com/speakeasy-api/openapi/internal/interfaces" + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/swagger/core" + "github.com/speakeasy-api/openapi/validation" +) + +// Tag allows adding metadata to a single tag that is used by operations. +type Tag struct { + marshaller.Model[core.Tag] + + // Name is the name of the tag. + Name string + // Description is a short description for the tag. GFM syntax can be used for rich text representation. + Description *string + // ExternalDocs is additional external documentation for this tag. + ExternalDocs *ExternalDocumentation + // Extensions provides a list of extensions to the Tag object. + Extensions *extensions.Extensions +} + +var _ interfaces.Model[core.Tag] = (*Tag)(nil) + +// GetName returns the value of the Name field. Returns empty string if not set. +func (t *Tag) GetName() string { + if t == nil { + return "" + } + return t.Name +} + +// GetDescription returns the value of the Description field. Returns empty string if not set. +func (t *Tag) GetDescription() string { + if t == nil || t.Description == nil { + return "" + } + return *t.Description +} + +// GetExternalDocs returns the value of the ExternalDocs field. Returns nil if not set. +func (t *Tag) GetExternalDocs() *ExternalDocumentation { + if t == nil { + return nil + } + return t.ExternalDocs +} + +// GetExtensions returns the value of the Extensions field. Returns an empty extensions map if not set. +func (t *Tag) GetExtensions() *extensions.Extensions { + if t == nil || t.Extensions == nil { + return extensions.New() + } + return t.Extensions +} + +// Validate validates the Tag object against the Swagger Specification. +func (t *Tag) Validate(ctx context.Context, opts ...validation.Option) []error { + c := t.GetCore() + errs := []error{} + + if c.Name.Present && t.Name == "" { + errs = append(errs, validation.NewValueError(validation.NewMissingValueError("tag.name is required"), c, c.Name)) + } + + if c.ExternalDocs.Present { + errs = append(errs, t.ExternalDocs.Validate(ctx, opts...)...) + } + + t.Valid = len(errs) == 0 && c.GetValid() + + return errs +} diff --git a/swagger/tag_validate_test.go b/swagger/tag_validate_test.go new file mode 100644 index 0000000..9a85a3f --- /dev/null +++ b/swagger/tag_validate_test.go @@ -0,0 +1,177 @@ +package swagger_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/speakeasy-api/openapi/marshaller" + "github.com/speakeasy-api/openapi/swagger" + "github.com/stretchr/testify/require" +) + +func TestTag_Validate_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "valid_tag_with_name_only", + yml: `name: users +description: User operations`, + }, + { + name: "valid_tag_with_external_docs", + yml: `name: pets +description: Pet operations +externalDocs: + description: Find more info here + url: https://example.com/docs`, + }, + { + name: "valid_tag_minimal", + yml: `name: minimal`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var tag swagger.Tag + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &tag) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := tag.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestTag_Validate_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "missing_name", + yml: `description: Some description`, + wantErrs: []string{"tag.name is missing"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var tag swagger.Tag + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &tag) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := tag.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} + +func TestExternalDocumentation_Validate_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + }{ + { + name: "valid_external_docs_with_description", + yml: `description: Find more info here +url: https://example.com/docs`, + }, + { + name: "valid_external_docs_minimal", + yml: `url: https://example.com`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var externalDocs swagger.ExternalDocumentation + + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &externalDocs) + require.NoError(t, err) + require.Empty(t, validationErrs) + + errs := externalDocs.Validate(t.Context()) + require.Empty(t, errs, "Expected no validation errors") + }) + } +} + +func TestExternalDocumentation_Validate_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + yml string + wantErrs []string + }{ + { + name: "missing_url", + yml: `description: Some description`, + wantErrs: []string{"externalDocumentation.url is missing"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var externalDocs swagger.ExternalDocumentation + + var allErrors []error + validationErrs, err := marshaller.Unmarshal(t.Context(), bytes.NewBufferString(tt.yml), &externalDocs) + require.NoError(t, err) + allErrors = append(allErrors, validationErrs...) + + validateErrs := externalDocs.Validate(t.Context()) + allErrors = append(allErrors, validateErrs...) + + require.NotEmpty(t, allErrors, "Expected validation errors") + + for _, wantErr := range tt.wantErrs { + found := false + for _, gotErr := range allErrors { + if gotErr != nil && strings.Contains(gotErr.Error(), wantErr) { + found = true + break + } + } + require.True(t, found, "Expected error containing '%s' not found in: %v", wantErr, allErrors) + } + }) + } +} diff --git a/swagger/testdata/test.swagger.json b/swagger/testdata/test.swagger.json new file mode 100644 index 0000000..074ff7a --- /dev/null +++ b/swagger/testdata/test.swagger.json @@ -0,0 +1,684 @@ +{ + "swagger": "2.0", + "info": { + "title": "Comprehensive Swagger Test API", + "description": "A comprehensive test document that exercises all Swagger 2.0 specification features", + "version": "1.0.0", + "termsOfService": "https://example.com/terms", + "contact": { + "name": "API Support", + "url": "https://example.com/support", + "email": "support@example.com" + }, + "license": { + "name": "Apache 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0.html" + }, + "x-custom-info": "custom info extension" + }, + "host": "api.example.com", + "basePath": "/v1", + "schemes": ["https", "http"], + "consumes": ["application/json", "application/xml"], + "produces": ["application/json", "application/xml"], + "securityDefinitions": { + "basic_auth": { + "type": "basic", + "description": "Basic HTTP authentication" + }, + "api_key": { + "type": "apiKey", + "name": "X-API-Key", + "in": "header", + "description": "API key authentication" + }, + "api_key_query": { + "type": "apiKey", + "name": "api_key", + "in": "query", + "description": "API key in query parameter" + }, + "oauth2_implicit": { + "type": "oauth2", + "flow": "implicit", + "authorizationUrl": "https://example.com/oauth/authorize", + "scopes": { + "read:items": "Read access to items", + "write:items": "Write access to items" + }, + "description": "OAuth2 implicit flow" + }, + "oauth2_password": { + "type": "oauth2", + "flow": "password", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "admin": "Admin access" + } + }, + "oauth2_application": { + "type": "oauth2", + "flow": "application", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "read:public": "Read public data" + } + }, + "oauth2_accessCode": { + "type": "oauth2", + "flow": "accessCode", + "authorizationUrl": "https://example.com/oauth/authorize", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "read:all": "Read all data", + "write:all": "Write all data" + } + } + }, + "security": [ + { + "api_key": [] + }, + { + "oauth2_implicit": ["read:items"] + } + ], + "tags": [ + { + "name": "items", + "description": "Operations on items", + "externalDocs": { + "description": "Find out more about items", + "url": "https://example.com/docs/items" + } + }, + { + "name": "users", + "description": "User management operations" + } + ], + "externalDocs": { + "description": "Find out more about our API", + "url": "https://example.com/docs" + }, + "paths": { + "/items": { + "get": { + "tags": ["items"], + "summary": "List all items", + "description": "Returns a list of items with optional filtering", + "operationId": "listItems", + "consumes": ["application/json"], + "produces": ["application/json", "application/xml"], + "parameters": [ + { + "name": "limit", + "in": "query", + "description": "Maximum number of items to return", + "required": false, + "type": "integer", + "format": "int32", + "default": 20, + "minimum": 1, + "maximum": 100 + }, + { + "name": "offset", + "in": "query", + "description": "Number of items to skip", + "type": "integer", + "format": "int32", + "default": 0, + "minimum": 0 + }, + { + "name": "tags", + "in": "query", + "description": "Filter by tags", + "type": "array", + "items": { + "type": "string" + }, + "collectionFormat": "csv" + }, + { + "name": "X-Request-ID", + "in": "header", + "description": "Unique request identifier", + "type": "string", + "pattern": "^[a-f0-9-]+$" + } + ], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/Item" + } + }, + "headers": { + "X-Rate-Limit": { + "type": "integer", + "description": "Requests per hour allowed" + }, + "X-Rate-Limit-Remaining": { + "type": "integer", + "description": "Requests remaining in the current period" + } + }, + "examples": { + "application/json": [ + { + "id": 1, + "name": "Item 1" + } + ] + } + }, + "400": { + "$ref": "#/responses/BadRequest" + }, + "401": { + "$ref": "#/responses/Unauthorized" + }, + "default": { + "description": "Unexpected error", + "schema": { + "$ref": "#/definitions/Error" + } + } + }, + "security": [ + { + "api_key": [] + }, + { + "oauth2_implicit": ["read:items"] + } + ], + "x-code-samples": [ + { + "lang": "curl", + "source": "curl -X GET https://api.example.com/v1/items" + } + ] + }, + "post": { + "tags": ["items"], + "summary": "Create a new item", + "description": "Creates a new item with the provided data", + "operationId": "createItem", + "parameters": [ + { + "name": "body", + "in": "body", + "description": "Item to create", + "required": true, + "schema": { + "$ref": "#/definitions/NewItem" + } + } + ], + "responses": { + "201": { + "description": "Item created successfully", + "schema": { + "$ref": "#/definitions/Item" + } + }, + "400": { + "$ref": "#/responses/BadRequest" + }, + "401": { + "$ref": "#/responses/Unauthorized" + } + }, + "security": [ + { + "oauth2_implicit": ["write:items"] + } + ] + } + }, + "/items/{itemId}": { + "parameters": [ + { + "$ref": "#/parameters/itemId" + } + ], + "get": { + "tags": ["items"], + "summary": "Get item by ID", + "description": "Returns a single item", + "operationId": "getItemById", + "produces": ["application/json"], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "$ref": "#/definitions/Item" + } + }, + "404": { + "description": "Item not found" + } + } + }, + "put": { + "tags": ["items"], + "summary": "Update an item", + "description": "Updates an existing item", + "operationId": "updateItem", + "parameters": [ + { + "name": "body", + "in": "body", + "description": "Updated item data", + "required": true, + "schema": { + "$ref": "#/definitions/Item" + } + } + ], + "responses": { + "200": { + "description": "Item updated successfully", + "schema": { + "$ref": "#/definitions/Item" + } + }, + "404": { + "description": "Item not found" + } + } + }, + "delete": { + "tags": ["items"], + "summary": "Delete an item", + "description": "Deletes an item", + "operationId": "deleteItem", + "responses": { + "204": { + "description": "Item deleted successfully" + }, + "404": { + "description": "Item not found" + } + } + } + }, + "/items/{itemId}/upload": { + "post": { + "tags": ["items"], + "summary": "Upload file for item", + "description": "Upload a file associated with an item", + "operationId": "uploadFile", + "consumes": ["multipart/form-data"], + "parameters": [ + { + "$ref": "#/parameters/itemId" + }, + { + "name": "file", + "in": "formData", + "description": "File to upload", + "required": true, + "type": "file" + }, + { + "name": "description", + "in": "formData", + "description": "File description", + "type": "string" + } + ], + "responses": { + "200": { + "description": "File uploaded successfully" + } + } + } + }, + "/users/{username}": { + "get": { + "tags": ["users"], + "summary": "Get user by username", + "operationId": "getUserByName", + "parameters": [ + { + "name": "username", + "in": "path", + "description": "The username", + "required": true, + "type": "string" + } + ], + "responses": { + "200": { + "description": "Successful operation", + "schema": { + "$ref": "#/definitions/User" + } + }, + "404": { + "description": "User not found" + } + } + } + }, + "/items/search": { + "get": { + "tags": ["items"], + "summary": "Search items", + "operationId": "searchItems", + "parameters": [ + { + "name": "q", + "in": "query", + "description": "Search query", + "required": true, + "type": "string", + "minLength": 1, + "maxLength": 100 + }, + { + "name": "category", + "in": "query", + "description": "Category filter", + "type": "string", + "enum": ["electronics", "clothing", "books"] + }, + { + "name": "price_range", + "in": "query", + "description": "Price range filter", + "type": "array", + "items": { + "type": "number", + "format": "double" + }, + "collectionFormat": "pipes", + "minItems": 2, + "maxItems": 2 + }, + { + "name": "in_stock", + "in": "query", + "description": "Filter by stock availability", + "type": "boolean" + } + ], + "responses": { + "200": { + "description": "Search results", + "schema": { + "type": "object", + "properties": { + "results": { + "type": "array", + "items": { + "$ref": "#/definitions/Item" + } + }, + "total": { + "type": "integer" + } + } + } + } + }, + "deprecated": true + } + }, + "/items/batch": { + "post": { + "tags": ["items"], + "summary": "Create multiple items", + "operationId": "createBatchItems", + "parameters": [ + { + "name": "items", + "in": "body", + "required": true, + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/NewItem" + }, + "minItems": 1, + "maxItems": 100 + } + } + ], + "responses": { + "201": { + "description": "Items created" + } + } + } + }, + "/auth/basic": { + "get": { + "tags": ["users"], + "summary": "Test basic auth", + "operationId": "testBasicAuth", + "responses": { + "200": { + "description": "Authentication successful" + } + }, + "security": [ + { + "basic_auth": [] + } + ] + } + } + }, + "definitions": { + "Item": { + "type": "object", + "required": ["id", "name"], + "properties": { + "id": { + "type": "integer", + "format": "int64", + "description": "Unique identifier" + }, + "name": { + "type": "string", + "description": "Item name", + "minLength": 1, + "maxLength": 255 + }, + "description": { + "type": "string", + "description": "Item description" + }, + "price": { + "type": "number", + "format": "double", + "minimum": 0, + "exclusiveMinimum": true + }, + "quantity": { + "type": "integer", + "default": 0, + "minimum": 0 + }, + "tags": { + "type": "array", + "items": { + "type": "string" + }, + "uniqueItems": true + }, + "status": { + "type": "string", + "enum": ["available", "pending", "sold"], + "default": "available" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "createdAt": { + "type": "string", + "format": "date-time" + } + }, + "xml": { + "name": "Item" + }, + "example": { + "id": 1, + "name": "Example Item", + "price": 19.99 + } + }, + "NewItem": { + "type": "object", + "required": ["name"], + "properties": { + "name": { + "type": "string", + "minLength": 1 + }, + "description": { + "type": "string" + }, + "price": { + "type": "number", + "format": "double", + "minimum": 0 + }, + "tags": { + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "User": { + "type": "object", + "required": ["id", "username"], + "properties": { + "id": { + "type": "integer", + "format": "int64" + }, + "username": { + "type": "string", + "pattern": "^[a-zA-Z0-9_]+$" + }, + "email": { + "type": "string", + "format": "email" + }, + "firstName": { + "type": "string" + }, + "lastName": { + "type": "string" + }, + "phone": { + "type": "string", + "pattern": "^\\+?[1-9]\\d{1,14}$" + }, + "userStatus": { + "type": "integer", + "format": "int32", + "description": "User Status" + } + } + }, + "Error": { + "type": "object", + "required": ["code", "message"], + "properties": { + "code": { + "type": "integer", + "format": "int32" + }, + "message": { + "type": "string" + }, + "details": { + "type": "array", + "items": { + "type": "object", + "properties": { + "field": { + "type": "string" + }, + "error": { + "type": "string" + } + } + } + } + } + } + }, + "parameters": { + "itemId": { + "name": "itemId", + "in": "path", + "description": "ID of the item", + "required": true, + "type": "integer", + "format": "int64", + "minimum": 1 + }, + "pageLimit": { + "name": "limit", + "in": "query", + "description": "Number of items to return", + "type": "integer", + "format": "int32", + "default": 20, + "minimum": 1, + "maximum": 100 + } + }, + "responses": { + "BadRequest": { + "description": "Invalid request", + "schema": { + "$ref": "#/definitions/Error" + }, + "examples": { + "application/json": { + "code": 400, + "message": "Bad Request" + } + } + }, + "Unauthorized": { + "description": "Authentication required", + "schema": { + "$ref": "#/definitions/Error" + }, + "headers": { + "WWW-Authenticate": { + "type": "string", + "description": "Authentication challenge" + } + } + }, + "NotFound": { + "description": "Resource not found", + "schema": { + "$ref": "#/definitions/Error" + } + } + }, + "x-custom-root": { + "custom-property": "value" + } +} diff --git a/swagger/testdata/walk.swagger.json b/swagger/testdata/walk.swagger.json new file mode 100644 index 0000000..f3ea77a --- /dev/null +++ b/swagger/testdata/walk.swagger.json @@ -0,0 +1,216 @@ +{ + "swagger": "2.0", + "info": { + "title": "Comprehensive Swagger API", + "version": "1.0.0", + "description": "A comprehensive Swagger API for testing walk functionality", + "termsOfService": "https://example.com/terms", + "contact": { + "name": "API Team", + "url": "https://example.com/contact", + "email": "api@example.com", + "x-contact-custom": "contact-extension" + }, + "license": { + "name": "MIT", + "url": "https://opensource.org/licenses/MIT", + "x-license-custom": "license-extension" + }, + "x-info-custom": "info-extension" + }, + "host": "api.example.com", + "basePath": "/v1", + "schemes": ["https", "http"], + "consumes": ["application/json"], + "produces": ["application/json"], + "tags": [ + { + "name": "users", + "description": "User operations", + "externalDocs": { + "url": "https://example.com/users", + "description": "User documentation" + } + }, + { + "name": "pets", + "description": "Pet operations" + } + ], + "externalDocs": { + "url": "https://example.com/docs", + "description": "Additional documentation" + }, + "security": [ + { + "apiKey": [] + } + ], + "paths": { + "/users/{id}": { + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "type": "integer", + "format": "int64", + "description": "User ID" + } + ], + "get": { + "tags": ["users"], + "summary": "Get user by ID", + "description": "Retrieve a user by their ID", + "operationId": "getUser", + "parameters": [ + { + "name": "expand", + "in": "query", + "type": "string", + "description": "Expand related resources" + } + ], + "responses": { + "200": { + "description": "Successful response", + "schema": { + "$ref": "#/definitions/User" + }, + "headers": { + "X-Rate-Limit": { + "type": "integer", + "description": "Rate limit remaining" + } + }, + "examples": { + "application/json": { + "id": 123, + "name": "John Doe" + } + } + }, + "404": { + "description": "User not found" + }, + "default": { + "description": "Error response" + } + }, + "security": [ + { + "apiKey": [] + } + ] + } + }, + "/pets": { + "get": { + "tags": ["pets"], + "summary": "List pets", + "operationId": "listPets", + "parameters": [ + { + "name": "tags", + "in": "query", + "type": "array", + "items": { + "type": "string" + }, + "collectionFormat": "csv" + } + ], + "responses": { + "200": { + "description": "Pet list", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/Pet" + } + } + } + } + } + } + }, + "definitions": { + "User": { + "type": "object", + "description": "User object", + "required": ["id", "name"], + "properties": { + "id": { + "type": "integer", + "format": "int64" + }, + "name": { + "type": "string" + }, + "email": { + "type": "string", + "format": "email" + } + } + }, + "Pet": { + "type": "object", + "required": ["id", "name"], + "properties": { + "id": { + "type": "integer", + "format": "int64" + }, + "name": { + "type": "string" + }, + "tag": { + "type": "string" + } + } + } + }, + "parameters": { + "PageParam": { + "name": "page", + "in": "query", + "type": "integer", + "description": "Page number" + } + }, + "responses": { + "ErrorResponse": { + "description": "Error response", + "schema": { + "type": "object", + "properties": { + "message": { + "type": "string" + }, + "code": { + "type": "integer" + } + } + } + } + }, + "securityDefinitions": { + "apiKey": { + "type": "apiKey", + "name": "X-API-Key", + "in": "header", + "description": "API key authentication" + }, + "oauth2": { + "type": "oauth2", + "flow": "accessCode", + "authorizationUrl": "https://example.com/oauth/authorize", + "tokenUrl": "https://example.com/oauth/token", + "scopes": { + "read": "Read access", + "write": "Write access" + } + } + }, + "x-root-custom": "root-extension" +} diff --git a/swagger/upgrade.go b/swagger/upgrade.go new file mode 100644 index 0000000..d33f780 --- /dev/null +++ b/swagger/upgrade.go @@ -0,0 +1,867 @@ +package swagger + +import ( + "context" + "fmt" + "sort" + "strings" + + "github.com/speakeasy-api/openapi/extensions" + "github.com/speakeasy-api/openapi/jsonschema/oas3" + "github.com/speakeasy-api/openapi/openapi" + "github.com/speakeasy-api/openapi/pointer" + "github.com/speakeasy-api/openapi/references" + "github.com/speakeasy-api/openapi/sequencedmap" + "github.com/speakeasy-api/openapi/values" +) + +// Upgrade converts a Swagger 2.0 document into an OpenAPI 3.0 document. +// +// The conversion performs the following major transformations: +// - swagger: "2.0" -> openapi: "3.0.0" +// - host/basePath/schemes -> servers +// - definitions -> components.schemas +// - parameters (global non-body) -> components.parameters +// - parameters (global body) -> components.requestBodies +// - responses (global) -> components.responses +// - securityDefinitions -> components.securitySchemes +// - operation parameters: +// - in: body -> requestBody with content and schema +// - in: formData -> requestBody with x-www-form-urlencoded or multipart/form-data schema +// - other parameters carried over with schema and style/explode derived from collectionFormat +// +// - responses: schema/examples -> content[mediaType].schema/example +// - Rewrites JSON Schema $ref targets from "#/definitions/..." to "#/components/schemas/..." +func Upgrade(ctx context.Context, src *Swagger) (*openapi.OpenAPI, error) { + if src == nil { + return nil, nil + } + + dst := &openapi.OpenAPI{ + OpenAPI: "3.0.0", + Info: convertInfo(src.Info), + Tags: convertTags(src.Tags), + } + + // Servers + dst.Servers = buildServers(src) + + // Paths + dst.Paths = convertPaths(src) + + // Components (only set when any sub-section is non-nil) + compSchemas := convertDefinitions(src.Definitions) + compParams := convertGlobalParameters(src) + compReqBodies := convertGlobalRequestBodies(src) + compResponses := convertGlobalResponses(src) + compSecSchemes := convertSecuritySchemes(src) + + if compSchemas != nil || compParams != nil || compReqBodies != nil || compResponses != nil || compSecSchemes != nil { + dst.Components = &openapi.Components{ + Schemas: compSchemas, + Parameters: compParams, + RequestBodies: compReqBodies, + Responses: compResponses, + SecuritySchemes: compSecSchemes, + } + } + + // Security requirements + dst.Security = convertSecurityRequirements(src) + + // External docs (root) + // Swagger ExternalDocs type differs from OpenAPI's; skip explicit convert at root + // (Operation-level externalDocs handled similarly) + + // Rewrite schema $refs from "#/definitions/" -> "#/components/schemas/" + rewriteRefTargets(ctx, dst) + + return dst, nil +} + +func convertInfo(src Info) openapi.Info { + return openapi.Info{ + Title: src.Title, + Version: src.Version, + Description: src.Description, + TermsOfService: func() *string { + if src.TermsOfService == nil || *src.TermsOfService == "" { + return nil + } + return src.TermsOfService + }(), + Contact: convertInfoContact(src.Contact), + License: convertInfoLicense(src.License), + Extensions: func() *extensions.Extensions { + if src.Extensions == nil { + return nil + } + ext := extensions.New() + _ = ext.Populate(src.Extensions) + return ext + }(), + } +} + +func convertInfoContact(src *Contact) *openapi.Contact { + if src == nil { + return nil + } + return &openapi.Contact{ + Name: src.Name, + URL: src.URL, + Email: src.Email, + Extensions: copyExtensions(src.Extensions), + } +} + +func convertInfoLicense(src *License) *openapi.License { + if src == nil { + return nil + } + return &openapi.License{ + Name: src.Name, + URL: src.URL, + Extensions: copyExtensions(src.Extensions), + } +} + +func copyExtensions(src *extensions.Extensions) *extensions.Extensions { + if src == nil { + return nil + } + dst := extensions.New() + _ = dst.Populate(src) + return dst +} + +func convertTags(src []*Tag) []*openapi.Tag { + if len(src) == 0 { + return nil + } + out := make([]*openapi.Tag, 0, len(src)) + for _, t := range src { + if t == nil { + continue + } + out = append(out, &openapi.Tag{ + Name: t.Name, + Description: t.Description, + ExternalDocs: func() *oas3.ExternalDocumentation { + // Swagger Tag has ExternalDocs type swagger.ExternalDocumentation; omit mapping for now + return nil + }(), + Extensions: copyExtensions(t.Extensions), + }) + } + return out +} + +func buildServers(src *Swagger) []*openapi.Server { + host := src.GetHost() + basePath := src.GetBasePath() + schemes := src.GetSchemes() + + if host == "" { + // No absolute server configured; rely on default "/" server (OpenAPI.GetServers fallback) + return nil + } + + pathsuffix := basePath + if pathsuffix == "" { + pathsuffix = "/" + } + pathsuffix = ensureLeadingSlash(pathsuffix) + + if len(schemes) == 0 { + // Default to https if host present + schemes = []string{"https"} + } + + var servers []*openapi.Server + seen := map[string]struct{}{} + for _, sch := range schemes { + url := fmt.Sprintf("%s://%s%s", sch, host, pathsuffix) + if _, ok := seen[url]; ok { + continue + } + seen[url] = struct{}{} + servers = append(servers, &openapi.Server{URL: url}) + } + + return servers +} + +func ensureLeadingSlash(s string) string { + if s == "" { + return "/" + } + if !strings.HasPrefix(s, "/") { + return "/" + s + } + return s +} + +func convertDefinitions(defs *sequencedmap.Map[string, *oas3.JSONSchema[oas3.Concrete]]) *sequencedmap.Map[string, *oas3.JSONSchema[oas3.Referenceable]] { + if defs == nil || defs.Len() == 0 { + return nil + } + out := sequencedmap.New[string, *oas3.JSONSchema[oas3.Referenceable]]() + for name, schema := range defs.All() { + if schema == nil { + continue + } + out.Set(name, oas3.ConcreteToReferenceable(schema)) + } + return out +} + +func convertSecuritySchemes(src *Swagger) *sequencedmap.Map[string, *openapi.ReferencedSecurityScheme] { + if src.SecurityDefinitions == nil || src.SecurityDefinitions.Len() == 0 { + return nil + } + out := sequencedmap.New[string, *openapi.ReferencedSecurityScheme]() + for name, s := range src.SecurityDefinitions.All() { + if s == nil { + continue + } + dst := &openapi.SecurityScheme{ + Extensions: copyExtensions(s.Extensions), + } + switch s.Type { + case SecuritySchemeTypeBasic: + dst.Type = openapi.SecuritySchemeTypeHTTP + dst.Scheme = pointer.From("basic") + case SecuritySchemeTypeAPIKey: + dst.Type = openapi.SecuritySchemeTypeAPIKey + dst.Name = s.Name + if s.In != nil { + switch *s.In { + case SecuritySchemeInHeader: + in := openapi.SecuritySchemeInHeader + dst.In = &in + case SecuritySchemeInQuery: + in := openapi.SecuritySchemeInQuery + dst.In = &in + default: + // Swagger 2.0 doesn't support cookie for apiKey + } + } + case SecuritySchemeTypeOAuth2: + dst.Type = openapi.SecuritySchemeTypeOAuth2 + dst.Flows = convertOAuth2Flows(s) + default: + // unsupported; copy as apiKey header by default to keep spec valid minimally + dst.Type = openapi.SecuritySchemeTypeAPIKey + n := pointer.From("Authorization") + dst.Name = n + in := openapi.SecuritySchemeInHeader + dst.In = &in + } + out.Set(name, openapi.NewReferencedSecuritySchemeFromSecurityScheme(dst)) + } + return out +} + +func convertOAuth2Flows(s *SecurityScheme) *openapi.OAuthFlows { + if s == nil { + return nil + } + flows := &openapi.OAuthFlows{ + Extensions: copyExtensions(s.Extensions), + } + if s.Flow == nil { + return flows + } + switch *s.Flow { + case OAuth2FlowImplicit: + flows.Implicit = &openapi.OAuthFlow{ + AuthorizationURL: s.AuthorizationURL, + TokenURL: nil, + RefreshURL: nil, + Scopes: cloneStringMap(s.Scopes), + } + case OAuth2FlowPassword: + flows.Password = &openapi.OAuthFlow{ + TokenURL: s.TokenURL, + Scopes: cloneStringMap(s.Scopes), + Extensions: nil, + } + case OAuth2FlowApplication: + flows.ClientCredentials = &openapi.OAuthFlow{ + TokenURL: s.TokenURL, + Scopes: cloneStringMap(s.Scopes), + Extensions: nil, + } + case OAuth2FlowAccessCode: + flows.AuthorizationCode = &openapi.OAuthFlow{ + AuthorizationURL: s.AuthorizationURL, + TokenURL: s.TokenURL, + Scopes: cloneStringMap(s.Scopes), + } + } + return flows +} + +func cloneStringMap(m *sequencedmap.Map[string, string]) *sequencedmap.Map[string, string] { + if m == nil || m.Len() == 0 { + return sequencedmap.New[string, string]() + } + out := sequencedmap.New[string, string]() + for k, v := range m.All() { + out.Set(k, v) + } + return out +} + +func convertSecurityRequirements(src *Swagger) []*openapi.SecurityRequirement { + if len(src.Security) == 0 { + return nil + } + var out []*openapi.SecurityRequirement + for _, req := range src.Security { + if req == nil { + continue + } + dst := openapi.NewSecurityRequirement() + for k, v := range req.All() { + dst.Set(k, v) + } + out = append(out, dst) + } + return out +} + +func convertPaths(src *Swagger) *openapi.Paths { + if src.Paths == nil || src.Paths.Len() == 0 { + return openapi.NewPaths() + } + + dst := openapi.NewPaths() + // Stable order for deterministic output + paths := make([]string, 0, src.Paths.Len()) + for p := range src.Paths.Keys() { + paths = append(paths, p) + } + sort.Strings(paths) + + for _, p := range paths { + pathItem, _ := src.Paths.Get(p) + if pathItem == nil { + continue + } + dst.Set(p, openapi.NewReferencedPathItemFromPathItem(convertPathItem(src, pathItem))) + } + + return dst +} + +func convertPathItem(root *Swagger, src *PathItem) *openapi.PathItem { + dst := openapi.NewPathItem() + // Path-level parameters (non-body only in OAS3) + for _, rp := range src.Parameters { + if rp == nil { + continue + } + if rp.IsReference() { + // Resolve reference name to decide if body or not + name := localComponentName(rp.GetReference()) + if name == "" { + continue + } + if root.Parameters != nil { + if gp, ok := root.Parameters.Get(name); ok && gp != nil { + if gp.In == ParameterInBody { + // skip; cannot put body parameter at path level; OAS3 has no path-level requestBody + continue + } + } + } + // Non-body parameter reference -> components.parameters + ref := references.Reference("#/components/parameters/" + name) + if dst.Parameters == nil { + dst.Parameters = []*openapi.ReferencedParameter{} + } + dst.Parameters = append(dst.Parameters, openapi.NewReferencedParameterFromRef(ref)) + continue + } + // Inline + if srcp := rp.GetObject(); srcp != nil && srcp.In != ParameterInBody { + if dst.Parameters == nil { + dst.Parameters = []*openapi.ReferencedParameter{} + } + dst.Parameters = append(dst.Parameters, openapi.NewReferencedParameterFromParameter(convertParameter(srcp))) + } + } + + // Operations + for method, op := range src.All() { + if op == nil { + continue + } + dst.Set(openapi.HTTPMethod(strings.ToLower(string(method))), convertOperation(root, op)) + } + + return dst +} + +func convertOperation(root *Swagger, src *Operation) *openapi.Operation { + dst := &openapi.Operation{ + OperationID: src.OperationID, + Summary: src.Summary, + Description: src.Description, + Deprecated: src.Deprecated, + Extensions: copyExtensions(src.Extensions), + Responses: openapi.NewResponses(), + } + // Only set tags if present to avoid emitting empty arrays + if len(src.Tags) > 0 { + dst.Tags = append([]string{}, src.Tags...) + } + + // Determine consumes/produces for this operation + consumes := src.Consumes + if len(consumes) == 0 { + consumes = root.Consumes + } + produces := src.Produces + if len(produces) == 0 { + produces = root.Produces + } + if len(produces) == 0 { + produces = []string{"application/json"} + } + if len(consumes) == 0 { + consumes = []string{"application/json"} + } + + // Parameters -> Parameters + RequestBody + formParams := []*Parameter{} + var bodyParam *Parameter + + for _, rp := range src.Parameters { + if rp == nil { + continue + } + if rp.IsReference() { + // Reference to global parameter + name := localComponentName(rp.GetReference()) + if name == "" { + continue + } + if root.Parameters != nil { + if gp, ok := root.Parameters.Get(name); ok && gp != nil { + switch gp.In { + case ParameterInBody: + // Use requestBodies reference + dst.RequestBody = openapi.NewReferencedRequestBodyFromRef(references.Reference("#/components/requestBodies/" + name)) + case ParameterInFormData: + formParams = append(formParams, gp) + default: + // Carry as parameter reference + if dst.Parameters == nil { + dst.Parameters = []*openapi.ReferencedParameter{} + } + dst.Parameters = append(dst.Parameters, openapi.NewReferencedParameterFromRef(references.Reference("#/components/parameters/"+name))) + } + continue + } + } + // Fallback: treat as parameter ref + if dst.Parameters == nil { + dst.Parameters = []*openapi.ReferencedParameter{} + } + dst.Parameters = append(dst.Parameters, openapi.NewReferencedParameterFromRef(references.Reference("#/components/parameters/"+name))) + continue + } + + // Inline parameter + p := rp.GetObject() + if p == nil { + continue + } + switch p.In { + case ParameterInBody: + bodyParam = p + case ParameterInFormData: + formParams = append(formParams, p) + default: + if dst.Parameters == nil { + dst.Parameters = []*openapi.ReferencedParameter{} + } + dst.Parameters = append(dst.Parameters, openapi.NewReferencedParameterFromParameter(convertParameter(p))) + } + } + + // Build requestBody from body parameter if present + if dst.RequestBody == nil && bodyParam != nil { + rb := &openapi.RequestBody{ + Description: bodyParam.Description, + Required: bodyParam.Required, + Content: sequencedmap.New[string, *openapi.MediaType](), + Extensions: nil, + } + // Create media types from consumes + for _, mt := range consumes { + mt = strings.TrimSpace(mt) + if mt == "" { + continue + } + rb.Content.Set(mt, &openapi.MediaType{ + Schema: bodyParam.Schema, + }) + } + dst.RequestBody = openapi.NewReferencedRequestBodyFromRequestBody(rb) + } + + // Build requestBody from formData if any + if dst.RequestBody == nil && len(formParams) > 0 { + mediaType := "application/x-www-form-urlencoded" + for _, fp := range formParams { + if fp.Type != nil && *fp.Type == "file" { + mediaType = "multipart/form-data" + break + } + } + + obj := &oas3.Schema{ + Type: oas3.NewTypeFromString(oas3.SchemaType("object")), + Properties: sequencedmap.New[string, *oas3.JSONSchema[oas3.Referenceable]](), + } + // required list is optional; omitted for minimal conversion + for _, fp := range formParams { + propSchema := schemaForSwaggerParamType(fp) + obj.Properties.Set(fp.Name, propSchema) + } + + rb := &openapi.RequestBody{ + Required: pointer.From(anyRequired(formParams)), + Content: sequencedmap.New(sequencedmap.NewElem(mediaType, &openapi.MediaType{Schema: oas3.NewJSONSchemaFromSchema[oas3.Referenceable](obj)})), + } + dst.RequestBody = openapi.NewReferencedRequestBodyFromRequestBody(rb) + } + + // Responses + if src.Responses != nil { + // Default + if src.Responses.Default != nil { + dst.Responses.Default = convertReferencedResponse(src.Responses.Default, produces) + } + // Codes + for code, rr := range src.Responses.All() { + dst.Responses.Set(code, convertReferencedResponse(rr, produces)) + } + } + + return dst +} + +func anyRequired(params []*Parameter) bool { + for _, p := range params { + if p.GetRequired() { + return true + } + } + return false +} + +func schemaForSwaggerParamType(p *Parameter) *oas3.JSONSchema[oas3.Referenceable] { + if p == nil { + return nil + } + switch { + case p.Type != nil && *p.Type == "array": + items := &oas3.Schema{Type: oas3.NewTypeFromString(oas3.SchemaType("string"))} + if p.Items != nil && p.Items.Type != "" { + items.Type = oas3.NewTypeFromString(oas3.SchemaType(strings.ToLower(p.Items.Type))) + } + return oas3.NewJSONSchemaFromSchema[oas3.Referenceable](&oas3.Schema{ + Type: oas3.NewTypeFromString(oas3.SchemaType("array")), + Items: oas3.NewJSONSchemaFromSchema[oas3.Referenceable](items), + }) + case p.Type != nil && *p.Type == "file": + return oas3.NewJSONSchemaFromSchema[oas3.Referenceable](&oas3.Schema{ + Type: oas3.NewTypeFromString(oas3.SchemaType("string")), + Format: pointer.From("binary"), + }) + case p.Type != nil && *p.Type != "": + return oas3.NewJSONSchemaFromSchema[oas3.Referenceable](&oas3.Schema{ + Type: oas3.NewTypeFromString(oas3.SchemaType(strings.ToLower(*p.Type))), + }) + default: + // Body parameter case should not call this; fall back to string + return oas3.NewJSONSchemaFromSchema[oas3.Referenceable](&oas3.Schema{ + Type: oas3.NewTypeFromString(oas3.SchemaType("string")), + }) + } +} + +func convertParameter(p *Parameter) *openapi.Parameter { + if p == nil { + return nil + } + dst := &openapi.Parameter{ + Name: p.Name, + Description: p.Description, + Required: p.Required, + Deprecated: nil, // Swagger 2.0 parameter doesn't have deprecated + Schema: nil, + Content: nil, + Extensions: copyExtensions(p.Extensions), + } + + // in + switch p.In { + case ParameterInQuery: + dst.In = openapi.ParameterInQuery + case ParameterInHeader: + dst.In = openapi.ParameterInHeader + case ParameterInPath: + dst.In = openapi.ParameterInPath + default: + // Cookie not in Swagger 2.0 + dst.In = openapi.ParameterInQuery + } + + // schema from type/format/items (non-body only) + if p.In != ParameterInBody { + dst.Schema = schemaForSwaggerParamType(p) + // collectionFormat -> style/explode + if p.CollectionFormat != nil { + switch *p.CollectionFormat { + case CollectionFormatMulti: + // style=form explode=true (default for query) + style := openapi.SerializationStyleForm + dst.Style = &style + dst.Explode = pointer.From(true) + case CollectionFormatCSV: + style := openapi.SerializationStyleForm + dst.Style = &style + dst.Explode = pointer.From(false) + case CollectionFormatSSV: + style := openapi.SerializationStyleSpaceDelimited + dst.Style = &style + dst.Explode = pointer.From(false) + case CollectionFormatPipes: + style := openapi.SerializationStylePipeDelimited + dst.Style = &style + dst.Explode = pointer.From(false) + default: + // tsv or unknown -> default form + explode=false + style := openapi.SerializationStyleForm + dst.Style = &style + dst.Explode = pointer.From(false) + } + } + } + + return dst +} + +func convertReferencedResponse(rr *ReferencedResponse, produces []string) *openapi.ReferencedResponse { + if rr == nil { + return nil + } + if rr.IsReference() { + // Global response reference -> components.responses + name := localComponentName(rr.GetReference()) + return openapi.NewReferencedResponseFromRef(references.Reference("#/components/responses/" + name)) + } + src := rr.GetObject() + if src == nil { + return nil + } + + dst := &openapi.Response{ + Description: src.Description, + Headers: convertResponseHeaders(src.Headers), + // Content will be created only when there is a schema to describe + Content: nil, + Extensions: copyExtensions(src.Extensions), + } + + if src.Schema != nil { + // Build content entries for each produces + if len(produces) == 0 { + produces = []string{"application/json"} + } + // Initialize content map before setting entries + if dst.Content == nil { + dst.Content = sequencedmap.New[string, *openapi.MediaType]() + } + for _, mt := range produces { + mt = strings.TrimSpace(mt) + if mt == "" { + continue + } + dst.Content.Set(mt, &openapi.MediaType{ + Schema: src.Schema, + Example: exampleForMediaType(mt, src), + }) + } + } + + return openapi.NewReferencedResponseFromResponse(dst) +} + +func exampleForMediaType(mt string, src *Response) values.Value { + if src == nil || src.Examples == nil || src.Examples.Len() == 0 { + return nil + } + // Match exact or wildcard examples + if v, ok := src.Examples.Get(mt); ok { + return v + } + // Common defaults + for _, cand := range []string{"application/json", "application/xml", "text/plain"} { + if v, ok := src.Examples.Get(cand); ok { + return v + } + } + return nil +} + +func convertResponseHeaders(hdrs *sequencedmap.Map[string, *Header]) *sequencedmap.Map[string, *openapi.ReferencedHeader] { + if hdrs == nil || hdrs.Len() == 0 { + return nil + } + out := sequencedmap.New[string, *openapi.ReferencedHeader]() + for name, h := range hdrs.All() { + if h == nil { + continue + } + dst := &openapi.Header{ + Description: h.Description, + Schema: func() *oas3.JSONSchema[oas3.Referenceable] { + // Convert simple header types + if h.Type == "" { + return nil + } + switch strings.ToLower(h.Type) { + case "array": + items := &oas3.Schema{Type: oas3.NewTypeFromString(oas3.SchemaType("string"))} + if h.Items != nil && h.Items.Type != "" { + items.Type = oas3.NewTypeFromString(oas3.SchemaType(strings.ToLower(h.Items.Type))) + } + return oas3.NewJSONSchemaFromSchema[oas3.Referenceable](&oas3.Schema{ + Type: oas3.NewTypeFromString(oas3.SchemaType("array")), + Items: oas3.NewJSONSchemaFromSchema[oas3.Referenceable](items), + }) + default: + return oas3.NewJSONSchemaFromSchema[oas3.Referenceable](&oas3.Schema{ + Type: oas3.NewTypeFromString(oas3.SchemaType(strings.ToLower(h.Type))), + Format: h.Format, + }) + } + }(), + Extensions: copyExtensions(h.Extensions), + } + out.Set(name, openapi.NewReferencedHeaderFromHeader(dst)) + } + return out +} + +func convertGlobalParameters(src *Swagger) *sequencedmap.Map[string, *openapi.ReferencedParameter] { + if src.Parameters == nil || src.Parameters.Len() == 0 { + return nil + } + out := sequencedmap.New[string, *openapi.ReferencedParameter]() + for name, p := range src.Parameters.All() { + if p == nil { + continue + } + if p.In == ParameterInBody { + // Skip; handled as requestBodies + continue + } + out.Set(name, openapi.NewReferencedParameterFromParameter(convertParameter(p))) + } + return out +} + +func convertGlobalRequestBodies(src *Swagger) *sequencedmap.Map[string, *openapi.ReferencedRequestBody] { + if src.Parameters == nil || src.Parameters.Len() == 0 { + return nil + } + var count int + out := sequencedmap.New[string, *openapi.ReferencedRequestBody]() + for name, p := range src.Parameters.All() { + if p == nil || p.In != ParameterInBody { + continue + } + count++ + rb := &openapi.RequestBody{ + Description: p.Description, + Required: p.Required, + Content: sequencedmap.New[string, *openapi.MediaType](), + } + // Use global consumes or default + consumes := src.Consumes + if len(consumes) == 0 { + consumes = []string{"application/json"} + } + for _, mt := range consumes { + mt = strings.TrimSpace(mt) + if mt == "" { + continue + } + rb.Content.Set(mt, &openapi.MediaType{ + Schema: p.Schema, + }) + } + out.Set(name, openapi.NewReferencedRequestBodyFromRequestBody(rb)) + } + if count == 0 { + return nil + } + return out +} + +func convertGlobalResponses(src *Swagger) *sequencedmap.Map[string, *openapi.ReferencedResponse] { + if src.Responses == nil || src.Responses.Len() == 0 { + return nil + } + out := sequencedmap.New[string, *openapi.ReferencedResponse]() + // Determine fallback produces + produces := src.Produces + if len(produces) == 0 { + produces = []string{"application/json"} + } + for name, r := range src.Responses.All() { + out.Set(name, convertReferencedResponse(&ReferencedResponse{Object: r}, produces)) + } + return out +} + +func localComponentName(ref references.Reference) string { + s := string(ref) + if s == "" { + return "" + } + parts := strings.Split(s, "/") + if len(parts) == 0 { + return "" + } + return parts[len(parts)-1] +} + +func rewriteRefTargets(ctx context.Context, doc *openapi.OpenAPI) { + if doc == nil { + return + } + for item := range openapi.Walk(ctx, doc) { + _ = item.Match(openapi.Matcher{ + Schema: func(js *oas3.JSONSchema[oas3.Referenceable]) error { + if js == nil || !js.IsReference() { + return nil + } + ref := string(js.GetReference()) + if strings.HasPrefix(ref, "#/definitions/") { + newRef := references.Reference(strings.Replace(ref, "#/definitions/", "#/components/schemas/", 1)) + *js = *oas3.NewJSONSchemaFromReference(newRef) + } + return nil + }, + }) + } +} diff --git a/swagger/upgrade_test.go b/swagger/upgrade_test.go new file mode 100644 index 0000000..7f82fbd --- /dev/null +++ b/swagger/upgrade_test.go @@ -0,0 +1,1506 @@ +package swagger + +import ( + "bytes" + "strings" + "testing" + + "github.com/speakeasy-api/openapi/marshaller" + "github.com/stretchr/testify/require" +) + +func TestUpgrade_MinimalSwaggerJSON_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + inputSwaggerJSON := `{ + "swagger": "2.0", + "info": { + "title": "Minimal API", + "version": "1.0.0" + }, + "paths": { + "/ping": { + "get": { + "responses": { + "200": { + "description": "ok" + } + } + } + } + } +} +` + + // Unmarshal Swagger 2.0 (JSON) + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(inputSwaggerJSON)) + require.NoError(t, err, "unmarshal should succeed") + require.Empty(t, validationErrs, "swagger should be valid") + + // Upgrade to OpenAPI 3.0 + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err, "upgrade should succeed") + require.NotNil(t, oaDoc, "openapi document should not be nil") + + // Marshal OpenAPI as JSON (match input format) + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err, "marshal should succeed") + + actualJSON := buf.String() + + expectedJSON := `{ + "openapi": "3.0.0", + "info": { + "title": "Minimal API", + "version": "1.0.0" + }, + "paths": { + "/ping": { + "get": { + "responses": { + "200": { + "description": "ok" + } + } + } + } + } +} +` + + require.Equal(t, expectedJSON, actualJSON, "upgraded OpenAPI JSON should match expected") +} + +func TestUpgrade_BodyParameter_To_RequestBody_JSON_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + inputSwaggerJSON := `{ + "swagger": "2.0", + "info": { + "title": "Body Param API", + "version": "1.0.0" + }, + "paths": { + "/users": { + "post": { + "consumes": ["application/json"], + "parameters": [ + { + "in": "body", + "name": "body", + "required": true, + "schema": { + "type": "object", + "properties": { + "name": { "type": "string" } + } + } + } + ], + "responses": { + "201": { + "description": "created" + } + } + } + } + } +} +` + + // Unmarshal Swagger 2.0 (JSON) + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(inputSwaggerJSON)) + require.NoError(t, err, "unmarshal should succeed") + require.Empty(t, validationErrs, "swagger should be valid") + + // Upgrade to OpenAPI 3.0 + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err, "upgrade should succeed") + require.NotNil(t, oaDoc, "openapi document should not be nil") + + // Preserve input format (JSON) + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err, "marshal should succeed") + + actualJSON := buf.String() + + expectedJSON := `{ + "openapi": "3.0.0", + "info": { + "title": "Body Param API", + "version": "1.0.0" + }, + "paths": { + "/users": { + "post": { + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": { + "type": "string" + } + } + } + } + }, + "required": true + }, + "responses": { + "201": { + "description": "created" + } + } + } + } + } +} +` + + require.Equal(t, expectedJSON, actualJSON, "upgraded OpenAPI JSON should map body parameter to requestBody") +} + +func TestUpgrade_Servers_FromHostBasePathSchemes_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + inputSwaggerJSON := `{ + "swagger": "2.0", + "info": { + "title": "Server API", + "version": "1.0.0" + }, + "host": "api.example.com", + "basePath": "/v1", + "schemes": ["http", "https"], + "paths": { + "/ping": { + "get": { + "responses": { + "200": { + "description": "ok" + } + } + } + } + } +} +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(inputSwaggerJSON)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actual := buf.String() + + expected := `{ + "openapi": "3.0.0", + "info": { + "title": "Server API", + "version": "1.0.0" + }, + "servers": [ + { + "url": "http://api.example.com/v1" + }, + { + "url": "https://api.example.com/v1" + } + ], + "paths": { + "/ping": { + "get": { + "responses": { + "200": { + "description": "ok" + } + } + } + } + } +} +` + + require.Equal(t, expected, actual, "servers should be constructed from host/basePath/schemes") +} + +func TestUpgrade_ResponseSchema_To_Content_WithProduces_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + input := `{ + "swagger": "2.0", + "info": { + "title": "Produces API", + "version": "1.0.0" + }, + "produces": ["application/json", "application/xml"], + "paths": { + "/things": { + "get": { + "responses": { + "200": { + "description": "ok", + "schema": { "type": "object" } + } + } + } + } + } +} +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(input)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actual := buf.String() + + expected := `{ + "openapi": "3.0.0", + "info": { + "title": "Produces API", + "version": "1.0.0" + }, + "paths": { + "/things": { + "get": { + "responses": { + "200": { + "description": "ok", + "content": { + "application/json": { + "schema": { + "type": "object" + } + }, + "application/xml": { + "schema": { + "type": "object" + } + } + } + } + } + } + } + } +} +` + + require.Equal(t, expected, actual, "response schema should be wrapped under content for each produces type") +} + +func TestUpgrade_FormData_To_RequestBody_Multipart_File_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + input := `{ + "swagger": "2.0", + "info": { + "title": "Upload API", + "version": "1.0.0" + }, + "paths": { + "/upload": { + "post": { + "parameters": [ + { "in": "formData", "name": "file", "type": "file" }, + { "in": "formData", "name": "title", "type": "string", "required": true } + ], + "responses": { + "200": { "description": "ok" } + } + } + } + } +} +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(input)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actual := buf.String() + + expected := `{ + "openapi": "3.0.0", + "info": { + "title": "Upload API", + "version": "1.0.0" + }, + "paths": { + "/upload": { + "post": { + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": { + "file": { + "type": "string", + "format": "binary" + }, + "title": { + "type": "string" + } + } + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "ok" + } + } + } + } + } +} +` + + require.Equal(t, expected, actual, "formData with file should become multipart/form-data requestBody and set required=true if any field required") +} + +func TestUpgrade_GlobalDefinitions_RefRewrite_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + input := `{ + "swagger": "2.0", + "info": { + "title": "Defs API", + "version": "1.0.0" + }, + "definitions": { + "MyModel": { + "type": "object", + "properties": { + "id": { "type": "string" } + } + } + }, + "paths": { + "/x": { + "get": { + "produces": ["application/json"], + "responses": { + "200": { + "description": "ok", + "schema": { "$ref": "#/definitions/MyModel" } + } + } + } + } + } +} +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(input)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actual := buf.String() + + expected := `{ + "openapi": "3.0.0", + "info": { + "title": "Defs API", + "version": "1.0.0" + }, + "paths": { + "/x": { + "get": { + "responses": { + "200": { + "description": "ok", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MyModel" + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "MyModel": { + "type": "object", + "properties": { + "id": { + "type": "string" + } + } + } + } + } +} +` + + require.Equal(t, expected, actual, "definition refs should be rewritten to components.schemas and schemas moved under components") +} + +func TestUpgrade_GlobalParameters_And_Responses_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + input := `{ + "swagger": "2.0", + "info": { + "title": "Globals API", + "version": "1.0.0" + }, + "parameters": { + "PageParam": { "name": "page", "in": "query", "type": "integer" }, + "BodyParam": { + "name": "body", + "in": "body", + "schema": { + "type": "object", + "properties": { + "n": { "type": "string" } + } + } + } + }, + "responses": { + "NotFound": { + "description": "not found", + "schema": { "type": "string" }, + "headers": { + "X-Rate-Limit": { "type": "integer", "format": "int32" } + } + } + }, + "paths": { + "/x": { + "get": { + "parameters": [ + { "$ref": "#/parameters/PageParam" }, + { "$ref": "#/parameters/BodyParam" } + ], + "responses": { + "404": { "$ref": "#/responses/NotFound" } + } + } + } + } +} +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(input)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actual := buf.String() + + expected := `{ + "openapi": "3.0.0", + "info": { + "title": "Globals API", + "version": "1.0.0" + }, + "paths": { + "/x": { + "get": { + "parameters": [ + { + "$ref": "#/components/parameters/PageParam" + } + ], + "requestBody": { + "$ref": "#/components/requestBodies/BodyParam" + }, + "responses": { + "404": { + "$ref": "#/components/responses/NotFound" + } + } + } + } + }, + "components": { + "responses": { + "NotFound": { + "description": "not found", + "headers": { + "X-Rate-Limit": { + "schema": { + "type": "integer", + "format": "int32" + } + } + }, + "content": { + "application/json": { + "schema": { + "type": "string" + } + } + } + } + }, + "parameters": { + "PageParam": { + "name": "page", + "in": "query", + "schema": { + "type": "integer" + } + } + }, + "requestBodies": { + "BodyParam": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "n": { + "type": "string" + } + } + } + } + } + } + } + } +} +` + + require.Equal(t, expected, actual, "globals should map to components and references updated accordingly") +} + +func TestUpgrade_Parameter_CollectionFormat_To_StyleExplode_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + input := `{ + "swagger": "2.0", + "info": { + "title": "CollectionFormat API", + "version": "1.0.0" + }, + "paths": { + "/search": { + "get": { + "parameters": [ + { + "name": "tags", + "in": "query", + "type": "array", + "collectionFormat": "csv", + "items": { "type": "string" } + } + ], + "responses": { + "200": { "description": "ok" } + } + } + } + } +} +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(input)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actual := buf.String() + + expected := `{ + "openapi": "3.0.0", + "info": { + "title": "CollectionFormat API", + "version": "1.0.0" + }, + "paths": { + "/search": { + "get": { + "parameters": [ + { + "name": "tags", + "in": "query", + "style": "form", + "explode": false, + "schema": { + "type": "array", + "items": { + "type": "string" + } + } + } + ], + "responses": { + "200": { + "description": "ok" + } + } + } + } + } +} +` + + require.Equal(t, expected, actual, "collectionFormat csv should map to style=form, explode=false with array schema") +} + +func TestUpgrade_SecurityDefinitions_To_SecuritySchemes_JSON_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + input := `{ + "swagger": "2.0", + "info": { + "title": "Security API", + "version": "1.0.0" + }, + "securityDefinitions": { + "basicAuth": { + "type": "basic" + }, + "apiKeyHeader": { + "type": "apiKey", + "name": "X-API-Key", + "in": "header" + }, + "oauth2Auth": { + "type": "oauth2", + "flow": "accessCode", + "authorizationUrl": "https://auth.example.com/authorize", + "tokenUrl": "https://auth.example.com/token", + "scopes": { + "read": "Read access", + "write": "Write access" + } + } + }, + "paths": { + "/ping": { + "get": { + "responses": { + "200": { "description": "ok" } + } + } + } + } +} +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(input)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actual := buf.String() + + // Note: Only asserting components.securitySchemes mapping and structure; path kept minimal + expected := `{ + "openapi": "3.0.0", + "info": { + "title": "Security API", + "version": "1.0.0" + }, + "paths": { + "/ping": { + "get": { + "responses": { + "200": { + "description": "ok" + } + } + } + } + }, + "components": { + "securitySchemes": { + "basicAuth": { + "type": "http", + "scheme": "basic" + }, + "apiKeyHeader": { + "type": "apiKey", + "name": "X-API-Key", + "in": "header" + }, + "oauth2Auth": { + "type": "oauth2", + "flows": { + "authorizationCode": { + "authorizationUrl": "https://auth.example.com/authorize", + "tokenUrl": "https://auth.example.com/token", + "scopes": { + "read": "Read access", + "write": "Write access" + } + } + } + } + } + } +} +` + + require.Equal(t, expected, actual, "securityDefinitions should map to components.securitySchemes with correct types/flows") +} + +func TestUpgrade_Produces_Overrides_Global_YAML_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + // Global produces application/json, but operation overrides to text/plain + inputYAML := `swagger: "2.0" +info: + title: Produces Override + version: "1.0.0" +produces: + - application/json +paths: + /data: + get: + produces: + - text/plain + responses: + "200": + description: ok + schema: + type: string +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(inputYAML)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + // Preserve YAML output + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actualYAML := buf.String() + + expectedYAML := `openapi: "3.0.0" +info: + title: "Produces Override" + version: "1.0.0" +paths: + /data: + get: + responses: + "200": + description: "ok" + content: + text/plain: + schema: + type: "string" +` + + require.Equal(t, expectedYAML, actualYAML, "operation-level produces should override global produces in response content") +} + +func TestUpgrade_FormData_UrlEncoded_YAML_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + // formData without file -> application/x-www-form-urlencoded and aggregate fields + inputYAML := `swagger: "2.0" +info: + title: Submit API + version: "1.0.0" +paths: + /submit: + post: + parameters: + - in: formData + name: a + type: string + required: true + - in: formData + name: b + type: integer + responses: + "200": + description: ok +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(inputYAML)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actualYAML := buf.String() + + expectedYAML := `openapi: "3.0.0" +info: + title: "Submit API" + version: "1.0.0" +paths: + /submit: + post: + requestBody: + content: + application/x-www-form-urlencoded: + schema: + type: "object" + properties: + a: + type: "string" + b: + type: "integer" + required: true + responses: + "200": + description: "ok" +` + + require.Equal(t, expectedYAML, actualYAML, "formData without file should map to x-www-form-urlencoded and aggregate fields in object schema") +} + +func TestUpgrade_Response_Examples_To_Content_Example_JSON_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + input := `{ + "swagger": "2.0", + "info": { + "title": "Examples API", + "version": "1.0.0" + }, + "paths": { + "/ex": { + "get": { + "produces": ["application/json"], + "responses": { + "200": { + "description": "ok", + "schema": { "type": "object", "properties": { "id": { "type": "integer" } } }, + "examples": { + "application/json": { "id": 123 } + } + } + } + } + } + } +} +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(input)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actual := buf.String() + + expected := `{ + "openapi": "3.0.0", + "info": { + "title": "Examples API", + "version": "1.0.0" + }, + "paths": { + "/ex": { + "get": { + "responses": { + "200": { + "description": "ok", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "id": { + "type": "integer" + } + } + }, + "example": {"id": 123} + } + } + } + } + } + } + } +} +` + + require.Equal(t, expected, actual, "response examples should map to content[mediaType].example") +} + +func TestUpgrade_Info_Contact_License_JSON_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + input := `{ + "swagger": "2.0", + "info": { + "title": "Info API", + "description": "API with contact and license", + "termsOfService": "https://example.com/tos", + "contact": { + "name": "Alice", + "url": "https://example.com", + "email": "alice@example.com" + }, + "license": { + "name": "MIT", + "url": "https://opensource.org/licenses/MIT" + }, + "version": "1.0.0" + }, + "paths": { + "/ping": { + "get": { + "responses": { + "200": { "description": "ok" } + } + } + } + } +} +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(input)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actual := buf.String() + + expected := `{ + "openapi": "3.0.0", + "info": { + "title": "Info API", + "version": "1.0.0", + "description": "API with contact and license", + "termsOfService": "https://example.com/tos", + "contact": { + "name": "Alice", + "url": "https://example.com", + "email": "alice@example.com" + }, + "license": { + "name": "MIT", + "url": "https://opensource.org/licenses/MIT" + } + }, + "paths": { + "/ping": { + "get": { + "responses": { + "200": { + "description": "ok" + } + } + } + } + } +} +` + + require.Equal(t, expected, actual, "info.contact/license/termsOfService should be preserved") +} + +func TestUpgrade_Tags_JSON_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + input := `{ + "swagger": "2.0", + "info": { "title": "Tags API", "version": "1.0.0" }, + "tags": [ + { "name": "users", "description": "User operations" }, + { "name": "admin", "description": "Admin operations" } + ], + "paths": { + "/ping": { + "get": { + "responses": { + "200": { "description": "ok" } + } + } + } + } +} +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(input)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actual := buf.String() + + expected := `{ + "openapi": "3.0.0", + "info": { + "title": "Tags API", + "version": "1.0.0" + }, + "tags": [ + { + "name": "users", + "description": "User operations" + }, + { + "name": "admin", + "description": "Admin operations" + } + ], + "paths": { + "/ping": { + "get": { + "responses": { + "200": { + "description": "ok" + } + } + } + } + } +} +` + + require.Equal(t, expected, actual, "root tags should be preserved") +} + +func TestUpgrade_SecurityDefinitions_AllOAuthFlows_JSON_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + input := `{ + "swagger": "2.0", + "info": { "title": "OAuth Flows API", "version": "1.0.0" }, + "securityDefinitions": { + "implicitFlow": { + "type": "oauth2", + "flow": "implicit", + "authorizationUrl": "https://auth.example.com/authorize", + "scopes": { "read": "Read access" } + }, + "passwordFlow": { + "type": "oauth2", + "flow": "password", + "tokenUrl": "https://auth.example.com/token", + "scopes": { "write": "Write access" } + }, + "applicationFlow": { + "type": "oauth2", + "flow": "application", + "tokenUrl": "https://auth.example.com/token", + "scopes": { "svc": "Service access" } + }, + "accessCodeFlow": { + "type": "oauth2", + "flow": "accessCode", + "authorizationUrl": "https://auth.example.com/authorize", + "tokenUrl": "https://auth.example.com/token", + "scopes": { "all": "All access" } + } + }, + "paths": { + "/ok": { + "get": { + "responses": { + "200": { "description": "ok" } + } + } + } + } +} +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(input)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actual := buf.String() + + expected := `{ + "openapi": "3.0.0", + "info": { + "title": "OAuth Flows API", + "version": "1.0.0" + }, + "paths": { + "/ok": { + "get": { + "responses": { + "200": { + "description": "ok" + } + } + } + } + }, + "components": { + "securitySchemes": { + "implicitFlow": { + "type": "oauth2", + "flows": { + "implicit": { + "authorizationUrl": "https://auth.example.com/authorize", + "scopes": { + "read": "Read access" + } + } + } + }, + "passwordFlow": { + "type": "oauth2", + "flows": { + "password": { + "tokenUrl": "https://auth.example.com/token", + "scopes": { + "write": "Write access" + } + } + } + }, + "applicationFlow": { + "type": "oauth2", + "flows": { + "clientCredentials": { + "tokenUrl": "https://auth.example.com/token", + "scopes": { + "svc": "Service access" + } + } + } + }, + "accessCodeFlow": { + "type": "oauth2", + "flows": { + "authorizationCode": { + "authorizationUrl": "https://auth.example.com/authorize", + "tokenUrl": "https://auth.example.com/token", + "scopes": { + "all": "All access" + } + } + } + } + } + } +} +` + + require.Equal(t, expected, actual, "all OAuth2 flows should map to OAS3 flows") +} + +func TestUpgrade_GlobalSecurityRequirements_JSON_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + input := `{ + "swagger": "2.0", + "info": { "title": "SecurityReq API", "version": "1.0.0" }, + "securityDefinitions": { + "apiKeyHeader": { + "type": "apiKey", + "name": "X-API-Key", + "in": "header" + } + }, + "security": [ + { "apiKeyHeader": [] } + ], + "paths": { + "/ok": { + "get": { + "responses": { + "200": { "description": "ok" } + } + } + } + } +} +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(input)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actual := buf.String() + + expected := `{ + "openapi": "3.0.0", + "info": { + "title": "SecurityReq API", + "version": "1.0.0" + }, + "security": [ + { + "apiKeyHeader": [] + } + ], + "paths": { + "/ok": { + "get": { + "responses": { + "200": { + "description": "ok" + } + } + } + } + }, + "components": { + "securitySchemes": { + "apiKeyHeader": { + "type": "apiKey", + "name": "X-API-Key", + "in": "header" + } + } + } +} +` + + require.Equal(t, expected, actual, "global security requirements should be preserved") +} + +func TestUpgrade_PathLevelParameters_JSON_Success(t *testing.T) { + t.Parallel() + ctx := t.Context() + + input := `{ + "swagger": "2.0", + "info": { "title": "Path Params API", "version": "1.0.0" }, + "paths": { + "/users/{id}": { + "parameters": [ + { "name": "id", "in": "path", "required": true, "type": "string" }, + { "name": "version", "in": "query", "type": "integer" } + ], + "get": { + "responses": { + "200": { "description": "ok" } + } + } + } + } +} +` + + swDoc, validationErrs, err := Unmarshal(ctx, strings.NewReader(input)) + require.NoError(t, err) + require.Empty(t, validationErrs) + + oaDoc, err := Upgrade(ctx, swDoc) + require.NoError(t, err) + require.NotNil(t, oaDoc) + + cfg := swDoc.GetCore().GetConfig() + oaDoc.GetCore().SetConfig(cfg) + + var buf bytes.Buffer + err = marshaller.Marshal(ctx, oaDoc, &buf) + require.NoError(t, err) + + actual := buf.String() + + expected := `{ + "openapi": "3.0.0", + "info": { + "title": "Path Params API", + "version": "1.0.0" + }, + "paths": { + "/users/{id}": { + "get": { + "responses": { + "200": { + "description": "ok" + } + } + }, + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "version", + "in": "query", + "schema": { + "type": "integer" + } + } + ] + } + } +} +` + + require.Equal(t, expected, actual, "path-level non-body parameters should be preserved and mapped with schema") +} diff --git a/swagger/walk.go b/swagger/walk.go new file mode 100644 index 0000000..ac69796 --- /dev/null +++ b/swagger/walk.go @@ -0,0 +1,685 @@ +package swagger + +import ( + "context" + "iter" + + "github.com/speakeasy-api/openapi/extensions" + "github.com/speakeasy-api/openapi/jsonschema/oas3" + "github.com/speakeasy-api/openapi/pointer" + "github.com/speakeasy-api/openapi/sequencedmap" +) + +// WalkItem represents a single item yielded by the Walk iterator. +type WalkItem struct { + Match MatchFunc + Location Locations + Swagger *Swagger +} + +// Walk returns an iterator that yields MatchFunc items for each model in the provided Swagger model. +// Users can iterate over the results using a for loop and break out at any time. +// When called with *Swagger, it walks the entire document. When called with other types, +// it walks from that specific component. +func Walk[T any](ctx context.Context, start *T) iter.Seq[WalkItem] { + return func(yield func(WalkItem) bool) { + if start == nil { + return + } + walkFrom(ctx, start, yield) + } +} + +// walkFrom handles walking from different starting points using type switching +func walkFrom[T any](ctx context.Context, start *T, yield func(WalkItem) bool) { + switch v := any(start).(type) { + case *Swagger: + walk(ctx, v, yield) + case *oas3.JSONSchema[oas3.Concrete]: + walkSchemaConcrete(v, []LocationContext{}, nil, yield) + case *oas3.JSONSchema[oas3.Referenceable]: + walkSchemaReferenceable(v, []LocationContext{}, nil, yield) + case *ExternalDocumentation: + walkExternalDocs(ctx, v, []LocationContext{}, nil, yield) + case *Info: + walkInfo(ctx, v, []LocationContext{}, nil, yield) + case *Contact: + yield(WalkItem{Match: getMatchFunc(v), Location: []LocationContext{}, Swagger: nil}) + case *License: + yield(WalkItem{Match: getMatchFunc(v), Location: []LocationContext{}, Swagger: nil}) + case *Tag: + walkTag(ctx, v, []LocationContext{}, nil, yield) + case *Paths: + walkPaths(ctx, v, []LocationContext{}, nil, yield) + case *PathItem: + walkPathItem(ctx, v, []LocationContext{}, nil, yield) + case *Operation: + walkOperation(ctx, v, []LocationContext{}, nil, yield) + case *ReferencedParameter: + walkReferencedParameter(ctx, v, []LocationContext{}, nil, yield) + case *Parameter: + walkParameter(ctx, v, []LocationContext{}, nil, yield) + case *Items: + walkItems(v, getMatchFunc(v), []LocationContext{}, nil, yield) + case *ReferencedResponse: + walkReferencedResponse(ctx, v, []LocationContext{}, nil, yield) + case *Response: + walkResponse(ctx, v, []LocationContext{}, nil, yield) + case *Header: + walkHeader(v, []LocationContext{}, nil, yield) + case *SecurityRequirement: + walkSecurityRequirement(ctx, v, []LocationContext{}, nil, yield) + case *SecurityScheme: + walkSecurityScheme(v, []LocationContext{}, nil, yield) + case *extensions.Extensions: + yield(WalkItem{Match: getMatchFunc(v), Location: []LocationContext{}, Swagger: nil}) + default: + yield(WalkItem{Match: getMatchFunc(start), Location: []LocationContext{}, Swagger: nil}) + } +} + +func walk(ctx context.Context, swagger *Swagger, yield func(WalkItem) bool) { + swaggerMatchFunc := getMatchFunc(swagger) + + // Visit the root Swagger document first, location nil to specify the root + if !yield(WalkItem{Match: swaggerMatchFunc, Location: nil, Swagger: swagger}) { + return + } + + // Visit each of the top level fields in turn populating their location context with field and any key/index information + loc := []LocationContext{} + + if !walkInfo(ctx, &swagger.Info, append(loc, LocationContext{ParentMatchFunc: swaggerMatchFunc, ParentField: "info"}), swagger, yield) { + return + } + + if !walkExternalDocs(ctx, swagger.ExternalDocs, append(loc, LocationContext{ParentMatchFunc: swaggerMatchFunc, ParentField: "externalDocs"}), swagger, yield) { + return + } + + if !walkTags(ctx, swagger.Tags, append(loc, LocationContext{ParentMatchFunc: swaggerMatchFunc, ParentField: "tags"}), swagger, yield) { + return + } + + if !walkPaths(ctx, swagger.Paths, append(loc, LocationContext{ParentMatchFunc: swaggerMatchFunc, ParentField: "paths"}), swagger, yield) { + return + } + + if !walkDefinitions(ctx, swagger.Definitions, append(loc, LocationContext{ParentMatchFunc: swaggerMatchFunc, ParentField: "definitions"}), swagger, yield) { + return + } + + if !walkParameters(ctx, swagger.Parameters, append(loc, LocationContext{ParentMatchFunc: swaggerMatchFunc, ParentField: "parameters"}), swagger, yield) { + return + } + + if !walkGlobalResponses(ctx, swagger.Responses, append(loc, LocationContext{ParentMatchFunc: swaggerMatchFunc, ParentField: "responses"}), swagger, yield) { + return + } + + if !walkSecurityDefinitions(ctx, swagger.SecurityDefinitions, append(loc, LocationContext{ParentMatchFunc: swaggerMatchFunc, ParentField: "securityDefinitions"}), swagger, yield) { + return + } + + if !walkSecurity(ctx, swagger.Security, append(loc, LocationContext{ParentMatchFunc: swaggerMatchFunc, ParentField: "security"}), swagger, yield) { + return + } + + // Visit Swagger Extensions + yield(WalkItem{Match: getMatchFunc(swagger.Extensions), Location: append(loc, LocationContext{ParentMatchFunc: swaggerMatchFunc, ParentField: ""}), Swagger: swagger}) +} + +func walkInfo(_ context.Context, info *Info, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if info == nil { + return true + } + + infoMatchFunc := getMatchFunc(info) + + if !yield(WalkItem{Match: infoMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + // Visit Contact and its Extensions + if info.Contact != nil { + contactMatchFunc := getMatchFunc(info.Contact) + + contactLoc := loc + contactLoc = append(contactLoc, LocationContext{ParentMatchFunc: infoMatchFunc, ParentField: "contact"}) + + if !yield(WalkItem{Match: contactMatchFunc, Location: contactLoc, Swagger: swagger}) { + return false + } + + if !yield(WalkItem{Match: getMatchFunc(info.Contact.Extensions), Location: append(contactLoc, LocationContext{ParentMatchFunc: contactMatchFunc, ParentField: ""}), Swagger: swagger}) { + return false + } + } + + // Visit License and its Extensions + if info.License != nil { + licenseMatchFunc := getMatchFunc(info.License) + + licenseLoc := loc + licenseLoc = append(licenseLoc, LocationContext{ParentMatchFunc: infoMatchFunc, ParentField: "license"}) + + if !yield(WalkItem{Match: licenseMatchFunc, Location: licenseLoc, Swagger: swagger}) { + return false + } + + if !yield(WalkItem{Match: getMatchFunc(info.License.Extensions), Location: append(licenseLoc, LocationContext{ParentMatchFunc: licenseMatchFunc, ParentField: ""}), Swagger: swagger}) { + return false + } + } + + // Visit Info Extensions + return yield(WalkItem{Match: getMatchFunc(info.Extensions), Location: append(loc, LocationContext{ParentMatchFunc: infoMatchFunc, ParentField: ""}), Swagger: swagger}) +} + +func walkExternalDocs(_ context.Context, externalDocs *ExternalDocumentation, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if externalDocs == nil { + return true + } + + externalDocsMatchFunc := getMatchFunc(externalDocs) + + if !yield(WalkItem{Match: externalDocsMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + // Visit ExternalDocs Extensions + return yield(WalkItem{Match: getMatchFunc(externalDocs.Extensions), Location: append(loc, LocationContext{ParentMatchFunc: externalDocsMatchFunc, ParentField: ""}), Swagger: swagger}) +} + +func walkTags(ctx context.Context, tags []*Tag, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if len(tags) == 0 { + return true + } + + // Get the last loc so we can set the parent index + parentLoc := loc[len(loc)-1] + loc = loc[:len(loc)-1] + + for i, tag := range tags { + parentLoc.ParentIndex = pointer.From(i) + + if !walkTag(ctx, tag, append(loc, parentLoc), swagger, yield) { + return false + } + } + return true +} + +func walkTag(ctx context.Context, tag *Tag, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if tag == nil { + return true + } + + tagMatchFunc := getMatchFunc(tag) + + if !yield(WalkItem{Match: tagMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + // Walk through external docs + if !walkExternalDocs(ctx, tag.ExternalDocs, append(loc, LocationContext{ParentMatchFunc: tagMatchFunc, ParentField: "externalDocs"}), swagger, yield) { + return false + } + + // Visit Tag Extensions + return yield(WalkItem{Match: getMatchFunc(tag.Extensions), Location: append(loc, LocationContext{ParentMatchFunc: tagMatchFunc, ParentField: ""}), Swagger: swagger}) +} + +// walkPaths walks through the paths object +func walkPaths(ctx context.Context, paths *Paths, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if paths == nil { + return true + } + + pathsMatchFunc := getMatchFunc(paths) + + if !yield(WalkItem{Match: pathsMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + for path, pathItem := range paths.All() { + if !walkPathItem(ctx, pathItem, append(loc, LocationContext{ParentMatchFunc: pathsMatchFunc, ParentKey: pointer.From(path)}), swagger, yield) { + return false + } + } + + // Visit Paths Extensions + return yield(WalkItem{Match: getMatchFunc(paths.Extensions), Location: append(loc, LocationContext{ParentMatchFunc: pathsMatchFunc, ParentField: ""}), Swagger: swagger}) +} + +// walkPathItem walks through a path item +func walkPathItem(ctx context.Context, pathItem *PathItem, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if pathItem == nil { + return true + } + + pathItemMatchFunc := getMatchFunc(pathItem) + + if !yield(WalkItem{Match: pathItemMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + // Walk through parameters + if !walkReferencedParameters(ctx, pathItem.Parameters, append(loc, LocationContext{ParentMatchFunc: pathItemMatchFunc, ParentField: "parameters"}), swagger, yield) { + return false + } + + // Walk through operations + for method, operation := range pathItem.All() { + if !walkOperation(ctx, operation, append(loc, LocationContext{ParentMatchFunc: pathItemMatchFunc, ParentKey: pointer.From(string(method))}), swagger, yield) { + return false + } + } + + // Visit PathItem Extensions + return yield(WalkItem{Match: getMatchFunc(pathItem.Extensions), Location: append(loc, LocationContext{ParentMatchFunc: pathItemMatchFunc, ParentField: ""}), Swagger: swagger}) +} + +// walkOperation walks through an operation +func walkOperation(ctx context.Context, operation *Operation, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if operation == nil { + return true + } + + operationMatchFunc := getMatchFunc(operation) + + if !yield(WalkItem{Match: operationMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + // Walk through security + if !walkSecurity(ctx, operation.Security, append(loc, LocationContext{ParentMatchFunc: operationMatchFunc, ParentField: "security"}), swagger, yield) { + return false + } + + // Walk through parameters + if !walkReferencedParameters(ctx, operation.Parameters, append(loc, LocationContext{ParentMatchFunc: operationMatchFunc, ParentField: "parameters"}), swagger, yield) { + return false + } + + // Walk through responses + if !walkOperationResponses(ctx, operation.Responses, append(loc, LocationContext{ParentMatchFunc: operationMatchFunc, ParentField: "responses"}), swagger, yield) { + return false + } + + // Walk through external docs + if !walkExternalDocs(ctx, operation.ExternalDocs, append(loc, LocationContext{ParentMatchFunc: operationMatchFunc, ParentField: "externalDocs"}), swagger, yield) { + return false + } + + // Visit Operation Extensions + return yield(WalkItem{Match: getMatchFunc(operation.Extensions), Location: append(loc, LocationContext{ParentMatchFunc: operationMatchFunc, ParentField: ""}), Swagger: swagger}) +} + +// walkReferencedParameters walks through referenced parameters +func walkReferencedParameters(ctx context.Context, parameters []*ReferencedParameter, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if len(parameters) == 0 { + return true + } + + // Get the last loc so we can set the parent index + parentLoc := loc[len(loc)-1] + loc = loc[:len(loc)-1] + + for i, parameter := range parameters { + parentLoc.ParentIndex = pointer.From(i) + + if !walkReferencedParameter(ctx, parameter, append(loc, parentLoc), swagger, yield) { + return false + } + } + return true +} + +// walkReferencedParameter walks through a referenced parameter +func walkReferencedParameter(ctx context.Context, parameter *ReferencedParameter, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if parameter == nil { + return true + } + + referencedParameterMatchFunc := getMatchFunc(parameter) + + if !yield(WalkItem{Match: referencedParameterMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + // If it's not a reference, walk the actual Parameter + if !parameter.IsReference() && parameter.Object != nil { + return walkParameter(ctx, parameter.Object, loc, swagger, yield) + } + + return true +} + +// walkParameter walks through a parameter +func walkParameter(_ context.Context, parameter *Parameter, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if parameter == nil { + return true + } + + parameterMatchFunc := getMatchFunc(parameter) + + if !yield(WalkItem{Match: parameterMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + // Walk through schema + if !walkSchemaReferenceable(parameter.Schema, append(loc, LocationContext{ParentMatchFunc: parameterMatchFunc, ParentField: "schema"}), swagger, yield) { + return false + } + + // Walk through items + if !walkItems(parameter.Items, parameterMatchFunc, append(loc, LocationContext{ParentMatchFunc: parameterMatchFunc, ParentField: "items"}), swagger, yield) { + return false + } + + // Visit Parameter Extensions + return yield(WalkItem{Match: getMatchFunc(parameter.Extensions), Location: append(loc, LocationContext{ParentMatchFunc: parameterMatchFunc, ParentField: ""}), Swagger: swagger}) +} + +// walkItems walks through items +func walkItems(items *Items, _ MatchFunc, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if items == nil { + return true + } + + itemsMatchFunc := getMatchFunc(items) + + if !yield(WalkItem{Match: itemsMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + // Walk through nested items + if !walkItems(items.Items, itemsMatchFunc, append(loc, LocationContext{ParentMatchFunc: itemsMatchFunc, ParentField: "items"}), swagger, yield) { + return false + } + + // Visit Items Extensions + return yield(WalkItem{Match: getMatchFunc(items.Extensions), Location: append(loc, LocationContext{ParentMatchFunc: itemsMatchFunc, ParentField: ""}), Swagger: swagger}) +} + +// walkOperationResponses walks through operation responses +func walkOperationResponses(ctx context.Context, responses *Responses, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if responses == nil { + return true + } + + responsesMatchFunc := getMatchFunc(responses) + + if !yield(WalkItem{Match: responsesMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + // Walk through default response + if !walkReferencedResponse(ctx, responses.Default, append(loc, LocationContext{ParentMatchFunc: responsesMatchFunc, ParentField: "default"}), swagger, yield) { + return false + } + + // Walk through status code responses + for statusCode, response := range responses.All() { + if !walkReferencedResponse(ctx, response, append(loc, LocationContext{ParentMatchFunc: responsesMatchFunc, ParentKey: pointer.From(statusCode)}), swagger, yield) { + return false + } + } + + // Visit Responses Extensions + return yield(WalkItem{Match: getMatchFunc(responses.Extensions), Location: append(loc, LocationContext{ParentMatchFunc: responsesMatchFunc, ParentField: ""}), Swagger: swagger}) +} + +// walkReferencedResponse walks through a referenced response +func walkReferencedResponse(ctx context.Context, response *ReferencedResponse, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if response == nil { + return true + } + + referencedResponseMatchFunc := getMatchFunc(response) + + if !yield(WalkItem{Match: referencedResponseMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + // If it's not a reference, walk the actual Response + if !response.IsReference() && response.Object != nil { + return walkResponse(ctx, response.Object, loc, swagger, yield) + } + + return true +} + +// walkResponse walks through a response +func walkResponse(ctx context.Context, response *Response, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if response == nil { + return true + } + + responseMatchFunc := getMatchFunc(response) + + if !yield(WalkItem{Match: responseMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + // Walk through schema + if !walkSchemaReferenceable(response.Schema, append(loc, LocationContext{ParentMatchFunc: responseMatchFunc, ParentField: "schema"}), swagger, yield) { + return false + } + + // Walk through headers + if !walkHeaders(ctx, response.Headers, append(loc, LocationContext{ParentMatchFunc: responseMatchFunc, ParentField: "headers"}), swagger, yield) { + return false + } + + // Visit Response Extensions + return yield(WalkItem{Match: getMatchFunc(response.Extensions), Location: append(loc, LocationContext{ParentMatchFunc: responseMatchFunc, ParentField: ""}), Swagger: swagger}) +} + +// walkHeaders walks through headers +func walkHeaders(_ context.Context, headers *sequencedmap.Map[string, *Header], loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if headers == nil || headers.Len() == 0 { + return true + } + + // Get the last loc so we can set the parent key + parentLoc := loc[len(loc)-1] + loc = loc[:len(loc)-1] + + for name, header := range headers.All() { + parentLoc.ParentKey = pointer.From(name) + + if !walkHeader(header, append(loc, parentLoc), swagger, yield) { + return false + } + } + return true +} + +// walkHeader walks through a header +func walkHeader(header *Header, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if header == nil { + return true + } + + headerMatchFunc := getMatchFunc(header) + + if !yield(WalkItem{Match: headerMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + // Walk through items + if !walkItems(header.Items, headerMatchFunc, append(loc, LocationContext{ParentMatchFunc: headerMatchFunc, ParentField: "items"}), swagger, yield) { + return false + } + + // Visit Header Extensions + return yield(WalkItem{Match: getMatchFunc(header.Extensions), Location: append(loc, LocationContext{ParentMatchFunc: headerMatchFunc, ParentField: ""}), Swagger: swagger}) +} + +// walkDefinitions walks through schema definitions +func walkDefinitions(_ context.Context, definitions *sequencedmap.Map[string, *oas3.JSONSchema[oas3.Concrete]], loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if definitions == nil || definitions.Len() == 0 { + return true + } + + // Get the last loc so we can set the parent key + parentLoc := loc[len(loc)-1] + loc = loc[:len(loc)-1] + + for name, schema := range definitions.All() { + parentLoc.ParentKey = pointer.From(name) + + if !walkSchemaConcrete(schema, append(loc, parentLoc), swagger, yield) { + return false + } + } + return true +} + +// walkSchemaConcrete walks through a concrete schema +func walkSchemaConcrete(schema *oas3.JSONSchema[oas3.Concrete], loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if schema == nil { + return true + } + + schemaMatchFunc := getMatchFunc(schema) + + if !yield(WalkItem{Match: schemaMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + // For Swagger, we just visit the schema itself without walking nested schemas + // since schema walking is specific to the JSON Schema implementation + return true +} + +// walkSchemaReferenceable walks through a referenceable schema +func walkSchemaReferenceable(schema *oas3.JSONSchema[oas3.Referenceable], loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if schema == nil { + return true + } + + // Convert to match func for referenceable schema + // Note: We can't use getMatchFunc directly because we only have oas3.JSONSchema[oas3.Concrete] in the registry + // For referenceable schemas, we just yield without a specific match function + if !yield(WalkItem{Match: func(m Matcher) error { + // No specific matcher for referenceable schemas + if m.Any != nil { + return m.Any(schema) + } + return nil + }, Location: loc, Swagger: swagger}) { + return false + } + + return true +} + +// walkParameters walks through global parameters +func walkParameters(ctx context.Context, parameters *sequencedmap.Map[string, *Parameter], loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if parameters == nil || parameters.Len() == 0 { + return true + } + + // Get the last loc so we can set the parent key + parentLoc := loc[len(loc)-1] + loc = loc[:len(loc)-1] + + for name, parameter := range parameters.All() { + parentLoc.ParentKey = pointer.From(name) + + if !walkParameter(ctx, parameter, append(loc, parentLoc), swagger, yield) { + return false + } + } + return true +} + +// walkGlobalResponses walks through global responses +func walkGlobalResponses(ctx context.Context, responses *sequencedmap.Map[string, *Response], loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if responses == nil || responses.Len() == 0 { + return true + } + + // Get the last loc so we can set the parent key + parentLoc := loc[len(loc)-1] + loc = loc[:len(loc)-1] + + for name, response := range responses.All() { + parentLoc.ParentKey = pointer.From(name) + + if !walkResponse(ctx, response, append(loc, parentLoc), swagger, yield) { + return false + } + } + return true +} + +// walkSecurityDefinitions walks through security definitions +func walkSecurityDefinitions(_ context.Context, securityDefinitions *sequencedmap.Map[string, *SecurityScheme], loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if securityDefinitions == nil || securityDefinitions.Len() == 0 { + return true + } + + // Get the last loc so we can set the parent key + parentLoc := loc[len(loc)-1] + loc = loc[:len(loc)-1] + + for name, securityScheme := range securityDefinitions.All() { + parentLoc.ParentKey = pointer.From(name) + + if !walkSecurityScheme(securityScheme, append(loc, parentLoc), swagger, yield) { + return false + } + } + return true +} + +// walkSecurityScheme walks through a security scheme +func walkSecurityScheme(securityScheme *SecurityScheme, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if securityScheme == nil { + return true + } + + securitySchemeMatchFunc := getMatchFunc(securityScheme) + + if !yield(WalkItem{Match: securitySchemeMatchFunc, Location: loc, Swagger: swagger}) { + return false + } + + // Visit SecurityScheme Extensions + return yield(WalkItem{Match: getMatchFunc(securityScheme.Extensions), Location: append(loc, LocationContext{ParentMatchFunc: securitySchemeMatchFunc, ParentField: ""}), Swagger: swagger}) +} + +// walkSecurity walks through security requirements +func walkSecurity(ctx context.Context, security []*SecurityRequirement, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if len(security) == 0 { + return true + } + + // Get the last loc so we can set the parent index + parentLoc := loc[len(loc)-1] + loc = loc[:len(loc)-1] + + for i, secReq := range security { + parentLoc.ParentIndex = pointer.From(i) + + if !walkSecurityRequirement(ctx, secReq, append(loc, parentLoc), swagger, yield) { + return false + } + } + return true +} + +// walkSecurityRequirement walks through a security requirement +func walkSecurityRequirement(_ context.Context, securityRequirement *SecurityRequirement, loc []LocationContext, swagger *Swagger, yield func(WalkItem) bool) bool { + if securityRequirement == nil { + return true + } + + securityRequirementMatchFunc := getMatchFunc(securityRequirement) + + return yield(WalkItem{Match: securityRequirementMatchFunc, Location: loc, Swagger: swagger}) +} diff --git a/swagger/walk_matching.go b/swagger/walk_matching.go new file mode 100644 index 0000000..3fd2aaa --- /dev/null +++ b/swagger/walk_matching.go @@ -0,0 +1,142 @@ +package swagger + +import ( + "fmt" + "reflect" + + "github.com/speakeasy-api/openapi/extensions" + "github.com/speakeasy-api/openapi/jsonschema/oas3" + walkpkg "github.com/speakeasy-api/openapi/walk" +) + +// Matcher is a struct that can be used to match specific nodes in the Swagger document. +type Matcher struct { + Swagger func(*Swagger) error + Info func(*Info) error + Contact func(*Contact) error + License func(*License) error + ExternalDocs func(*ExternalDocumentation) error + Tag func(*Tag) error + Paths func(*Paths) error + PathItem func(*PathItem) error + Operation func(*Operation) error + ReferencedParameter func(*ReferencedParameter) error + Parameter func(*Parameter) error + Schema func(*oas3.JSONSchema[oas3.Concrete]) error + Discriminator func(*oas3.Discriminator) error + XML func(*oas3.XML) error + Responses func(*Responses) error + ReferencedResponse func(*ReferencedResponse) error + Response func(*Response) error + Header func(*Header) error + Items func(*Items) error + SecurityRequirement func(*SecurityRequirement) error + SecurityScheme func(*SecurityScheme) error + Extensions func(*extensions.Extensions) error + Any func(any) error // Any will be called along with the other functions above on a match of a model +} + +// MatchFunc represents a particular model in the Swagger document that can be matched. +// Pass it a Matcher with the appropriate functions populated to match the model type(s) you are interested in. +type MatchFunc func(Matcher) error + +// Use the shared walking infrastructure +type ( + LocationContext = walkpkg.LocationContext[MatchFunc] + Locations = walkpkg.Locations[MatchFunc] +) + +type matchHandler[T any] struct { + GetSpecific func(m Matcher) func(*T) error +} + +var matchRegistry = map[reflect.Type]any{ + reflect.TypeOf((*Swagger)(nil)): matchHandler[Swagger]{ + GetSpecific: func(m Matcher) func(*Swagger) error { return m.Swagger }, + }, + reflect.TypeOf((*Info)(nil)): matchHandler[Info]{ + GetSpecific: func(m Matcher) func(*Info) error { return m.Info }, + }, + reflect.TypeOf((*Contact)(nil)): matchHandler[Contact]{ + GetSpecific: func(m Matcher) func(*Contact) error { return m.Contact }, + }, + reflect.TypeOf((*License)(nil)): matchHandler[License]{ + GetSpecific: func(m Matcher) func(*License) error { return m.License }, + }, + reflect.TypeOf((*ExternalDocumentation)(nil)): matchHandler[ExternalDocumentation]{ + GetSpecific: func(m Matcher) func(*ExternalDocumentation) error { return m.ExternalDocs }, + }, + reflect.TypeOf((*Tag)(nil)): matchHandler[Tag]{ + GetSpecific: func(m Matcher) func(*Tag) error { return m.Tag }, + }, + reflect.TypeOf((*Paths)(nil)): matchHandler[Paths]{ + GetSpecific: func(m Matcher) func(*Paths) error { return m.Paths }, + }, + reflect.TypeOf((*PathItem)(nil)): matchHandler[PathItem]{ + GetSpecific: func(m Matcher) func(*PathItem) error { return m.PathItem }, + }, + reflect.TypeOf((*Operation)(nil)): matchHandler[Operation]{ + GetSpecific: func(m Matcher) func(*Operation) error { return m.Operation }, + }, + reflect.TypeOf((*ReferencedParameter)(nil)): matchHandler[ReferencedParameter]{ + GetSpecific: func(m Matcher) func(*ReferencedParameter) error { return m.ReferencedParameter }, + }, + reflect.TypeOf((*Parameter)(nil)): matchHandler[Parameter]{ + GetSpecific: func(m Matcher) func(*Parameter) error { return m.Parameter }, + }, + reflect.TypeOf((*oas3.JSONSchema[oas3.Concrete])(nil)): matchHandler[oas3.JSONSchema[oas3.Concrete]]{ + GetSpecific: func(m Matcher) func(*oas3.JSONSchema[oas3.Concrete]) error { return m.Schema }, + }, + reflect.TypeOf((*oas3.Discriminator)(nil)): matchHandler[oas3.Discriminator]{ + GetSpecific: func(m Matcher) func(*oas3.Discriminator) error { return m.Discriminator }, + }, + reflect.TypeOf((*oas3.XML)(nil)): matchHandler[oas3.XML]{ + GetSpecific: func(m Matcher) func(*oas3.XML) error { return m.XML }, + }, + reflect.TypeOf((*Responses)(nil)): matchHandler[Responses]{ + GetSpecific: func(m Matcher) func(*Responses) error { return m.Responses }, + }, + reflect.TypeOf((*ReferencedResponse)(nil)): matchHandler[ReferencedResponse]{ + GetSpecific: func(m Matcher) func(*ReferencedResponse) error { return m.ReferencedResponse }, + }, + reflect.TypeOf((*Response)(nil)): matchHandler[Response]{ + GetSpecific: func(m Matcher) func(*Response) error { return m.Response }, + }, + reflect.TypeOf((*Header)(nil)): matchHandler[Header]{ + GetSpecific: func(m Matcher) func(*Header) error { return m.Header }, + }, + reflect.TypeOf((*Items)(nil)): matchHandler[Items]{ + GetSpecific: func(m Matcher) func(*Items) error { return m.Items }, + }, + reflect.TypeOf((*SecurityRequirement)(nil)): matchHandler[SecurityRequirement]{ + GetSpecific: func(m Matcher) func(*SecurityRequirement) error { return m.SecurityRequirement }, + }, + reflect.TypeOf((*SecurityScheme)(nil)): matchHandler[SecurityScheme]{ + GetSpecific: func(m Matcher) func(*SecurityScheme) error { return m.SecurityScheme }, + }, + reflect.TypeOf((*extensions.Extensions)(nil)): matchHandler[extensions.Extensions]{ + GetSpecific: func(m Matcher) func(*extensions.Extensions) error { return m.Extensions }, + }, +} + +func getMatchFunc[T any](target *T) MatchFunc { + t := reflect.TypeOf(target) + + h, ok := matchRegistry[t] + if !ok { + panic(fmt.Sprintf("no match handler registered for type %v", t)) + } + + handler := h.(matchHandler[T]) + return func(m Matcher) error { + if m.Any != nil { + if err := m.Any(target); err != nil { + return err + } + } + if specific := handler.GetSpecific(m); specific != nil { + return specific(target) + } + return nil + } +} diff --git a/swagger/walk_test.go b/swagger/walk_test.go new file mode 100644 index 0000000..e13ad48 --- /dev/null +++ b/swagger/walk_test.go @@ -0,0 +1,889 @@ +package swagger_test + +import ( + "context" + "errors" + "os" + "testing" + + "github.com/speakeasy-api/openapi/extensions" + "github.com/speakeasy-api/openapi/jsonschema/oas3" + "github.com/speakeasy-api/openapi/swagger" + "github.com/speakeasy-api/openapi/walk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// loadSwaggerDocument loads a fresh Swagger document for each test to ensure thread safety +func loadSwaggerDocument(ctx context.Context) (*swagger.Swagger, error) { + f, err := os.Open("testdata/walk.swagger.json") + if err != nil { + return nil, err + } + defer f.Close() + + s, validationErrs, err := swagger.Unmarshal(ctx, f) + if err != nil { + return nil, err + } + if len(validationErrs) > 0 { + return nil, errors.Join(validationErrs...) + } + + return s, nil +} + +func TestWalkSwagger_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedLoc := "/" + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Swagger: func(s *swagger.Swagger) error { + swaggerLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, swaggerLoc) + + if swaggerLoc == expectedLoc { + assert.Equal(t, "2.0", s.Swagger) + assert.Equal(t, "api.example.com", s.GetHost()) + assert.Equal(t, "/v1", s.GetBasePath()) + + return walk.ErrTerminate + } + + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Contains(t, matchedLocations, expectedLoc) +} + +func TestWalkSwagger_Extensions_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedLoc := "/" + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Extensions: func(e *extensions.Extensions) error { + extensionsLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, extensionsLoc) + + if extensionsLoc == expectedLoc { + assert.Equal(t, "root-extension", e.GetOrZero("x-root-custom").Value) + + return walk.ErrTerminate + } + + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Contains(t, matchedLocations, expectedLoc) +} + +func TestWalkInfo_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedLoc := "/info" + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Info: func(i *swagger.Info) error { + infoLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, infoLoc) + + if infoLoc == expectedLoc { + assert.Equal(t, "Comprehensive Swagger API", i.GetTitle()) + assert.Equal(t, "1.0.0", i.GetVersion()) + assert.Equal(t, "A comprehensive Swagger API for testing walk functionality", i.GetDescription()) + + return walk.ErrTerminate + } + + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Contains(t, matchedLocations, expectedLoc) +} + +func TestWalkContact_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedLoc := "/info/contact" + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Contact: func(c *swagger.Contact) error { + contactLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, contactLoc) + + if contactLoc == expectedLoc { + assert.Equal(t, "API Team", c.GetName()) + assert.Equal(t, "api@example.com", c.GetEmail()) + assert.Equal(t, "https://example.com/contact", c.GetURL()) + + return walk.ErrTerminate + } + + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Contains(t, matchedLocations, expectedLoc) +} + +func TestWalkLicense_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedLoc := "/info/license" + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + License: func(l *swagger.License) error { + licenseLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, licenseLoc) + + if licenseLoc == expectedLoc { + assert.Equal(t, "MIT", l.GetName()) + assert.Equal(t, "https://opensource.org/licenses/MIT", l.GetURL()) + + return walk.ErrTerminate + } + + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Contains(t, matchedLocations, expectedLoc) +} + +func TestWalkExternalDocs_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedAssertions := map[string]func(*swagger.ExternalDocumentation){ + "/externalDocs": func(e *swagger.ExternalDocumentation) { + assert.Equal(t, "https://example.com/docs", e.GetURL()) + assert.Equal(t, "Additional documentation", e.GetDescription()) + }, + "/tags/0/externalDocs": func(e *swagger.ExternalDocumentation) { + assert.Equal(t, "https://example.com/users", e.GetURL()) + assert.Equal(t, "User documentation", e.GetDescription()) + }, + } + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + ExternalDocs: func(e *swagger.ExternalDocumentation) error { + externalDocsLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, externalDocsLoc) + + if assertFunc, exists := expectedAssertions[externalDocsLoc]; exists { + assertFunc(e) + } + + return nil + }, + }) + require.NoError(t, err) + } + + for expectedLoc := range expectedAssertions { + assert.Contains(t, matchedLocations, expectedLoc) + } +} + +func TestWalkTag_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedAssertions := map[string]func(*swagger.Tag){ + "/tags/0": func(tag *swagger.Tag) { + assert.Equal(t, "users", tag.GetName()) + assert.Equal(t, "User operations", tag.GetDescription()) + }, + "/tags/1": func(tag *swagger.Tag) { + assert.Equal(t, "pets", tag.GetName()) + assert.Equal(t, "Pet operations", tag.GetDescription()) + }, + } + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Tag: func(tag *swagger.Tag) error { + tagLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, tagLoc) + + if assertFunc, exists := expectedAssertions[tagLoc]; exists { + assertFunc(tag) + } + + return nil + }, + }) + require.NoError(t, err) + } + + for expectedLoc := range expectedAssertions { + assert.Contains(t, matchedLocations, expectedLoc) + } +} + +func TestWalkSecurity_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedLoc := "/security/0" + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + SecurityRequirement: func(sr *swagger.SecurityRequirement) error { + securityLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, securityLoc) + + if securityLoc == expectedLoc { + assert.NotNil(t, sr) + // Security requirement should have apiKey + apiKeyScopes, exists := sr.Get("apiKey") + assert.True(t, exists) + assert.Empty(t, apiKeyScopes) // Empty array for API key + + return walk.ErrTerminate + } + + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Contains(t, matchedLocations, expectedLoc) +} + +func TestWalkPaths_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedLoc := "/paths" + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Paths: func(p *swagger.Paths) error { + pathsLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, pathsLoc) + + if pathsLoc == expectedLoc { + assert.NotNil(t, p) + // Should contain the /users/{id} path + pathItem, exists := p.Get("/users/{id}") + assert.True(t, exists) + assert.NotNil(t, pathItem) + + return walk.ErrTerminate + } + + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Contains(t, matchedLocations, expectedLoc) +} + +func TestWalkPathItem_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedLoc := "/paths/~1users~1{id}" + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + PathItem: func(pi *swagger.PathItem) error { + pathItemLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, pathItemLoc) + + if pathItemLoc == expectedLoc { + assert.NotNil(t, pi) + assert.NotNil(t, pi.Get()) + assert.Equal(t, "getUser", pi.Get().GetOperationID()) + + return walk.ErrTerminate + } + + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Contains(t, matchedLocations, expectedLoc) +} + +func TestWalkOperation_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedLoc := "/paths/~1users~1{id}/get" + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Operation: func(op *swagger.Operation) error { + operationLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, operationLoc) + + if operationLoc == expectedLoc { + assert.Equal(t, "getUser", op.GetOperationID()) + assert.Equal(t, "Get user by ID", op.GetSummary()) + assert.Equal(t, "Retrieve a user by their ID", op.GetDescription()) + assert.Contains(t, op.GetTags(), "users") + + return walk.ErrTerminate + } + + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Contains(t, matchedLocations, expectedLoc) +} + +func TestWalkReferencedParameter_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedAssertions := map[string]func(*swagger.ReferencedParameter){ + "/paths/~1users~1{id}/parameters/0": func(rp *swagger.ReferencedParameter) { + assert.False(t, rp.IsReference()) + assert.NotNil(t, rp.Object) + assert.Equal(t, "id", rp.Object.GetName()) + assert.Equal(t, swagger.ParameterInPath, rp.Object.GetIn()) + }, + "/paths/~1users~1{id}/get/parameters/0": func(rp *swagger.ReferencedParameter) { + assert.False(t, rp.IsReference()) + assert.NotNil(t, rp.Object) + assert.Equal(t, "expand", rp.Object.GetName()) + assert.Equal(t, swagger.ParameterInQuery, rp.Object.GetIn()) + }, + } + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + ReferencedParameter: func(rp *swagger.ReferencedParameter) error { + paramLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, paramLoc) + + if assertFunc, exists := expectedAssertions[paramLoc]; exists { + assertFunc(rp) + } + + return nil + }, + }) + require.NoError(t, err) + } + + for expectedLoc := range expectedAssertions { + assert.Contains(t, matchedLocations, expectedLoc) + } +} + +func TestWalkParameter_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedAssertions := map[string]func(*swagger.Parameter){ + "/paths/~1users~1{id}/parameters/0": func(p *swagger.Parameter) { + assert.Equal(t, "id", p.GetName()) + assert.Equal(t, swagger.ParameterInPath, p.GetIn()) + assert.Equal(t, "integer", p.GetType()) + }, + "/parameters/PageParam": func(p *swagger.Parameter) { + assert.Equal(t, "page", p.GetName()) + assert.Equal(t, swagger.ParameterInQuery, p.GetIn()) + assert.Equal(t, "integer", p.GetType()) + }, + } + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Parameter: func(p *swagger.Parameter) error { + paramLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, paramLoc) + + if assertFunc, exists := expectedAssertions[paramLoc]; exists { + assertFunc(p) + } + + return nil + }, + }) + require.NoError(t, err) + } + + for expectedLoc := range expectedAssertions { + assert.Contains(t, matchedLocations, expectedLoc) + } +} + +func TestWalkSchema_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedAssertions := map[string]func(any){ + "/definitions/User": func(a any) { + schema, ok := a.(*oas3.JSONSchema[oas3.Concrete]) + assert.True(t, ok) + assert.NotNil(t, schema) + // For concrete schemas, the schema object is in Left + assert.True(t, schema.IsLeft()) + if schema.Left != nil { + schemaType := schema.Left.GetType() + assert.Len(t, schemaType, 1) + assert.Equal(t, oas3.SchemaTypeObject, schemaType[0]) + assert.Equal(t, "User object", schema.Left.GetDescription()) + } + }, + "/paths/~1users~1{id}/get/responses/200/schema": func(a any) { + // Schema reference, just verify it exists + assert.NotNil(t, a) + }, + } + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Any: func(model any) error { + loc := string(item.Location.ToJSONPointer()) + + if assertFunc, exists := expectedAssertions[loc]; exists { + matchedLocations = append(matchedLocations, loc) + assertFunc(model) + } + + return nil + }, + }) + require.NoError(t, err) + } + + for expectedLoc := range expectedAssertions { + assert.Contains(t, matchedLocations, expectedLoc) + } +} + +func TestWalkResponses_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedLoc := "/paths/~1users~1{id}/get/responses" + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Responses: func(r *swagger.Responses) error { + responsesLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, responsesLoc) + + if responsesLoc == expectedLoc { + assert.NotNil(t, r) + // Should have 200 response + response200, exists := r.Get("200") + assert.True(t, exists) + assert.NotNil(t, response200) + // Should have default response + assert.NotNil(t, r.Default) + + return walk.ErrTerminate + } + + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Contains(t, matchedLocations, expectedLoc) +} + +func TestWalkReferencedResponse_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedAssertions := map[string]func(*swagger.ReferencedResponse){ + "/paths/~1users~1{id}/get/responses/200": func(rr *swagger.ReferencedResponse) { + assert.False(t, rr.IsReference()) + assert.NotNil(t, rr.Object) + assert.Equal(t, "Successful response", rr.Object.GetDescription()) + }, + "/paths/~1users~1{id}/get/responses/default": func(rr *swagger.ReferencedResponse) { + assert.False(t, rr.IsReference()) + assert.NotNil(t, rr.Object) + assert.Equal(t, "Error response", rr.Object.GetDescription()) + }, + } + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + ReferencedResponse: func(rr *swagger.ReferencedResponse) error { + responseLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, responseLoc) + + if assertFunc, exists := expectedAssertions[responseLoc]; exists { + assertFunc(rr) + } + + return nil + }, + }) + require.NoError(t, err) + } + + for expectedLoc := range expectedAssertions { + assert.Contains(t, matchedLocations, expectedLoc) + } +} + +func TestWalkGlobalResponse_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedLoc := "/responses/ErrorResponse" + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Response: func(r *swagger.Response) error { + responseLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, responseLoc) + + if responseLoc == expectedLoc { + assert.Equal(t, "Error response", r.GetDescription()) + assert.NotNil(t, r.Schema) + + return walk.ErrTerminate + } + + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Contains(t, matchedLocations, expectedLoc) +} + +func TestWalkResponse_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedLoc := "/paths/~1users~1{id}/get/responses/200" + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Response: func(r *swagger.Response) error { + responseLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, responseLoc) + + if responseLoc == expectedLoc { + assert.Equal(t, "Successful response", r.GetDescription()) + assert.NotNil(t, r.Schema) + + return walk.ErrTerminate + } + + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Contains(t, matchedLocations, expectedLoc) +} + +func TestWalkHeader_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedLoc := "/paths/~1users~1{id}/get/responses/200/headers/X-Rate-Limit" + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Header: func(h *swagger.Header) error { + headerLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, headerLoc) + + if headerLoc == expectedLoc { + assert.Equal(t, "integer", h.GetType()) + assert.Equal(t, "Rate limit remaining", h.GetDescription()) + + return walk.ErrTerminate + } + + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Contains(t, matchedLocations, expectedLoc) +} + +func TestWalkItems_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedLoc := "/paths/~1pets/get/parameters/0/items" + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Items: func(i *swagger.Items) error { + itemsLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, itemsLoc) + + if itemsLoc == expectedLoc { + assert.Equal(t, "string", i.GetType()) + + return walk.ErrTerminate + } + + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Contains(t, matchedLocations, expectedLoc) +} + +func TestWalkSecurityScheme_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + matchedLocations := []string{} + expectedAssertions := map[string]func(*swagger.SecurityScheme){ + "/securityDefinitions/apiKey": func(ss *swagger.SecurityScheme) { + assert.Equal(t, swagger.SecuritySchemeTypeAPIKey, ss.GetType()) + assert.Equal(t, "X-API-Key", ss.GetName()) + assert.Equal(t, swagger.SecuritySchemeInHeader, ss.GetIn()) + }, + "/securityDefinitions/oauth2": func(ss *swagger.SecurityScheme) { + assert.Equal(t, swagger.SecuritySchemeTypeOAuth2, ss.GetType()) + assert.Equal(t, swagger.OAuth2FlowAccessCode, ss.GetFlow()) + assert.Equal(t, "https://example.com/oauth/authorize", ss.GetAuthorizationURL()) + assert.Equal(t, "https://example.com/oauth/token", ss.GetTokenURL()) + }, + } + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + SecurityScheme: func(ss *swagger.SecurityScheme) error { + schemeLoc := string(item.Location.ToJSONPointer()) + matchedLocations = append(matchedLocations, schemeLoc) + + if assertFunc, exists := expectedAssertions[schemeLoc]; exists { + assertFunc(ss) + } + + return nil + }, + }) + require.NoError(t, err) + } + + for expectedLoc := range expectedAssertions { + assert.Contains(t, matchedLocations, expectedLoc) + } +} + +func TestWalkAny_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + visitCounts := make(map[string]int) + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Any: func(model any) error { + location := string(item.Location.ToJSONPointer()) + visitCounts[location]++ + return nil + }, + }) + require.NoError(t, err) + } + + // Verify we visited key locations + assert.Positive(t, visitCounts["/"], "Should visit root") + assert.Positive(t, visitCounts["/info"], "Should visit info") + assert.Positive(t, visitCounts["/paths"], "Should visit paths") + assert.Positive(t, visitCounts["/definitions/User"], "Should visit User definition") + + // Should have visited many locations + assert.Greater(t, len(visitCounts), 30, "Should visit many locations in comprehensive document") +} + +func TestWalk_Terminate_Success(t *testing.T) { + t.Parallel() + + swaggerDoc, err := loadSwaggerDocument(t.Context()) + require.NoError(t, err) + + visits := 0 + + for item := range swagger.Walk(t.Context(), swaggerDoc) { + err := item.Match(swagger.Matcher{ + Swagger: func(s *swagger.Swagger) error { + return walk.ErrTerminate + }, + Any: func(a any) error { + visits++ + return nil + }, + }) + + if errors.Is(err, walk.ErrTerminate) { + break + } + require.NoError(t, err) + } + + assert.Equal(t, 1, visits, "expected only one visit before terminating") +} diff --git a/yml/config.go b/yml/config.go index a992dee..772411e 100644 --- a/yml/config.go +++ b/yml/config.go @@ -22,16 +22,37 @@ const ( OutputFormatYAML OutputFormat = "yaml" ) +type IndentationStyle string + +const ( + IndentationStyleSpace IndentationStyle = "space" + IndentationStyleTab IndentationStyle = "tab" +) + +func (i IndentationStyle) ToIndent() string { + switch i { + case IndentationStyleSpace: + return " " + case IndentationStyleTab: + return "\t" + default: + return "" + } +} + type Config struct { - KeyStringStyle yaml.Style // The default string style to use when creating new keys - ValueStringStyle yaml.Style // The default string style to use when creating new nodes - Indentation int // The indentation level of the document - OutputFormat OutputFormat // The output format to use when marshalling - OriginalFormat OutputFormat // The original input format, helps detect when we are changing formats + KeyStringStyle yaml.Style // The default string style to use when creating new keys + ValueStringStyle yaml.Style // The default string style to use when creating new nodes + Indentation int // The indentation level of the document + IndentationStyle IndentationStyle // The indentation style of the document valid for JSON only + OutputFormat OutputFormat // The output format to use when marshalling + OriginalFormat OutputFormat // The original input format, helps detect when we are changing formats + TrailingNewline bool // Whether the original document had a trailing newline } var defaultConfig = &Config{ Indentation: 2, + IndentationStyle: IndentationStyleSpace, KeyStringStyle: 0, ValueStringStyle: 0, OutputFormat: OutputFormatYAML, @@ -64,9 +85,12 @@ func GetConfigFromContext(ctx context.Context) *Config { func GetConfigFromDoc(data []byte, doc *yaml.Node) *Config { cfg := *defaultConfig - cfg.OutputFormat, cfg.Indentation = inspectData(data) + cfg.OutputFormat, cfg.Indentation, cfg.IndentationStyle = inspectData(data) cfg.OriginalFormat = cfg.OutputFormat + // Check if the original data had a trailing newline + cfg.TrailingNewline = len(data) > 0 && data[len(data)-1] == '\n' + // Only extract string styles from the document if it's YAML // For JSON input, keep the default YAML styles if cfg.OriginalFormat == OutputFormatYAML { @@ -76,15 +100,19 @@ func GetConfigFromDoc(data []byte, doc *yaml.Node) *Config { return &cfg } -func inspectData(data []byte) (OutputFormat, int) { +func inspectData(data []byte) (OutputFormat, int, IndentationStyle) { lines := bytes.Split(bytes.TrimSpace(data), []byte("\n")) foundIndentation := false foundDocFormat := false indentation := 2 + indentationStyle := IndentationStyleSpace docFormat := OutputFormatYAML + // Track the minimum leading whitespace to establish baseline + minLeadingWhitespace := -1 + for i, line := range lines { trimLine := bytes.TrimSpace(line) @@ -99,18 +127,54 @@ func inspectData(data []byte) (OutputFormat, int) { docFormat = OutputFormatJSON foundDocFormat = true default: - if len(line) != len(trimLine) && !foundIndentation { - indentation = len(line) - len(trimLine) - foundIndentation = true + currentLeading := len(line) - len(trimLine) + + // Track minimum leading whitespace (baseline indentation) + if minLeadingWhitespace == -1 || currentLeading < minLeadingWhitespace { + minLeadingWhitespace = currentLeading + } + + // Look for indentation relative to the baseline + if currentLeading > minLeadingWhitespace && !foundIndentation { + // Extract the indentation (difference from baseline) + leadingWhitespace := line[minLeadingWhitespace:currentLeading] + + if len(leadingWhitespace) > 0 { + // Check the first character to determine tab vs space + if leadingWhitespace[0] == '\t' { + indentationStyle = IndentationStyleTab + // Count consecutive tabs + indentation = 0 + for _, ch := range leadingWhitespace { + if ch == '\t' { + indentation++ + } else { + break + } + } + } else if leadingWhitespace[0] == ' ' { + indentationStyle = IndentationStyleSpace + // Count consecutive spaces + indentation = 0 + for _, ch := range leadingWhitespace { + if ch == ' ' { + indentation++ + } else { + break + } + } + } + foundIndentation = true + } } } // If we have found everything we need or have iterated too long we can stop - if foundIndentation && (foundDocFormat || i > 5) { + if foundIndentation && (foundDocFormat || i > 10) { break } } - return docFormat, indentation + return docFormat, indentation, indentationStyle } func getGlobalStringStyle(doc *yaml.Node, cfg *Config) { diff --git a/yml/config_test.go b/yml/config_test.go index aa44a50..207090a 100644 --- a/yml/config_test.go +++ b/yml/config_test.go @@ -145,13 +145,14 @@ func TestGetConfigFromContext_Success(t *testing.T) { func TestGetConfigFromDoc_Success(t *testing.T) { t.Parallel() tests := []struct { - name string - data []byte - doc *yaml.Node - expectedFormat yml.OutputFormat - expectedIndent int - expectedKeyStyle yaml.Style - expectedValueStyle yaml.Style + name string + data []byte + doc *yaml.Node + expectedFormat yml.OutputFormat + expectedIndent int + expectedIndentStyle yml.IndentationStyle + expectedKeyStyle yaml.Style + expectedValueStyle yaml.Style }{ { name: "YAML document with quoted strings", @@ -176,10 +177,11 @@ func TestGetConfigFromDoc_Success(t *testing.T) { }, }, }, - expectedFormat: yml.OutputFormatYAML, - expectedIndent: 2, - expectedKeyStyle: 0, - expectedValueStyle: yaml.DoubleQuotedStyle, + expectedFormat: yml.OutputFormatYAML, + expectedIndent: 2, + expectedIndentStyle: yml.IndentationStyleSpace, + expectedKeyStyle: 0, + expectedValueStyle: yaml.DoubleQuotedStyle, }, { name: "JSON document", @@ -200,10 +202,11 @@ func TestGetConfigFromDoc_Success(t *testing.T) { }, }, }, - expectedFormat: yml.OutputFormatJSON, - expectedIndent: 2, - expectedKeyStyle: 0, - expectedValueStyle: 0, + expectedFormat: yml.OutputFormatJSON, + expectedIndent: 2, + expectedIndentStyle: yml.IndentationStyleSpace, + expectedKeyStyle: 0, + expectedValueStyle: 0, }, { name: "YAML with 4-space indentation", @@ -224,10 +227,55 @@ func TestGetConfigFromDoc_Success(t *testing.T) { }, }, }, - expectedFormat: yml.OutputFormatYAML, - expectedIndent: 4, - expectedKeyStyle: 0, - expectedValueStyle: 0, + expectedFormat: yml.OutputFormatYAML, + expectedIndent: 4, + expectedIndentStyle: yml.IndentationStyleSpace, + expectedKeyStyle: 0, + expectedValueStyle: 0, + }, + { + name: "YAML with single tab indentation", + data: []byte("key1: value1\nnested:\n\tkey2: value2"), + doc: &yaml.Node{ + Kind: yaml.DocumentNode, + Content: []*yaml.Node{ + { + Kind: yaml.MappingNode, + Tag: "!!map", + Content: []*yaml.Node{ + {Value: "key1", Kind: yaml.ScalarNode, Tag: "!!str"}, + {Value: "value1", Kind: yaml.ScalarNode, Tag: "!!str"}, + }, + }, + }, + }, + expectedFormat: yml.OutputFormatYAML, + expectedIndent: 1, + expectedIndentStyle: yml.IndentationStyleTab, + expectedKeyStyle: 0, + expectedValueStyle: 0, + }, + { + name: "YAML with double tab indentation", + data: []byte("key1: value1\nnested:\n\t\tkey2: value2"), + doc: &yaml.Node{ + Kind: yaml.DocumentNode, + Content: []*yaml.Node{ + { + Kind: yaml.MappingNode, + Tag: "!!map", + Content: []*yaml.Node{ + {Value: "key1", Kind: yaml.ScalarNode, Tag: "!!str"}, + {Value: "value1", Kind: yaml.ScalarNode, Tag: "!!str"}, + }, + }, + }, + }, + expectedFormat: yml.OutputFormatYAML, + expectedIndent: 2, + expectedIndentStyle: yml.IndentationStyleTab, + expectedKeyStyle: 0, + expectedValueStyle: 0, }, } @@ -240,6 +288,7 @@ func TestGetConfigFromDoc_Success(t *testing.T) { assert.Equal(t, tt.expectedFormat, result.OutputFormat) assert.Equal(t, tt.expectedFormat, result.OriginalFormat) assert.Equal(t, tt.expectedIndent, result.Indentation) + assert.Equal(t, tt.expectedIndentStyle, result.IndentationStyle) assert.Equal(t, tt.expectedKeyStyle, result.KeyStringStyle) assert.Equal(t, tt.expectedValueStyle, result.ValueStringStyle) }) @@ -294,6 +343,7 @@ func TestGetConfigFromDoc_WithComplexDocument_Success(t *testing.T) { assert.Equal(t, yml.OutputFormatYAML, result.OutputFormat) assert.Equal(t, yml.OutputFormatYAML, result.OriginalFormat) assert.Equal(t, 2, result.Indentation) + assert.Equal(t, yml.IndentationStyleSpace, result.IndentationStyle) assert.Equal(t, yaml.Style(0), result.KeyStringStyle) assert.Equal(t, yaml.Style(0), result.ValueStringStyle) // First string found doesn't have style } @@ -339,34 +389,39 @@ func TestGetConfigFromDoc_WithAliasNodes_Success(t *testing.T) { func TestGetConfigFromDoc_EdgeCases_Success(t *testing.T) { t.Parallel() tests := []struct { - name string - data []byte - expectedFormat yml.OutputFormat - expectedIndent int + name string + data []byte + expectedFormat yml.OutputFormat + expectedIndent int + expectedIndentStyle yml.IndentationStyle }{ { - name: "empty data", - data: []byte(""), - expectedFormat: yml.OutputFormatYAML, - expectedIndent: 2, + name: "empty data", + data: []byte(""), + expectedFormat: yml.OutputFormatYAML, + expectedIndent: 2, + expectedIndentStyle: yml.IndentationStyleSpace, }, { - name: "only comments", - data: []byte("# Just a comment\n# Another comment"), - expectedFormat: yml.OutputFormatYAML, - expectedIndent: 2, + name: "only comments", + data: []byte("# Just a comment\n# Another comment"), + expectedFormat: yml.OutputFormatYAML, + expectedIndent: 2, + expectedIndentStyle: yml.IndentationStyleSpace, }, { - name: "only whitespace", - data: []byte(" \n \n "), - expectedFormat: yml.OutputFormatYAML, - expectedIndent: 2, + name: "only whitespace", + data: []byte(" \n \n "), + expectedFormat: yml.OutputFormatYAML, + expectedIndent: 2, + expectedIndentStyle: yml.IndentationStyleSpace, }, { - name: "JSON with no indentation", - data: []byte(`{"key":"value"}`), - expectedFormat: yml.OutputFormatJSON, - expectedIndent: 2, + name: "JSON with no indentation", + data: []byte(`{"key":"value"}`), + expectedFormat: yml.OutputFormatJSON, + expectedIndent: 2, + expectedIndentStyle: yml.IndentationStyleSpace, }, } @@ -390,6 +445,7 @@ func TestGetConfigFromDoc_EdgeCases_Success(t *testing.T) { require.NotNil(t, result) assert.Equal(t, tt.expectedFormat, result.OutputFormat) assert.Equal(t, tt.expectedIndent, result.Indentation) + assert.Equal(t, tt.expectedIndentStyle, result.IndentationStyle) }) } } From 28d2d3b8cd538f893c34b3337dd2f76a3d6d1841 Mon Sep 17 00:00:00 2001 From: Tristan Cartledge Date: Fri, 17 Oct 2025 08:45:10 +1000 Subject: [PATCH 4/5] fix --- RELEASE_NOTES.md | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 RELEASE_NOTES.md diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md deleted file mode 100644 index 4f9aa9e..0000000 --- a/RELEASE_NOTES.md +++ /dev/null @@ -1,9 +0,0 @@ -## Release Notes - -### Fixes -- Snip TUI select-all/deselect bug -- Clean reachability and tag cleanup -- Documentation updates - -### Related Pull Request -- #65 \ No newline at end of file From aa8b51161408f682d4dc6ecf530ca13dfd8e6d37 Mon Sep 17 00:00:00 2001 From: Tristan Cartledge Date: Fri, 17 Oct 2025 13:41:37 +1000 Subject: [PATCH 5/5] fix --- json/json.go | 2 +- json/json_test.go | 9 +++++++ yml/config.go | 6 ++++- yml/config_test.go | 58 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 73 insertions(+), 2 deletions(-) diff --git a/json/json.go b/json/json.go index 0f6c76e..220057d 100644 --- a/json/json.go +++ b/json/json.go @@ -35,7 +35,7 @@ func YAMLToJSONWithConfig(node *yaml.Node, indent string, indentCount int, trail indent: indentStr, buffer: &bytes.Buffer{}, currentCol: 0, - forceCompact: indentCount == 0, // Force compact mode when no indentation + forceCompact: len(indentStr) == 0, // Force compact mode when no indentation } // Write the JSON diff --git a/json/json_test.go b/json/json_test.go index ca6dbbf..f189e1b 100644 --- a/json/json_test.go +++ b/json/json_test.go @@ -383,6 +383,15 @@ middle: second`, indentCount: 1, expectedJSON: "{\n\t\"zebra\": \"last\",\n\t\"apple\": \"first\",\n\t\"middle\": \"second\"\n}\n", }, + { + name: "empty indent string with non-zero count (compact)", + yamlInput: `name: John +age: 30`, + indent: "", + indentCount: 5, + expectedJSON: `{"name":"John","age":30} +`, + }, } for _, tt := range tests { diff --git a/yml/config.go b/yml/config.go index 772411e..b587959 100644 --- a/yml/config.go +++ b/yml/config.go @@ -127,7 +127,11 @@ func inspectData(data []byte) (OutputFormat, int, IndentationStyle) { docFormat = OutputFormatJSON foundDocFormat = true default: - currentLeading := len(line) - len(trimLine) + // Calculate leading whitespace by counting from the start + currentLeading := 0 + for currentLeading < len(line) && (line[currentLeading] == ' ' || line[currentLeading] == '\t') { + currentLeading++ + } // Track minimum leading whitespace (baseline indentation) if minLeadingWhitespace == -1 || currentLeading < minLeadingWhitespace { diff --git a/yml/config_test.go b/yml/config_test.go index 207090a..18afc71 100644 --- a/yml/config_test.go +++ b/yml/config_test.go @@ -449,3 +449,61 @@ func TestGetConfigFromDoc_EdgeCases_Success(t *testing.T) { }) } } + +func TestGetConfigFromDoc_TrailingWhitespace_Success(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + data []byte + expectedIndent int + }{ + { + name: "YAML with trailing whitespace on lines", + data: []byte(`key: value +nested: + child: value `), + expectedIndent: 2, + }, + { + name: "YAML with trailing tabs", + data: []byte("key: value\t\t\nnested:\t\t\n child: value\t\t"), + expectedIndent: 2, + }, + { + name: "line with only leading brace and trailing whitespace", + data: []byte("{ "), + expectedIndent: 2, // default when no indentation found + }, + { + name: "mixed leading and trailing whitespace", + data: []byte(` key: value + nested: + child: value `), + expectedIndent: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Create a minimal document structure + doc := &yaml.Node{ + Kind: yaml.DocumentNode, + Content: []*yaml.Node{ + { + Kind: yaml.MappingNode, + Tag: "!!map", + Content: []*yaml.Node{}, + }, + }, + } + + result := yml.GetConfigFromDoc(tt.data, doc) + + require.NotNil(t, result) + assert.Equal(t, tt.expectedIndent, result.Indentation, "should correctly calculate indentation ignoring trailing whitespace") + }) + } +}