Skip to content
130 changes: 124 additions & 6 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package yaml
import (
"encoding"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"math"
Expand Down Expand Up @@ -319,11 +321,12 @@ type decoder struct {
stringMapType reflect.Type
generalMapType reflect.Type

knownFields bool
uniqueKeys bool
decodeCount int
aliasCount int
aliasDepth int
knownFields bool
uniqueKeys bool
decodeCount int
aliasCount int
aliasDepth int
fallbackToJSON bool

mergedFields map[any]bool
}
Expand Down Expand Up @@ -383,6 +386,114 @@ func (d *decoder) callUnmarshaler(n *Node, u Unmarshaler) (good bool) {
}
}

func (d *decoder) callJSONUnmarshaler(n *Node, u json.Unmarshaler) bool {
// First decode the node into an interface{} using the normal decoding rules.
var v any
if err := n.Decode(&v); err != nil {
var te *TypeError
if errors.As(err, &te) {
d.terrors = append(d.terrors, te.Errors...)
} else {
d.terrors = append(d.terrors, &UnmarshalError{
Err: err,
Line: n.Line,
Column: n.Column,
})
}
return false
}

// Second, marshal that intermediate representation into JSON
// and unmarshal it using the JSON unmarshaler.
if err := unmarshalJSON(v, u); err != nil {
d.terrors = append(d.terrors, &UnmarshalError{
Err: err,
Line: n.Line,
Column: n.Column,
})
return false
}
return true
}

// unmarshalJSON marshals the value into JSON and unmarshals it using the JSON unmarshaler.
// NOTE: This double conversion (YAML -> interface{} -> JSON -> target type) adds overhead,
// but is necessary to support types that implement json.Unmarshaler. There is no
// more direct way to invoke UnmarshalJSON with YAML input, so this trade-off is
// required for compatibility.
// Additionally, it normalizes the intermediate representation to ensure
// that map keys are strings, as required by JSON.
func unmarshalJSON(input any, unmarshaler json.Unmarshaler) error {
normalizedInput, err := normalizeJSON(input)
if err != nil {
return err
}
data, err := json.Marshal(normalizedInput)
if err != nil {
return err
}
return unmarshaler.UnmarshalJSON(data)
}

// normalizeJSON converts a YAML-parsed structure into a form suitable for JSON marshaling.
// The JSON specification requires that object keys be strings, so this function
// converts map[any]any to map[string]any by marshaling non-string keys to JSON strings.
func normalizeJSON(input any) (any, error) {
x := reflect.ValueOf(input)
switch x.Kind() {
case reflect.Map:
m := make(map[string]any, x.Len())

iter := x.MapRange()
for iter.Next() {
k := iter.Key().Interface()
v := iter.Value().Interface()
jv, err := normalizeJSON(v)
if err != nil {
return nil, err
}
if s, ok := k.(string); ok {
if _, ok := m[s]; ok {
return nil, fmt.Errorf("duplicate key %q found when converting to JSON object", s)
}
m[s] = jv
continue
}

// Convert non-string keys to JSON and then to strings.
jk, err := normalizeJSON(k)
if err != nil {
return nil, err
}
raw, err := json.Marshal(jk)
if err != nil {
return nil, err
}
sk := string(raw)
if _, ok := m[sk]; ok {
return nil, fmt.Errorf("duplicate key %q found when converting to JSON object", sk)
}
m[sk] = jv
}
return m, nil

case reflect.Slice, reflect.Array:
a := make([]any, x.Len())
for i := 0; i < x.Len(); i++ {
v := x.Index(i).Interface()
jv, err := normalizeJSON(v)
if err != nil {
return nil, err
}
a[i] = jv
}
return a, nil

default:
return input, nil
}
}

func (d *decoder) callObsoleteUnmarshaler(n *Node, u obsoleteUnmarshaler) (good bool) {
terrlen := len(d.terrors)
err := u.UnmarshalYAML(func(v any) (err error) {
Expand Down Expand Up @@ -442,6 +553,13 @@ func (d *decoder) prepare(n *Node, out reflect.Value) (newout reflect.Value, unm
good = d.callObsoleteUnmarshaler(n, u)
return out, true, good
}
if d.fallbackToJSON {
// Try JSON unmarshaler as a fallback.
if u, ok := outi.(json.Unmarshaler); ok {
good = d.callJSONUnmarshaler(n, u)
return out, true, good
}
}
}
}
return out, false, false
Expand Down Expand Up @@ -892,7 +1010,7 @@ func isStringMap(n *Node) bool {
}

func (d *decoder) mappingStruct(n *Node, out reflect.Value) (good bool) {
sinfo, err := getStructInfo(out.Type())
sinfo, err := getStructInfo(out.Type(), d.fallbackToJSON)
if err != nil {
panic(err)
}
Expand Down
95 changes: 95 additions & 0 deletions decode_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package yaml

import (
"encoding/json"
"testing"

"go.yaml.in/yaml/v4/internal/testutil/assert"
)

type TestUnmarshaler struct {
Value string
Array []int
Map map[string]int
}

func (t *TestUnmarshaler) UnmarshalJSON(data []byte) error {
type Alias TestUnmarshaler
var aux Alias
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
*t = TestUnmarshaler(aux)
return nil
}

func Test_unmarshalJSON(t *testing.T) {
tests := []struct {
name string
input any
expected any
wantErr string
}{
{
name: "map with string keys",
input: map[string]any{
"Value": "hello",
"Array": []int{1, 2, 3},
"Map": map[string]int{"a": 1, "b": 2},
},
expected: &TestUnmarshaler{
Value: "hello",
Array: []int{1, 2, 3},
Map: map[string]int{"a": 1, "b": 2},
},
},
{
name: "map with int keys",
input: map[string]any{
"Value": "hello",
"Array": []int{1, 2, 3},
"Map": map[int]int{1: 1, 2: 2},
},
expected: &TestUnmarshaler{
Value: "hello",
Array: []int{1, 2, 3},
Map: map[string]int{"1": 1, "2": 2},
},
},
{
name: "map with any keys",
input: map[string]any{
"Value": "hello",
"Array": []int{1, 2, 3},
"Map": map[any]int{1: 1, "b": 2, true: 3},
},
expected: &TestUnmarshaler{
Value: "hello",
Array: []int{1, 2, 3},
Map: map[string]int{"1": 1, "b": 2, "true": 3},
},
},
{
name: "map with duplicate keys",
input: map[string]any{
"Value": "hello",
"Array": []int{1, 2, 3},
"Map": map[any]int{1: 1, "1": 2},
},
wantErr: `duplicate key "1" found when converting to JSON object`,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
target := &TestUnmarshaler{}
err := unmarshalJSON(tt.input, target)
if tt.wantErr != "" {
assert.ErrorMatchesf(t, tt.wantErr, err, "unmarshalJSON() error")
return
}
assert.NoError(t, err)
assert.DeepEqualf(t, tt.expected, target, "unmarshalJSON() result")
})
}
}
39 changes: 32 additions & 7 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package yaml

import (
"encoding"
"encoding/json"
"fmt"
"io"
"reflect"
Expand All @@ -31,12 +32,13 @@ import (
)

type encoder struct {
emitter libyaml.Emitter
event libyaml.Event
out []byte
flow bool
indent int
doneInit bool
emitter libyaml.Emitter
event libyaml.Event
out []byte
flow bool
indent int
doneInit bool
fallbackToJSON bool
}

func newEncoder() *encoder {
Expand Down Expand Up @@ -157,6 +159,29 @@ func (e *encoder) marshal(tag string, in reflect.Value) {
fail(err)
}
in = reflect.ValueOf(string(text))
case json.Marshaler:
if !e.fallbackToJSON {
break // do the normal thing
}
// Fallback to JSON marshaling.
// Marshal to JSON,
// then unmarshal to an interface{},
// then marshal that value to YAML.
// NOTE: This double conversion (Source type -> JSON -> interface{} -> YAML) adds overhead,
// but is necessary to support types that implement json.Marshaler. There is no
// more direct way to invoke MarshalJSON for YAML output, so this trade-off is
// required for compatibility.

data, err := value.MarshalJSON()
if err != nil {
fail(err)
}
var v any
if err := json.Unmarshal(data, &v); err != nil {
fail(err)
}
e.marshal(tag, reflect.ValueOf(v))
return
case nil:
e.nilv()
return
Expand Down Expand Up @@ -216,7 +241,7 @@ func (e *encoder) fieldByIndex(v reflect.Value, index []int) (field reflect.Valu
}

func (e *encoder) structv(tag string, in reflect.Value) {
sinfo, err := getStructInfo(in.Type())
sinfo, err := getStructInfo(in.Type(), e.fallbackToJSON)
if err != nil {
panic(err)
}
Expand Down
Loading