Skip to content

Commit ba294cb

Browse files
authored
feat: add bandwidth limit
1 parent 3d57e07 commit ba294cb

File tree

3 files changed

+100
-9
lines changed

3 files changed

+100
-9
lines changed

src/cli/cli.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ func Run() (err error) {
9494
&cli.StringFlag{Name: "ports", Value: "9009,9010,9011,9012,9013", Usage: "ports of the relay"},
9595
&cli.IntFlag{Name: "port", Value: 9009, Usage: "base port for the relay"},
9696
&cli.IntFlag{Name: "transfers", Value: 5, Usage: "number of ports to use for relay"},
97+
&cli.Int64Flag{Name: "bandwidth", Value: 0, Usage: "maximum bandwidth per transfer in megabytes (0 = unlimited)"},
9798
},
9899
},
99100
}
@@ -736,17 +737,27 @@ func relay(c *cli.Context) (err error) {
736737
}
737738
}
738739

740+
// Get bandwidth limit in megabytes and convert to bytes
741+
bandwidthMB := c.Int64("bandwidth")
742+
var bandwidthBytes int64
743+
if bandwidthMB > 0 {
744+
bandwidthBytes = bandwidthMB * 1024 * 1024
745+
log.Infof("bandwidth limit set to %d MB (%d bytes) per transfer", bandwidthMB, bandwidthBytes)
746+
} else {
747+
log.Info("no bandwidth limit set")
748+
}
749+
739750
tcpPorts := strings.Join(ports[1:], ",")
740751
for i, port := range ports {
741752
if i == 0 {
742753
continue
743754
}
744755
go func(portStr string) {
745-
err := tcp.Run(debugString, host, portStr, determinePass(c))
756+
err := tcp.RunWithBandwidthLimit(debugString, host, portStr, determinePass(c), bandwidthBytes)
746757
if err != nil {
747758
panic(err)
748759
}
749760
}(port)
750761
}
751-
return tcp.Run(debugString, host, ports[0], determinePass(c), tcpPorts)
762+
return tcp.RunWithBandwidthLimit(debugString, host, ports[0], determinePass(c), bandwidthBytes, tcpPorts)
752763
}

src/tcp/options.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ func WithRoomTTL(ttl time.Duration) serverOptsFunc {
4343
}
4444
}
4545

46+
func WithMaxBandwidth(maxBytes int64) serverOptsFunc {
47+
return func(s *server) error {
48+
s.maxBandwidth = maxBytes
49+
return nil
50+
}
51+
}
52+
4653
func containsSlice(s []string, e string) bool {
4754
for _, ss := range s {
4855
if e == ss {

src/tcp/tcp.go

Lines changed: 80 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,18 @@ type server struct {
2626

2727
roomCleanupInterval time.Duration
2828
roomTTL time.Duration
29+
maxBandwidth int64 // maximum bytes allowed per transfer (0 = unlimited)
2930

3031
stopRoomCleanup chan struct{}
3132
}
3233

3334
type roomInfo struct {
34-
first *comm.Comm
35-
second *comm.Comm
36-
opened time.Time
37-
full bool
35+
first *comm.Comm
36+
second *comm.Comm
37+
opened time.Time
38+
full bool
39+
bytesTransferred int64 // total bytes transferred in this room
40+
transferMutex sync.Mutex // protects bytesTransferred
3841
}
3942

4043
type roomMap struct {
@@ -74,6 +77,11 @@ func Run(debugLevel, host, port, password string, banner ...string) (err error)
7477
return RunWithOptionsAsync(host, port, password, WithBanner(banner...), WithLogLevel(debugLevel))
7578
}
7679

80+
// RunWithBandwidthLimit starts a tcp listener with bandwidth limiting, run async
81+
func RunWithBandwidthLimit(debugLevel, host, port, password string, maxBandwidthBytes int64, banner ...string) (err error) {
82+
return RunWithOptionsAsync(host, port, password, WithBanner(banner...), WithLogLevel(debugLevel), WithMaxBandwidth(maxBandwidthBytes))
83+
}
84+
7785
func (s *server) start() (err error) {
7886
log.SetLevel(s.debugLevel)
7987

@@ -361,7 +369,7 @@ func (s *server) clientCommunication(port string, c *comm.Comm) (room string, er
361369
// start piping
362370
go func(com1, com2 *comm.Comm, wg *sync.WaitGroup) {
363371
log.Debug("starting pipes")
364-
pipe(com1.Connection(), com2.Connection())
372+
s.pipe(com1.Connection(), com2.Connection(), room)
365373
wg.Done()
366374
log.Debug("done piping")
367375
}(otherConnection, c, &wg)
@@ -432,8 +440,9 @@ func chanFromConn(conn net.Conn) chan []byte {
432440
}
433441

434442
// pipe creates a full-duplex pipe between the two sockets and
435-
// transfers data from one to the other.
436-
func pipe(conn1 net.Conn, conn2 net.Conn) {
443+
// transfers data from one to the other. It tracks bandwidth usage
444+
// and enforces limits if configured.
445+
func (s *server) pipe(conn1 net.Conn, conn2 net.Conn, room string) {
437446
chan1 := chanFromConn(conn1)
438447
chan2 := chanFromConn(conn2)
439448

@@ -443,16 +452,80 @@ func pipe(conn1 net.Conn, conn2 net.Conn) {
443452
if b1 == nil {
444453
return
445454
}
455+
456+
// Check bandwidth limit before writing
457+
if s.maxBandwidth > 0 {
458+
s.rooms.Lock()
459+
roomInfo, exists := s.rooms.rooms[room]
460+
s.rooms.Unlock()
461+
462+
if exists {
463+
roomInfo.transferMutex.Lock()
464+
newTotal := roomInfo.bytesTransferred + int64(len(b1))
465+
466+
if newTotal > s.maxBandwidth {
467+
roomInfo.transferMutex.Unlock()
468+
log.Warnf("bandwidth limit exceeded for room %s: %d bytes (limit: %d bytes)",
469+
room, newTotal, s.maxBandwidth)
470+
// Close both connections immediately
471+
conn1.Close()
472+
conn2.Close()
473+
// Delete the room
474+
s.deleteRoom(room)
475+
return
476+
}
477+
478+
roomInfo.bytesTransferred = newTotal
479+
s.rooms.Lock()
480+
s.rooms.rooms[room] = roomInfo
481+
s.rooms.Unlock()
482+
roomInfo.transferMutex.Unlock()
483+
}
484+
}
485+
446486
if _, err := conn2.Write(b1); err != nil {
447487
log.Errorf("write error on channel 1: %v", err)
488+
return
448489
}
449490

450491
case b2 := <-chan2:
451492
if b2 == nil {
452493
return
453494
}
495+
496+
// Check bandwidth limit before writing
497+
if s.maxBandwidth > 0 {
498+
s.rooms.Lock()
499+
roomInfo, exists := s.rooms.rooms[room]
500+
s.rooms.Unlock()
501+
502+
if exists {
503+
roomInfo.transferMutex.Lock()
504+
newTotal := roomInfo.bytesTransferred + int64(len(b2))
505+
506+
if newTotal > s.maxBandwidth {
507+
roomInfo.transferMutex.Unlock()
508+
log.Warnf("bandwidth limit exceeded for room %s: %d bytes (limit: %d bytes)",
509+
room, newTotal, s.maxBandwidth)
510+
// Close both connections immediately
511+
conn1.Close()
512+
conn2.Close()
513+
// Delete the room
514+
s.deleteRoom(room)
515+
return
516+
}
517+
518+
roomInfo.bytesTransferred = newTotal
519+
s.rooms.Lock()
520+
s.rooms.rooms[room] = roomInfo
521+
s.rooms.Unlock()
522+
roomInfo.transferMutex.Unlock()
523+
}
524+
}
525+
454526
if _, err := conn1.Write(b2); err != nil {
455527
log.Errorf("write error on channel 2: %v", err)
528+
return
456529
}
457530
}
458531
}

0 commit comments

Comments
 (0)