Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions device/awg/special_handshake_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ type SpecialHandshakeHandler struct {
IsSet bool
}

func (handler *SpecialHandshakeHandler) Validate() error {
func (handler *SpecialHandshakeHandler) Validate(maxSegmentSize int) error {
var errs []error
if err := handler.SpecialJunk.Validate(); err != nil {
if err := handler.SpecialJunk.Validate(maxSegmentSize); err != nil {
errs = append(errs, err)
}
if err := handler.ControlledJunk.Validate(); err != nil {
if err := handler.ControlledJunk.Validate(maxSegmentSize); err != nil {
errs = append(errs, err)
}
return errors.Join(errs...)
Expand Down
4 changes: 4 additions & 0 deletions device/awg/tag_junk_packet_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,7 @@ func (tg *TagJunkPacketGenerator) IpcGetFields() IpcFields {
Value: tg.tagValue,
}
}

func (tg *TagJunkPacketGenerator) Size() int {
return tg.packetSize
}
10 changes: 8 additions & 2 deletions device/awg/tag_junk_packet_generators.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package awg

import "fmt"
import (
"fmt"
)

type TagJunkPacketGenerators struct {
tagGenerators []TagJunkPacketGenerator
Expand All @@ -20,7 +22,7 @@ func (generators *TagJunkPacketGenerators) IsDefined() bool {
}

// validate that packets were defined consecutively
func (generators *TagJunkPacketGenerators) Validate() error {
func (generators *TagJunkPacketGenerators) Validate(maxSegmentSize int) error {
seen := make([]bool, len(generators.tagGenerators))
for _, generator := range generators.tagGenerators {
index, err := generator.nameIndex()
Expand All @@ -32,6 +34,10 @@ func (generators *TagJunkPacketGenerators) Validate() error {
} else {
seen[index-1] = true
}

if generator.Size() > maxSegmentSize {
return fmt.Errorf("junk packet %s must not exceed %d bytes", generator.name, maxSegmentSize)
}
}

for _, found := range seen {
Expand Down
2 changes: 1 addition & 1 deletion device/awg/tag_junk_packet_generators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func TestTagJunkGeneratorHandlerValidate(t *testing.T) {
generators.AppendGenerator(gen)
}

err := generators.Validate()
err := generators.Validate(1500)
if tt.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errMsg)
Expand Down
2 changes: 1 addition & 1 deletion device/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ func (device *Device) handlePostConfig(tempAwg *awg.Protocol) error {
}

if tempAwg.HandshakeHandler.IsSet {
if err := tempAwg.HandshakeHandler.Validate(); err != nil {
if err := tempAwg.HandshakeHandler.Validate(MaxSegmentSize); err != nil {
errs = append(errs, ipcErrorf(
ipc.IpcErrorInvalid, "handshake handler validate: %w", err))
} else {
Expand Down
8 changes: 4 additions & 4 deletions device/uapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,26 +406,26 @@ func (device *Device) handleDeviceLine(key, value string, tempAwg *awg.Protocol)
return nil
}

generators, err := awg.Parse(key, value)
generator, err := awg.Parse(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
}
device.log.Verbosef("UAPI: Updating %s", key)
tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generators)
tempAwg.HandshakeHandler.SpecialJunk.AppendGenerator(generator)
tempAwg.HandshakeHandler.IsSet = true
case "j1", "j2", "j3":
if len(value) == 0 {
device.log.Verbosef("UAPI: received empty %s", key)
return nil
}

generators, err := awg.Parse(key, value)
generator, err := awg.Parse(key, value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "invalid %s: %w", key, err)
}
device.log.Verbosef("UAPI: Updating %s", key)

tempAwg.HandshakeHandler.ControlledJunk.AppendGenerator(generators)
tempAwg.HandshakeHandler.ControlledJunk.AppendGenerator(generator)
tempAwg.HandshakeHandler.IsSet = true
case "itime":
if len(value) == 0 {
Expand Down