Skip to content

Commit 8923c67

Browse files
authored
[jwe] Add option to explicitly clear per-recipient headers ("header") for flattened JSON serialization (#1477)
* Implement WithLegacyHeaderMerging * Rethink WithLegacyHeaderMerging * appease linter * Re-generate options * Tweak example * Update Changes * Make sure jwe works correctl around "header" field * Update Changes * Remove unnecessary checks for apu/apv
1 parent f2f199d commit 8923c67

File tree

8 files changed

+391
-43
lines changed

8 files changed

+391
-43
lines changed

Changes

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,30 @@ v3 has many incompatibilities with v2. To see the full list of differences betwe
55
v2 and v3, please read the Changes-v3.md file (https://github.com/lestrrat-go/jwx/blob/develop/v3/Changes-v3.md)
66

77
v3.0.12 UNRELEASED
8+
* [jwe] As part of the next change, now per-recipient headers that are empty
9+
are no longer serialized in flattened JSON serialization.
10+
11+
* [jwe] Introduce `jwe.WithLegacyHeaderMerging(bool)` option to control header
12+
merging behavior in during JWE encryption. This only applies to flattened
13+
JSON serialization.
14+
15+
Previously, when using flattened JSON serialization (i.e. you specified
16+
JSON serialization via `jwe.WithJSON()` and only supplied one key), per-recipient
17+
headers were merged into the protected headers during encryption, and then
18+
were left to be included in the final serialization as-is. This caused duplicate
19+
headers to be present in both the protected headers and the per-recipient headers.
20+
21+
Since there maybe users who rely on this behavior already, instead of changing the
22+
default behavior to fix this duplication, a new option to `jwe.Encrypt()` was added
23+
to allow clearing the per-recipient headers after merging to leave the `"headers"`
24+
field empty. This in effect makes the flattened JSON serialization more similar to
25+
the compact serialization, where there are no per-recipient headers present, and
26+
leaves the headers disjoint.
27+
28+
Note that in compact mode, there are no per-recipient headers and thus the
29+
headers need to be merged regardless. In full JSON serialization, we never
30+
merge the headers, so it is left up to the user to keep the headers disjoint.
31+
832
* [jws] Calling the deprecated `jws.NewSigner()` function for the time will cause
933
legacy signers to be loaded automatically. Previously, you had to explicitly
1034
call `jws.Settings(jws.WithLegacySigners(true))` to enable legacy signers.

examples/jwe_encrypt_json_example_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@ func Example_jwe_encrypt_json() {
2929
}
3030

3131
const payload = `Lorem ipsum`
32-
encrypted, err := jwe.Encrypt([]byte(payload), jwe.WithJSON(), jwe.WithKey(jwa.RSA_OAEP(), pubkey))
32+
encrypted, err := jwe.Encrypt(
33+
[]byte(payload),
34+
jwe.WithJSON(), // Toggle JSON serialization. Because there's only one key (recipient), this will produce Flattened JSON serialization
35+
jwe.WithLegacyHeaderMerging(false), // Disable legacy header merging
36+
jwe.WithKey(jwa.RSA_OAEP(), pubkey), // Public key for encryption
37+
)
3338
if err != nil {
3439
fmt.Printf("failed to encrypt payload: %s\n", err)
3540
return

jwe/jwe.go

Lines changed: 110 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,20 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
9999
rawKey = raw
100100
}
101101

102-
// Extract ECDH-ES specific parameters if needed
102+
// Extract ECDH-ES specific parameters if needed.
103103
var apu, apv []byte
104-
if b.headers != nil {
105-
if val, ok := b.headers.AgreementPartyUInfo(); ok {
106-
apu = val
107-
}
108-
if val, ok := b.headers.AgreementPartyVInfo(); ok {
109-
apv = val
110-
}
104+
105+
hdr := b.headers
106+
if hdr == nil {
107+
hdr = NewHeaders()
108+
}
109+
110+
if val, ok := hdr.AgreementPartyUInfo(); ok {
111+
apu = val
112+
}
113+
114+
if val, ok := hdr.AgreementPartyVInfo(); ok {
115+
apv = val
111116
}
112117

113118
// Create the encrypter using the new jwebb pattern
@@ -116,20 +121,20 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
116121
return nil, fmt.Errorf(`jwe.Encrypt: recipientBuilder: failed to create encrypter: %w`, err)
117122
}
118123

119-
if hdrs := b.headers; hdrs != nil {
120-
_ = r.SetHeaders(hdrs)
121-
}
124+
_ = r.SetHeaders(hdr)
122125

123-
if err := r.Headers().Set(AlgorithmKey, b.alg); err != nil {
126+
// Populate headers with stuff that we automatically set
127+
if err := hdr.Set(AlgorithmKey, b.alg); err != nil {
124128
return nil, fmt.Errorf(`failed to set header: %w`, err)
125129
}
126130

127131
if keyID != "" {
128-
if err := r.Headers().Set(KeyIDKey, keyID); err != nil {
132+
if err := hdr.Set(KeyIDKey, keyID); err != nil {
129133
return nil, fmt.Errorf(`failed to set header: %w`, err)
130134
}
131135
}
132136

137+
// Handle the encrypted key
133138
var rawCEK []byte
134139
enckey, err := enc.EncryptKey(cek)
135140
if err != nil {
@@ -143,8 +148,9 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
143148
}
144149
}
145150

151+
// finally, anything specific should go here
146152
if hp, ok := enckey.(populater); ok {
147-
if err := hp.Populate(r.Headers()); err != nil {
153+
if err := hp.Populate(hdr); err != nil {
148154
return nil, fmt.Errorf(`failed to populate: %w`, err)
149155
}
150156
}
@@ -154,7 +160,9 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
154160

155161
// Encrypt generates a JWE message for the given payload and returns
156162
// it in serialized form, which can be in either compact or
157-
// JSON format. Default is compact.
163+
// JSON format. Default is compact. When JSON format is specified and
164+
// there is only one recipient, the resulting serialization is
165+
// automatically converted to flattened JSON serialization format.
158166
//
159167
// You must pass at least one key to `jwe.Encrypt()` by using `jwe.WithKey()`
160168
// option.
@@ -172,6 +180,10 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
172180
//
173181
// Look for options that return `jwe.EncryptOption` or `jws.EncryptDecryptOption`
174182
// for a complete list of options that can be passed to this function.
183+
//
184+
// As of v3.0.12, users can specify `jwe.WithLegacyHeaderMerging()` to
185+
// disable header merging behavior that was the default prior to v3.0.12.
186+
// Read the documentation for `jwe.WithLegacyHeaderMerging()` for more information.
175187
func Encrypt(payload []byte, options ...EncryptOption) ([]byte, error) {
176188
ec := encryptContextPool.Get()
177189
defer encryptContextPool.Put(ec)
@@ -410,10 +422,26 @@ func (dc *decryptContext) decryptContent(msg *Message, alg jwa.KeyEncryptionAlgo
410422
Tag(msg.tag).
411423
CEK(dc.cek)
412424

413-
if v, ok := recipient.Headers().Algorithm(); !ok || v != alg {
414-
// algorithms don't match
425+
// The "alg" header can be in either protected/unprotected headers.
426+
// prefer per-recipient headers (as it might be the case that the algorithm differs
427+
// by each recipient), then look at protected headers.
428+
var algMatched bool
429+
for _, hdr := range []Headers{recipient.Headers(), protectedHeaders} {
430+
v, ok := hdr.Algorithm()
431+
if !ok {
432+
continue
433+
}
434+
435+
if v == alg {
436+
algMatched = true
437+
break
438+
}
439+
// if we found something but didn't match, it's a failure
415440
return nil, fmt.Errorf(`jwe.Decrypt: key (%q) and recipient (%q) algorithms do not match`, alg, v)
416441
}
442+
if !algMatched {
443+
return nil, fmt.Errorf(`jwe.Decrypt: failed to find "alg" header in either protected or per-recipient headers`)
444+
}
417445

418446
h2, err := protectedHeaders.Clone()
419447
if err != nil {
@@ -534,11 +562,12 @@ func (dc *decryptContext) decryptContent(msg *Message, alg jwa.KeyEncryptionAlgo
534562

535563
// encryptContext holds the state during JWE encryption, similar to JWS signContext
536564
type encryptContext struct {
537-
calg jwa.ContentEncryptionAlgorithm
538-
compression jwa.CompressionAlgorithm
539-
format int
540-
builders []*recipientBuilder
541-
protected Headers
565+
calg jwa.ContentEncryptionAlgorithm
566+
compression jwa.CompressionAlgorithm
567+
format int
568+
builders []*recipientBuilder
569+
protected Headers
570+
legacyHeaderMerging bool
542571
}
543572

544573
var encryptContextPool = pool.New(allocEncryptContext, freeEncryptContext)
@@ -561,6 +590,7 @@ func freeEncryptContext(ec *encryptContext) *encryptContext {
561590
}
562591

563592
func (ec *encryptContext) ProcessOptions(options []EncryptOption) error {
593+
ec.legacyHeaderMerging = true
564594
var mergeProtected bool
565595
var useRawCEK bool
566596
for _, option := range options {
@@ -577,7 +607,11 @@ func (ec *encryptContext) ProcessOptions(options []EncryptOption) error {
577607
if v == jwa.DIRECT() || v == jwa.ECDH_ES() {
578608
useRawCEK = true
579609
}
580-
ec.builders = append(ec.builders, &recipientBuilder{alg: v, key: wk.key, headers: wk.headers})
610+
ec.builders = append(ec.builders, &recipientBuilder{
611+
alg: v,
612+
key: wk.key,
613+
headers: wk.headers,
614+
})
581615
case identContentEncryptionAlgorithm{}:
582616
var c jwa.ContentEncryptionAlgorithm
583617
if err := option.Value(&c); err != nil {
@@ -616,6 +650,12 @@ func (ec *encryptContext) ProcessOptions(options []EncryptOption) error {
616650
return err
617651
}
618652
ec.format = fmtOpt
653+
case identLegacyHeaderMerging{}:
654+
var v bool
655+
if err := option.Value(&v); err != nil {
656+
return err
657+
}
658+
ec.legacyHeaderMerging = v
619659
}
620660
}
621661

@@ -732,7 +772,8 @@ func (ec *encryptContext) EncryptMessage(payload []byte, cek []byte) ([]byte, er
732772
}
733773
}
734774

735-
recipients := recipientSlicePool.GetCapacity(len(ec.builders))
775+
lbuilders := len(ec.builders)
776+
recipients := recipientSlicePool.GetCapacity(lbuilders)
736777
defer recipientSlicePool.Put(recipients)
737778

738779
for i, builder := range ec.builders {
@@ -767,14 +808,55 @@ func (ec *encryptContext) EncryptMessage(payload []byte, cek []byte) ([]byte, er
767808
}
768809
}
769810

770-
// If there's only one recipient, you want to include that in the
771-
// protected header
772-
if len(recipients) == 1 {
811+
// fmtCompact does not have per-recipient headers, nor a "header" field.
812+
// In this mode, we're going to have to merge everything to the protected
813+
// header.
814+
if ec.format == fmtCompact {
815+
// We have already established that the number of builders is 1 in
816+
// ec.ProcessOptions(). But we're going to be pedantic
817+
if lbuilders != 1 {
818+
return nil, fmt.Errorf(`internal error: expected exactly one recipient builder (got %d)`, lbuilders)
819+
}
820+
821+
// when we're using compact format, we can safely merge per-recipient
822+
// headers into the protected header, if any
773823
h, err := protected.Merge(recipients[0].Headers())
774824
if err != nil {
775-
return nil, fmt.Errorf(`failed to merge protected headers: %w`, err)
825+
return nil, fmt.Errorf(`failed to merge protected headers for compact serialization: %w`, err)
776826
}
777827
protected = h
828+
// per-recipient headers, if any, will be ignored in compact format
829+
} else {
830+
// If it got here, it's JSON (could be pretty mode, too).
831+
if lbuilders == 1 {
832+
// If it got here, then we're doing flattened JSON serialization.
833+
// In this mode, we should merge per-recipient headers into the protected header,
834+
// but we also need to make sure that the "header" field is reset so that
835+
// it does not contain the same fields as the protected header.
836+
//
837+
// However, old behavior was to merge per-recipient headers into the
838+
// protected header when there was only one recipient, AND leave the
839+
// original "header" field as is, so we need to support that for backwards compatibility.
840+
//
841+
// The legacy merging only takes effect when there is exactly one recipient.
842+
//
843+
// This behavior can be disabled by passing jwe.WithLegacyHeaderMerging(false)
844+
// If the user has explicitly asked for merging, do it
845+
h, err := protected.Merge(recipients[0].Headers())
846+
if err != nil {
847+
return nil, fmt.Errorf(`failed to merge protected headers for flattenend JSON format: %w`, err)
848+
}
849+
protected = h
850+
851+
if !ec.legacyHeaderMerging {
852+
// Clear per-recipient headers, since they have been merged.
853+
// But we only do it when legacy merging is disabled.
854+
// Note: we should probably introduce a Reset() method in v4
855+
if err := recipients[0].SetHeaders(NewHeaders()); err != nil {
856+
return nil, fmt.Errorf(`failed to clear per-recipient headers after merging: %w`, err)
857+
}
858+
}
859+
}
778860
}
779861

780862
aad, err := protected.Encode()

0 commit comments

Comments
 (0)