Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package yaml

import (
"context"
"encoding"
"encoding/base64"
"fmt"
Expand Down Expand Up @@ -312,6 +313,7 @@ func (p *parser) mapping() *Node {
// Decoder, unmarshals a node into a provided value.

type decoder struct {
ctx context.Context
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a context stored in a struct is an anti pattern.

For me, there is a problem here.

But it might be needed if we don't want a PR that breaks everything and methods signature.

doc *Node
aliases map[*Node]bool
terrors []*UnmarshalError
Expand All @@ -336,8 +338,9 @@ var (
ifaceType = generalMapType.Elem()
)

func newDecoder() *decoder {
func newDecoder(ctx context.Context) *decoder {
d := &decoder{
ctx: ctx,
stringMapType: stringMapType,
generalMapType: generalMapType,
uniqueKeys: true,
Expand Down Expand Up @@ -383,6 +386,24 @@ func (d *decoder) callUnmarshaler(n *Node, u Unmarshaler) (good bool) {
}
}

func (d *decoder) callUnmarshalerWithContext(n *Node, u UnmarshalerWithContext) (good bool) {
err := u.UnmarshalYAML(d.ctx, n)
switch e := err.(type) {
case nil:
return true
case *TypeError:
d.terrors = append(d.terrors, e.Errors...)
return false
default:
d.terrors = append(d.terrors, &UnmarshalError{
Err: err,
Line: n.Line,
Column: n.Column,
})
return false
}
}

func (d *decoder) callObsoleteUnmarshaler(n *Node, u obsoleteUnmarshaler) (good bool) {
terrlen := len(d.terrors)
err := u.UnmarshalYAML(func(v any) (err error) {
Expand Down Expand Up @@ -434,6 +455,10 @@ func (d *decoder) prepare(n *Node, out reflect.Value) (newout reflect.Value, unm
}
if out.CanAddr() {
outi := out.Addr().Interface()
if u, ok := outi.(UnmarshalerWithContext); ok {
good = d.callUnmarshalerWithContext(n, u)
return out, true, good
}
if u, ok := outi.(Unmarshaler); ok {
good = d.callUnmarshaler(n, u)
return out, true, good
Expand Down
24 changes: 24 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package yaml_test

import (
"bytes"
"context"
"encoding"
"errors"
"fmt"
Expand Down Expand Up @@ -2123,6 +2124,29 @@ a:
}
}

type typeWithContext struct {
value map[string]any
context context.Context
}

func (t *typeWithContext) UnmarshalYAML(ctx context.Context, value *yaml.Node) error {
t.context = ctx
t.value = make(map[string]any)
value.Decode(&t.value)
return nil
}

func TestDecodeWithContext(t *testing.T) {
type contextKey struct{}
ctx := context.WithValue(context.Background(), contextKey{}, "value")
var v typeWithContext
err := yaml.UnmarshalWithContext(ctx, []byte("foo: bar"), &v)

assert.NoError(t, err)
assert.Equal(t, v.context.Value(contextKey{}), "value")
assert.DeepEqual(t, v.value, map[string]any{"foo": "bar"})
}

//var data []byte
//func init() {
// var err error
Expand Down
16 changes: 15 additions & 1 deletion encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package yaml

import (
"context"
"encoding"
"fmt"
"io"
Expand All @@ -31,6 +32,7 @@ import (
)

type encoder struct {
context context.Context
emitter libyaml.Emitter
event libyaml.Event
out []byte
Expand All @@ -39,9 +41,10 @@ type encoder struct {
doneInit bool
}

func newEncoder() *encoder {
func newEncoder(ctx context.Context) *encoder {
e := &encoder{
emitter: libyaml.NewEmitter(),
context: ctx,
}
e.emitter.SetOutputString(&e.out)
e.emitter.SetUnicode(true)
Expand Down Expand Up @@ -140,6 +143,17 @@ func (e *encoder) marshal(tag string, in reflect.Value) {
case time.Duration:
e.stringv(tag, reflect.ValueOf(value.String()))
return
case MarshalerWithContext:
v, err := value.MarshalYAML(e.context)
if err != nil {
fail(err)
}
if v == nil {
e.nilv()
return
}
e.marshal(tag, reflect.ValueOf(v))
return
case Marshaler:
v, err := value.MarshalYAML()
if err != nil {
Expand Down
9 changes: 9 additions & 0 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package yaml_test

import (
"bytes"
"context"
"fmt"
"math"
"net"
Expand Down Expand Up @@ -1407,3 +1408,11 @@ func TestUnicodeWhitespaceHandling(t *testing.T) {
})
}
}

func TestMarshalWithContext(t *testing.T) {
type contextKey struct{}
ctx := context.WithValue(context.Background(), contextKey{}, "value")
data, err := yaml.MarshalWithContext(ctx, map[string]string{"foo": "bar"})
assert.NoError(t, err)
assert.Equal(t, "foo: bar\n", string(data))
}
59 changes: 52 additions & 7 deletions yaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
package yaml

import (
"context"
"errors"
"fmt"
"io"
Expand All @@ -40,6 +41,11 @@ type Unmarshaler interface {
UnmarshalYAML(value *Node) error
}

// UnmarshalerWithContext is the same as Unmarshaler, but it also receives a context.
type UnmarshalerWithContext interface {
UnmarshalYAML(ctx context.Context, value *Node) error
}

type obsoleteUnmarshaler interface {
UnmarshalYAML(unmarshal func(any) error) error
}
Expand All @@ -54,6 +60,11 @@ type Marshaler interface {
MarshalYAML() (any, error)
}

// MarshalerWithContext is the same as Marshaler, but it also receives a context.
type MarshalerWithContext interface {
MarshalYAML(ctx context.Context) (any, error)
}

// Unmarshal decodes the first document found within the in byte slice
// and assigns decoded values into the out value.
//
Expand Down Expand Up @@ -88,7 +99,13 @@ type Marshaler interface {
// See the documentation of Marshal for the format of tags and a list of
// supported tag options.
func Unmarshal(in []byte, out any) (err error) {
return unmarshal(in, out, false)
return unmarshal(context.Background(), in, out, false)
}

// UnmarshalWithContext is the same as Unmarshal, but also provides
// a context to the UnmarshalerWithContext implementations.
func UnmarshalWithContext(ctx context.Context, in []byte, out any) (err error) {
return unmarshal(ctx, in, out, false)
}

// A Decoder reads and decodes YAML values from an input stream.
Expand Down Expand Up @@ -119,7 +136,16 @@ func (dec *Decoder) KnownFields(enable bool) {
// See the documentation for Unmarshal for details about the
// conversion of YAML into a Go value.
func (dec *Decoder) Decode(v any) (err error) {
d := newDecoder()
return dec.DecodeWithContext(context.Background(), v)
}

// DecodeWithContext reads the next YAML-encoded value from its input
// and stores it in the value pointed to by v.
//
// See the documentation for Unmarshal for details about the
// conversion of YAML into a Go value.
func (dec *Decoder) DecodeWithContext(ctx context.Context, v any) (err error) {
d := newDecoder(ctx)
d.knownFields = dec.knownFields
defer handleErr(&err)
node := dec.parser.parse()
Expand All @@ -142,7 +168,11 @@ func (dec *Decoder) Decode(v any) (err error) {
// See the documentation for Unmarshal for details about the
// conversion of YAML into a Go value.
func (n *Node) Decode(v any) (err error) {
d := newDecoder()
return n.DecodeWithContext(context.Background(), v)
}

func (n *Node) DecodeWithContext(ctx context.Context, v any) (err error) {
d := newDecoder(ctx)
defer handleErr(&err)
out := reflect.ValueOf(v)
if out.Kind() == reflect.Pointer && !out.IsNil() {
Expand All @@ -155,9 +185,9 @@ func (n *Node) Decode(v any) (err error) {
return nil
}

func unmarshal(in []byte, out any, strict bool) (err error) {
func unmarshal(ctx context.Context, in []byte, out any, strict bool) (err error) {
defer handleErr(&err)
d := newDecoder()
d := newDecoder(ctx)
p := newParser(in)
defer p.destroy()
node := p.parse()
Expand Down Expand Up @@ -217,8 +247,14 @@ func unmarshal(in []byte, out any, strict bool) (err error) {
// yaml.Marshal(&T{B: 2}) // Returns "b: 2\n"
// yaml.Marshal(&T{F: 1}} // Returns "a: 1\nb: 0\n"
func Marshal(in any) (out []byte, err error) {
return MarshalWithContext(context.Background(), in)
}

// MarshalWithContext is the same as Marshal, but also provides
// a context to the MarshalerWithContext implementations.
func MarshalWithContext(ctx context.Context, in any) (out []byte, err error) {
defer handleErr(&err)
e := newEncoder()
e := newEncoder(ctx)
defer e.destroy()
e.marshalDoc("", reflect.ValueOf(in))
e.finish()
Expand Down Expand Up @@ -258,8 +294,17 @@ func (e *Encoder) Encode(v any) (err error) {
// See the documentation for Marshal for details about the
// conversion of Go values into YAML.
func (n *Node) Encode(v any) (err error) {
return n.EncodeWithContext(context.Background(), v)
}

// EncodeWithContext is the same as Encode, but also provides
// a context to the MarshalerWithContext implementations.
//
// See the documentation for Marshal for details about the
// conversion of Go values into YAML.
func (n *Node) EncodeWithContext(ctx context.Context, v any) (err error) {
defer handleErr(&err)
e := newEncoder()
e := newEncoder(ctx)
defer e.destroy()
e.marshalDoc("", reflect.ValueOf(v))
e.finish()
Expand Down