Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .generated-golangci-depguard.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ linters:
- $all
- '!$test'
- '!**/pkg/testutil/**/*.go'
- '!**/pkg/**/test/*.go'
deny:
- desc: Production code should not depend on test utilities.
pkg: github.com/GoogleCloudPlatform/khi/pkg/testutil
25 changes: 24 additions & 1 deletion cmd/kubernetes-history-inspector/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ import (
"github.com/GoogleCloudPlatform/khi/pkg/model/k8s"
"github.com/GoogleCloudPlatform/khi/pkg/parameters"
"github.com/GoogleCloudPlatform/khi/pkg/server"
"github.com/GoogleCloudPlatform/khi/pkg/server/option"
"github.com/GoogleCloudPlatform/khi/pkg/server/upload"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"

inspectioncore_contract "github.com/GoogleCloudPlatform/khi/pkg/task/inspection/inspectioncore/contract"

Expand Down Expand Up @@ -152,6 +155,21 @@ func run() int {

slog.Info("Starting Kubernetes History Inspector server...")

// Setting up options or parameters needed to instanciate gin.Engine
serverMode := gin.ReleaseMode
server.DefaultServerFactory.AddOptions(option.Required())

corsConfig := cors.DefaultConfig()
corsConfig.AllowAllOrigins = true
server.DefaultServerFactory.AddOptions(option.CORS(corsConfig))

if *parameters.Debug.Verbose {
server.DefaultServerFactory.AddOptions(
option.AccessLog("/api/v3/inspection", "/api/v3/popup"), // ignoreing noisy paths
)
serverMode = gin.DebugMode
}

uploadFileStoreFolder := "/tmp"

if parameters.Common.UploadFileStoreFolder != nil {
Expand All @@ -167,7 +185,12 @@ func run() int {
ServerBasePath: *parameters.Server.BasePath,
UploadFileStore: upload.DefaultUploadFileStore,
}
engine := server.CreateKHIServer(inspectionServer, &config)
engine, err := server.DefaultServerFactory.CreateInstance(serverMode)
if err != nil {
slog.Error(fmt.Sprintf("Failed to create a server instance\n%v", err))
return 1
}
engine = server.CreateKHIServer(engine, inspectionServer, &config)

if parameters.Auth.OAuthEnabled() {
err := accesstoken.DefaultOAuthTokenResolver.SetServer(engine)
Expand Down
10 changes: 5 additions & 5 deletions pkg/core/inspection/taskbase/cached_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import (
inspectioncore_contract "github.com/GoogleCloudPlatform/khi/pkg/task/inspection/inspectioncore/contract"
)

// PreviousTaskResult is the combination of the cached value and a digest of its dependency.
type PreviousTaskResult[T any] struct {
// CacheableTaskResult is the combination of the cached value and a digest of its dependency.
type CacheableTaskResult[T any] struct {
// Value is the value used previous run.
Value T
// DependencyDigest is a string representation of digest of its inputs.
Expand All @@ -35,11 +35,11 @@ type PreviousTaskResult[T any] struct {
}

// NewCachedTask generates a task which can reuse the value last time.
func NewCachedTask[T any](taskID taskid.TaskImplementationID[T], depdendencies []taskid.UntypedTaskReference, f func(ctx context.Context, prevValue PreviousTaskResult[T]) (PreviousTaskResult[T], error), labelOpt ...coretask.LabelOpt) coretask.Task[T] {
func NewCachedTask[T any](taskID taskid.TaskImplementationID[T], depdendencies []taskid.UntypedTaskReference, f func(ctx context.Context, prevValue CacheableTaskResult[T]) (CacheableTaskResult[T], error), labelOpt ...coretask.LabelOpt) coretask.Task[T] {
return coretask.NewTask(taskID, depdendencies, func(ctx context.Context) (T, error) {
inspectionSharedMap := khictx.MustGetValue(ctx, inspectioncore_contract.GlobalSharedMap)
cacheKey := typedmap.NewTypedKey[PreviousTaskResult[T]](fmt.Sprintf("cached_result-%s", taskID.String()))
cachedResult := typedmap.GetOrDefault(inspectionSharedMap, cacheKey, PreviousTaskResult[T]{
cacheKey := typedmap.NewTypedKey[CacheableTaskResult[T]](fmt.Sprintf("cached_result-%s", taskID.String()))
cachedResult := typedmap.GetOrDefault(inspectionSharedMap, cacheKey, CacheableTaskResult[T]{
Value: *new(T),
DependencyDigest: "",
})
Expand Down
8 changes: 4 additions & 4 deletions pkg/core/inspection/taskbase/cached_task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ import (
)

func TestCachedTask(t *testing.T) {
prevValues := []PreviousTaskResult[string]{}
prevValues := []CacheableTaskResult[string]{}
testTaskID := taskid.NewDefaultImplementationID[string]("foo")
task := NewCachedTask(testTaskID, []taskid.UntypedTaskReference{}, func(ctx context.Context, prevValue PreviousTaskResult[string]) (PreviousTaskResult[string], error) {
task := NewCachedTask(testTaskID, []taskid.UntypedTaskReference{}, func(ctx context.Context, prevValue CacheableTaskResult[string]) (CacheableTaskResult[string], error) {
prevValues = append(prevValues, prevValue)
return PreviousTaskResult[string]{
return CacheableTaskResult[string]{
Value: "foo",
DependencyDigest: "foo",
}, nil
Expand All @@ -45,7 +45,7 @@ func TestCachedTask(t *testing.T) {
t.Errorf("unexpected task error result %v", err)
}

if diff := cmp.Diff(prevValues, []PreviousTaskResult[string]{
if diff := cmp.Diff(prevValues, []CacheableTaskResult[string]{
{
Value: "",
DependencyDigest: "",
Expand Down
127 changes: 127 additions & 0 deletions pkg/server/option/option.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package option

import (
"slices"

"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)

// Option defines an interface for configuring a Gin engine.
type Option interface {
// ID returns a unique identifier for the option.
ID() string
// Order returns the order in which this option should be applied relative to other options.
Order() int
// Apply applies the option's configuration to the given Gin engine.
// It returns an error if the application fails.
Apply(engine *gin.Engine) error
}

func ApplyOptions(engine *gin.Engine, options []Option) error {
slices.SortFunc(options, func(a, b Option) int { return a.Order() - b.Order() })
for _, option := range options {
err := option.Apply(engine)
if err != nil {
return err
}
}
return nil
}

// requiredOption is an Option implementation for setting several required middleware in KHI.
type requiredOption struct {
}

// Required creates a new Option to set several required middlewares and gin server mode.
func Required() Option {
return &requiredOption{}
}

func (s *requiredOption) ID() string {
return "required"
}

// Order returns the application order for the server mode option.
func (s *requiredOption) Order() int {
return 0
}

// Apply adds required middlewares (currently the recovery is the only middleware.)
func (s *requiredOption) Apply(engine *gin.Engine) error {
engine.Use(gin.Recovery())
return nil
}

var _ Option = (*requiredOption)(nil)

// corsOption is an Option implementation for enabling CORS.
type corsOption struct {
corsConfig cors.Config
}

// CORS creates a new Option to enable CORS.
func CORS(config cors.Config) Option {
return &corsOption{config}
}

func (c *corsOption) ID() string {
return "cors"
}

// Order returns the application order for the CORS option.
func (c *corsOption) Order() int {
return 1
}

// Apply configures the Gin engine to use the gin-contrib/cors middleware with all origins allowed.
func (c *corsOption) Apply(engine *gin.Engine) error {
engine.Use(cors.New(c.corsConfig))
return nil
}

var _ Option = (*corsOption)(nil)

type accessLogOption struct {
ignoredPath []string
}

// AccessLog creates a new Option to log access logs with ignoreing the provided paths.
func AccessLog(ignoredPath ...string) Option {
return &accessLogOption{
ignoredPath: ignoredPath,
}
}

// Apply implements Option.
func (l *accessLogOption) Apply(engine *gin.Engine) error {
engine.Use(gin.LoggerWithConfig(gin.LoggerConfig{
SkipPaths: l.ignoredPath,
}))
return nil
}

// Order implements Option.
func (l *accessLogOption) Order() int {
return 2
}

func (l *accessLogOption) ID() string {
return "access-log"
}

var _ Option = (*accessLogOption)(nil)
163 changes: 163 additions & 0 deletions pkg/server/option/option_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package option

import (
"errors"
"fmt"
"net/http/httptest"
"testing"

"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)

// mockOption is a helper for testing.
type mockOption struct {
id string
order int
apply func(e *gin.Engine) error
}

func (m *mockOption) ID() string {
return m.id
}

func (m *mockOption) Order() int {
return m.order
}

func (m *mockOption) Apply(engine *gin.Engine) error {
if m.apply != nil {
return m.apply(engine)
}
return nil
}

func TestApplyOptions(t *testing.T) {
gin.SetMode(gin.TestMode)

testCases := []struct {
name string
options []Option
expectOrder []string
wantErr bool
}{
{
name: "should apply in correct order",
options: []Option{
&mockOption{id: "option-2", order: 2},
&mockOption{id: "option-1", order: 1},
&mockOption{id: "option-3", order: 3},
},
expectOrder: []string{"option-1", "option-2", "option-3"},
wantErr: false,
},
{
name: "should handle empty options",
options: []Option{},
expectOrder: []string{},
wantErr: false,
},
{
name: "should return error on apply failure",
options: []Option{
&mockOption{id: "good-option", order: 1},
&mockOption{id: "bad-option", order: 2, apply: func(e *gin.Engine) error {
return errors.New("apply failed")
}},
},
expectOrder: []string{"good-option", "bad-option"}, // bad-option will not be used but it must be called once in the order.
wantErr: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
engine := gin.New()
var appliedOrder []string

// Wrap mock options to record apply calls
recordingOptions := make([]Option, len(tc.options))
for i, opt := range tc.options {
mock, ok := opt.(*mockOption)
if !ok {
t.Fatalf("test setup error: expected mockOption")
}
// copy mock to avoid closure issues
originalApply := mock.apply
recordingOptions[i] = &mockOption{
id: mock.id,
order: mock.order,
apply: func(e *gin.Engine) error {
appliedOrder = append(appliedOrder, mock.id)
if originalApply != nil {
return originalApply(e)
}
return nil
},
}
}

err := ApplyOptions(engine, recordingOptions)

if (err != nil) != tc.wantErr {
t.Fatalf("ApplyOptions() error = %v, wantErr %v", err, tc.wantErr)
}

if fmt.Sprint(appliedOrder) != fmt.Sprint(tc.expectOrder) {
t.Errorf("ApplyOptions() applied in wrong order. got=%v, want=%v", appliedOrder, tc.expectOrder)
}
})
}
}

func TestCorsOption(t *testing.T) {
gin.SetMode(gin.TestMode)

config := cors.Config{
AllowOrigins: []string{"http://localhost:4200"},
}

opt := CORS(config)
engine := gin.New()

if err := opt.Apply(engine); err != nil {
t.Fatalf("Apply() failed: %v", err)
}

// Check ID and Order
if opt.ID() != "cors" {
t.Errorf("ID() got = %q, want = \"cors\"", opt.ID())
}
if opt.Order() != 1 {
t.Errorf("Order() got = %d, want = 1", opt.Order())
}

// Check if CORS header is present by making a request
engine.GET("/test", func(c *gin.Context) {
c.String(200, "ok")
})

req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("Origin", "http://localhost:4200")
w := httptest.NewRecorder()
engine.ServeHTTP(w, req)

gotHeader := w.Header().Get("Access-Control-Allow-Origin")
if gotHeader != "http://localhost:4200" {
t.Errorf("Access-Control-Allow-Origin header not set correctly. got=%q, want=%q", gotHeader, "http://localhost:4200")
}
}
Loading
Loading