@@ -99,15 +99,20 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
9999 rawKey = raw
100100 }
101101
102- // Extract ECDH-ES specific parameters if needed
102+ // Extract ECDH-ES specific parameters if needed.
103103 var apu , apv []byte
104- if b .headers != nil {
105- if val , ok := b .headers .AgreementPartyUInfo (); ok {
106- apu = val
107- }
108- if val , ok := b .headers .AgreementPartyVInfo (); ok {
109- apv = val
110- }
104+
105+ hdr := b .headers
106+ if hdr == nil {
107+ hdr = NewHeaders ()
108+ }
109+
110+ if val , ok := hdr .AgreementPartyUInfo (); ok {
111+ apu = val
112+ }
113+
114+ if val , ok := hdr .AgreementPartyVInfo (); ok {
115+ apv = val
111116 }
112117
113118 // Create the encrypter using the new jwebb pattern
@@ -116,20 +121,20 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
116121 return nil , fmt .Errorf (`jwe.Encrypt: recipientBuilder: failed to create encrypter: %w` , err )
117122 }
118123
119- if hdrs := b .headers ; hdrs != nil {
120- _ = r .SetHeaders (hdrs )
121- }
124+ _ = r .SetHeaders (hdr )
122125
123- if err := r .Headers ().Set (AlgorithmKey , b .alg ); err != nil {
126+ // Populate headers with stuff that we automatically set
127+ if err := hdr .Set (AlgorithmKey , b .alg ); err != nil {
124128 return nil , fmt .Errorf (`failed to set header: %w` , err )
125129 }
126130
127131 if keyID != "" {
128- if err := r . Headers () .Set (KeyIDKey , keyID ); err != nil {
132+ if err := hdr .Set (KeyIDKey , keyID ); err != nil {
129133 return nil , fmt .Errorf (`failed to set header: %w` , err )
130134 }
131135 }
132136
137+ // Handle the encrypted key
133138 var rawCEK []byte
134139 enckey , err := enc .EncryptKey (cek )
135140 if err != nil {
@@ -143,8 +148,9 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
143148 }
144149 }
145150
151+ // finally, anything specific should go here
146152 if hp , ok := enckey .(populater ); ok {
147- if err := hp .Populate (r . Headers () ); err != nil {
153+ if err := hp .Populate (hdr ); err != nil {
148154 return nil , fmt .Errorf (`failed to populate: %w` , err )
149155 }
150156 }
@@ -154,7 +160,9 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
154160
155161// Encrypt generates a JWE message for the given payload and returns
156162// it in serialized form, which can be in either compact or
157- // JSON format. Default is compact.
163+ // JSON format. Default is compact. When JSON format is specified and
164+ // there is only one recipient, the resulting serialization is
165+ // automatically converted to flattened JSON serialization format.
158166//
159167// You must pass at least one key to `jwe.Encrypt()` by using `jwe.WithKey()`
160168// option.
@@ -172,6 +180,10 @@ func (b *recipientBuilder) Build(r Recipient, cek []byte, calg jwa.ContentEncryp
172180//
173181// Look for options that return `jwe.EncryptOption` or `jws.EncryptDecryptOption`
174182// for a complete list of options that can be passed to this function.
183+ //
184+ // As of v3.0.12, users can specify `jwe.WithLegacyHeaderMerging()` to
185+ // disable header merging behavior that was the default prior to v3.0.12.
186+ // Read the documentation for `jwe.WithLegacyHeaderMerging()` for more information.
175187func Encrypt (payload []byte , options ... EncryptOption ) ([]byte , error ) {
176188 ec := encryptContextPool .Get ()
177189 defer encryptContextPool .Put (ec )
@@ -410,10 +422,26 @@ func (dc *decryptContext) decryptContent(msg *Message, alg jwa.KeyEncryptionAlgo
410422 Tag (msg .tag ).
411423 CEK (dc .cek )
412424
413- if v , ok := recipient .Headers ().Algorithm (); ! ok || v != alg {
414- // algorithms don't match
425+ // The "alg" header can be in either protected/unprotected headers.
426+ // prefer per-recipient headers (as it might be the case that the algorithm differs
427+ // by each recipient), then look at protected headers.
428+ var algMatched bool
429+ for _ , hdr := range []Headers {recipient .Headers (), protectedHeaders } {
430+ v , ok := hdr .Algorithm ()
431+ if ! ok {
432+ continue
433+ }
434+
435+ if v == alg {
436+ algMatched = true
437+ break
438+ }
439+ // if we found something but didn't match, it's a failure
415440 return nil , fmt .Errorf (`jwe.Decrypt: key (%q) and recipient (%q) algorithms do not match` , alg , v )
416441 }
442+ if ! algMatched {
443+ return nil , fmt .Errorf (`jwe.Decrypt: failed to find "alg" header in either protected or per-recipient headers` )
444+ }
417445
418446 h2 , err := protectedHeaders .Clone ()
419447 if err != nil {
@@ -534,11 +562,12 @@ func (dc *decryptContext) decryptContent(msg *Message, alg jwa.KeyEncryptionAlgo
534562
535563// encryptContext holds the state during JWE encryption, similar to JWS signContext
536564type encryptContext struct {
537- calg jwa.ContentEncryptionAlgorithm
538- compression jwa.CompressionAlgorithm
539- format int
540- builders []* recipientBuilder
541- protected Headers
565+ calg jwa.ContentEncryptionAlgorithm
566+ compression jwa.CompressionAlgorithm
567+ format int
568+ builders []* recipientBuilder
569+ protected Headers
570+ legacyHeaderMerging bool
542571}
543572
544573var encryptContextPool = pool .New (allocEncryptContext , freeEncryptContext )
@@ -561,6 +590,7 @@ func freeEncryptContext(ec *encryptContext) *encryptContext {
561590}
562591
563592func (ec * encryptContext ) ProcessOptions (options []EncryptOption ) error {
593+ ec .legacyHeaderMerging = true
564594 var mergeProtected bool
565595 var useRawCEK bool
566596 for _ , option := range options {
@@ -577,7 +607,11 @@ func (ec *encryptContext) ProcessOptions(options []EncryptOption) error {
577607 if v == jwa .DIRECT () || v == jwa .ECDH_ES () {
578608 useRawCEK = true
579609 }
580- ec .builders = append (ec .builders , & recipientBuilder {alg : v , key : wk .key , headers : wk .headers })
610+ ec .builders = append (ec .builders , & recipientBuilder {
611+ alg : v ,
612+ key : wk .key ,
613+ headers : wk .headers ,
614+ })
581615 case identContentEncryptionAlgorithm {}:
582616 var c jwa.ContentEncryptionAlgorithm
583617 if err := option .Value (& c ); err != nil {
@@ -616,6 +650,12 @@ func (ec *encryptContext) ProcessOptions(options []EncryptOption) error {
616650 return err
617651 }
618652 ec .format = fmtOpt
653+ case identLegacyHeaderMerging {}:
654+ var v bool
655+ if err := option .Value (& v ); err != nil {
656+ return err
657+ }
658+ ec .legacyHeaderMerging = v
619659 }
620660 }
621661
@@ -732,7 +772,8 @@ func (ec *encryptContext) EncryptMessage(payload []byte, cek []byte) ([]byte, er
732772 }
733773 }
734774
735- recipients := recipientSlicePool .GetCapacity (len (ec .builders ))
775+ lbuilders := len (ec .builders )
776+ recipients := recipientSlicePool .GetCapacity (lbuilders )
736777 defer recipientSlicePool .Put (recipients )
737778
738779 for i , builder := range ec .builders {
@@ -767,14 +808,55 @@ func (ec *encryptContext) EncryptMessage(payload []byte, cek []byte) ([]byte, er
767808 }
768809 }
769810
770- // If there's only one recipient, you want to include that in the
771- // protected header
772- if len (recipients ) == 1 {
811+ // fmtCompact does not have per-recipient headers, nor a "header" field.
812+ // In this mode, we're going to have to merge everything to the protected
813+ // header.
814+ if ec .format == fmtCompact {
815+ // We have already established that the number of builders is 1 in
816+ // ec.ProcessOptions(). But we're going to be pedantic
817+ if lbuilders != 1 {
818+ return nil , fmt .Errorf (`internal error: expected exactly one recipient builder (got %d)` , lbuilders )
819+ }
820+
821+ // when we're using compact format, we can safely merge per-recipient
822+ // headers into the protected header, if any
773823 h , err := protected .Merge (recipients [0 ].Headers ())
774824 if err != nil {
775- return nil , fmt .Errorf (`failed to merge protected headers: %w` , err )
825+ return nil , fmt .Errorf (`failed to merge protected headers for compact serialization : %w` , err )
776826 }
777827 protected = h
828+ // per-recipient headers, if any, will be ignored in compact format
829+ } else {
830+ // If it got here, it's JSON (could be pretty mode, too).
831+ if lbuilders == 1 {
832+ // If it got here, then we're doing flattened JSON serialization.
833+ // In this mode, we should merge per-recipient headers into the protected header,
834+ // but we also need to make sure that the "header" field is reset so that
835+ // it does not contain the same fields as the protected header.
836+ //
837+ // However, old behavior was to merge per-recipient headers into the
838+ // protected header when there was only one recipient, AND leave the
839+ // original "header" field as is, so we need to support that for backwards compatibility.
840+ //
841+ // The legacy merging only takes effect when there is exactly one recipient.
842+ //
843+ // This behavior can be disabled by passing jwe.WithLegacyHeaderMerging(false)
844+ // If the user has explicitly asked for merging, do it
845+ h , err := protected .Merge (recipients [0 ].Headers ())
846+ if err != nil {
847+ return nil , fmt .Errorf (`failed to merge protected headers for flattenend JSON format: %w` , err )
848+ }
849+ protected = h
850+
851+ if ! ec .legacyHeaderMerging {
852+ // Clear per-recipient headers, since they have been merged.
853+ // But we only do it when legacy merging is disabled.
854+ // Note: we should probably introduce a Reset() method in v4
855+ if err := recipients [0 ].SetHeaders (NewHeaders ()); err != nil {
856+ return nil , fmt .Errorf (`failed to clear per-recipient headers after merging: %w` , err )
857+ }
858+ }
859+ }
778860 }
779861
780862 aad , err := protected .Encode ()
0 commit comments