diff --git a/decode.go b/decode.go index b5423d1c..b296eade 100644 --- a/decode.go +++ b/decode.go @@ -16,6 +16,7 @@ package yaml import ( + "context" "encoding" "encoding/base64" "fmt" @@ -312,6 +313,7 @@ func (p *parser) mapping() *Node { // Decoder, unmarshals a node into a provided value. type decoder struct { + ctx context.Context doc *Node aliases map[*Node]bool terrors []*UnmarshalError @@ -336,8 +338,9 @@ var ( ifaceType = generalMapType.Elem() ) -func newDecoder() *decoder { +func newDecoder(ctx context.Context) *decoder { d := &decoder{ + ctx: ctx, stringMapType: stringMapType, generalMapType: generalMapType, uniqueKeys: true, @@ -383,6 +386,24 @@ func (d *decoder) callUnmarshaler(n *Node, u Unmarshaler) (good bool) { } } +func (d *decoder) callUnmarshalerWithContext(n *Node, u UnmarshalerWithContext) (good bool) { + err := u.UnmarshalYAML(d.ctx, n) + switch e := err.(type) { + case nil: + return true + case *TypeError: + d.terrors = append(d.terrors, e.Errors...) + return false + default: + d.terrors = append(d.terrors, &UnmarshalError{ + Err: err, + Line: n.Line, + Column: n.Column, + }) + return false + } +} + func (d *decoder) callObsoleteUnmarshaler(n *Node, u obsoleteUnmarshaler) (good bool) { terrlen := len(d.terrors) err := u.UnmarshalYAML(func(v any) (err error) { @@ -434,6 +455,10 @@ func (d *decoder) prepare(n *Node, out reflect.Value) (newout reflect.Value, unm } if out.CanAddr() { outi := out.Addr().Interface() + if u, ok := outi.(UnmarshalerWithContext); ok { + good = d.callUnmarshalerWithContext(n, u) + return out, true, good + } if u, ok := outi.(Unmarshaler); ok { good = d.callUnmarshaler(n, u) return out, true, good diff --git a/decode_test.go b/decode_test.go index 1b29a8c3..4dfa1e46 100644 --- a/decode_test.go +++ b/decode_test.go @@ -17,6 +17,7 @@ package yaml_test import ( "bytes" + "context" "encoding" "errors" "fmt" @@ -2123,6 +2124,29 @@ a: } } +type typeWithContext struct { + value map[string]any + context context.Context +} + +func (t *typeWithContext) UnmarshalYAML(ctx context.Context, value *yaml.Node) error { + t.context = ctx + t.value = make(map[string]any) + value.Decode(&t.value) + return nil +} + +func TestDecodeWithContext(t *testing.T) { + type contextKey struct{} + ctx := context.WithValue(context.Background(), contextKey{}, "value") + var v typeWithContext + err := yaml.UnmarshalWithContext(ctx, []byte("foo: bar"), &v) + + assert.NoError(t, err) + assert.Equal(t, v.context.Value(contextKey{}), "value") + assert.DeepEqual(t, v.value, map[string]any{"foo": "bar"}) +} + //var data []byte //func init() { // var err error diff --git a/encode.go b/encode.go index ce66dee9..6e993ca2 100644 --- a/encode.go +++ b/encode.go @@ -16,6 +16,7 @@ package yaml import ( + "context" "encoding" "fmt" "io" @@ -31,6 +32,7 @@ import ( ) type encoder struct { + context context.Context emitter libyaml.Emitter event libyaml.Event out []byte @@ -39,9 +41,10 @@ type encoder struct { doneInit bool } -func newEncoder() *encoder { +func newEncoder(ctx context.Context) *encoder { e := &encoder{ emitter: libyaml.NewEmitter(), + context: ctx, } e.emitter.SetOutputString(&e.out) e.emitter.SetUnicode(true) @@ -140,6 +143,17 @@ func (e *encoder) marshal(tag string, in reflect.Value) { case time.Duration: e.stringv(tag, reflect.ValueOf(value.String())) return + case MarshalerWithContext: + v, err := value.MarshalYAML(e.context) + if err != nil { + fail(err) + } + if v == nil { + e.nilv() + return + } + e.marshal(tag, reflect.ValueOf(v)) + return case Marshaler: v, err := value.MarshalYAML() if err != nil { diff --git a/encode_test.go b/encode_test.go index 168a1cba..920b4359 100644 --- a/encode_test.go +++ b/encode_test.go @@ -17,6 +17,7 @@ package yaml_test import ( "bytes" + "context" "fmt" "math" "net" @@ -1407,3 +1408,11 @@ func TestUnicodeWhitespaceHandling(t *testing.T) { }) } } + +func TestMarshalWithContext(t *testing.T) { + type contextKey struct{} + ctx := context.WithValue(context.Background(), contextKey{}, "value") + data, err := yaml.MarshalWithContext(ctx, map[string]string{"foo": "bar"}) + assert.NoError(t, err) + assert.Equal(t, "foo: bar\n", string(data)) +} diff --git a/yaml.go b/yaml.go index a218035c..f83153fe 100644 --- a/yaml.go +++ b/yaml.go @@ -21,6 +21,7 @@ package yaml import ( + "context" "errors" "fmt" "io" @@ -40,6 +41,11 @@ type Unmarshaler interface { UnmarshalYAML(value *Node) error } +// UnmarshalerWithContext is the same as Unmarshaler, but it also receives a context. +type UnmarshalerWithContext interface { + UnmarshalYAML(ctx context.Context, value *Node) error +} + type obsoleteUnmarshaler interface { UnmarshalYAML(unmarshal func(any) error) error } @@ -54,6 +60,11 @@ type Marshaler interface { MarshalYAML() (any, error) } +// MarshalerWithContext is the same as Marshaler, but it also receives a context. +type MarshalerWithContext interface { + MarshalYAML(ctx context.Context) (any, error) +} + // Unmarshal decodes the first document found within the in byte slice // and assigns decoded values into the out value. // @@ -88,7 +99,13 @@ type Marshaler interface { // See the documentation of Marshal for the format of tags and a list of // supported tag options. func Unmarshal(in []byte, out any) (err error) { - return unmarshal(in, out, false) + return unmarshal(context.Background(), in, out, false) +} + +// UnmarshalWithContext is the same as Unmarshal, but also provides +// a context to the UnmarshalerWithContext implementations. +func UnmarshalWithContext(ctx context.Context, in []byte, out any) (err error) { + return unmarshal(ctx, in, out, false) } // A Decoder reads and decodes YAML values from an input stream. @@ -119,7 +136,16 @@ func (dec *Decoder) KnownFields(enable bool) { // See the documentation for Unmarshal for details about the // conversion of YAML into a Go value. func (dec *Decoder) Decode(v any) (err error) { - d := newDecoder() + return dec.DecodeWithContext(context.Background(), v) +} + +// DecodeWithContext reads the next YAML-encoded value from its input +// and stores it in the value pointed to by v. +// +// See the documentation for Unmarshal for details about the +// conversion of YAML into a Go value. +func (dec *Decoder) DecodeWithContext(ctx context.Context, v any) (err error) { + d := newDecoder(ctx) d.knownFields = dec.knownFields defer handleErr(&err) node := dec.parser.parse() @@ -142,7 +168,11 @@ func (dec *Decoder) Decode(v any) (err error) { // See the documentation for Unmarshal for details about the // conversion of YAML into a Go value. func (n *Node) Decode(v any) (err error) { - d := newDecoder() + return n.DecodeWithContext(context.Background(), v) +} + +func (n *Node) DecodeWithContext(ctx context.Context, v any) (err error) { + d := newDecoder(ctx) defer handleErr(&err) out := reflect.ValueOf(v) if out.Kind() == reflect.Pointer && !out.IsNil() { @@ -155,9 +185,9 @@ func (n *Node) Decode(v any) (err error) { return nil } -func unmarshal(in []byte, out any, strict bool) (err error) { +func unmarshal(ctx context.Context, in []byte, out any, strict bool) (err error) { defer handleErr(&err) - d := newDecoder() + d := newDecoder(ctx) p := newParser(in) defer p.destroy() node := p.parse() @@ -217,8 +247,14 @@ func unmarshal(in []byte, out any, strict bool) (err error) { // yaml.Marshal(&T{B: 2}) // Returns "b: 2\n" // yaml.Marshal(&T{F: 1}} // Returns "a: 1\nb: 0\n" func Marshal(in any) (out []byte, err error) { + return MarshalWithContext(context.Background(), in) +} + +// MarshalWithContext is the same as Marshal, but also provides +// a context to the MarshalerWithContext implementations. +func MarshalWithContext(ctx context.Context, in any) (out []byte, err error) { defer handleErr(&err) - e := newEncoder() + e := newEncoder(ctx) defer e.destroy() e.marshalDoc("", reflect.ValueOf(in)) e.finish() @@ -258,8 +294,17 @@ func (e *Encoder) Encode(v any) (err error) { // See the documentation for Marshal for details about the // conversion of Go values into YAML. func (n *Node) Encode(v any) (err error) { + return n.EncodeWithContext(context.Background(), v) +} + +// EncodeWithContext is the same as Encode, but also provides +// a context to the MarshalerWithContext implementations. +// +// See the documentation for Marshal for details about the +// conversion of Go values into YAML. +func (n *Node) EncodeWithContext(ctx context.Context, v any) (err error) { defer handleErr(&err) - e := newEncoder() + e := newEncoder(ctx) defer e.destroy() e.marshalDoc("", reflect.ValueOf(v)) e.finish()