From 2e4e1973d990e2deec5a6b91c4ddc31f49136a6a Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Wed, 21 Apr 2021 13:41:18 +0300 Subject: [PATCH] Pull request: 2606 substitute nclient4 with fork Merge in DNS/adguard-home from 2606-rm-nclient4 to master Closes #2606. Squashed commit of the following: commit a9abc3ac27b19ef0ab6c4dea8610d97034b24ec2 Author: Eugene Burkov Date: Wed Apr 21 13:22:56 2021 +0300 nclient4: rm commit abcd0042a7f0d1fbf7ebb398a06414bbc7fc2528 Author: Eugene Burkov Date: Wed Apr 21 13:19:11 2021 +0300 all: clear changes commit d1bc3b83f00be07bf9bb97cfe6ef773e07ae8710 Merge: 1b1ab0b9 c2667558 Author: Eugene Burkov Date: Wed Apr 21 12:58:48 2021 +0300 Merge branch 'master' into 2606-rm-nclient4 commit 1b1ab0b9796552854ecffefe8e79ca24b472fed0 Author: Eugene Burkov Date: Mon Apr 19 18:53:04 2021 +0300 dhcpd: subst nclient4 with fork --- go.mod | 6 +- go.sum | 12 +- internal/dhcpd/checkother.go | 2 +- internal/dhcpd/nclient4/client.go | 554 ------------------------- internal/dhcpd/nclient4/client_test.go | 345 --------------- internal/dhcpd/nclient4/conn_unix.go | 140 ------- internal/dhcpd/nclient4/ipv4.go | 377 ----------------- 7 files changed, 5 insertions(+), 1431 deletions(-) delete mode 100644 internal/dhcpd/nclient4/client.go delete mode 100644 internal/dhcpd/nclient4/client_test.go delete mode 100644 internal/dhcpd/nclient4/conn_unix.go delete mode 100644 internal/dhcpd/nclient4/ipv4.go diff --git a/go.mod b/go.mod index 1b6934b8..35ffa592 100644 --- a/go.mod +++ b/go.mod @@ -16,14 +16,11 @@ require ( github.com/gobuffalo/packr/v2 v2.8.1 // indirect github.com/google/go-cmp v0.5.5 // indirect github.com/google/renameio v1.0.1-0.20210406141108-81588dbe0453 - github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 github.com/insomniacslk/dhcp v0.0.0-20210310193751-cfd4d47082c2 github.com/kardianos/service v1.2.0 github.com/karrick/godirwalk v1.16.1 // indirect github.com/lucas-clemente/quic-go v0.20.1 - github.com/mdlayher/ethernet v0.0.0-20190606142754-0394541c37b7 github.com/mdlayher/netlink v1.4.0 - github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065 github.com/miekg/dns v1.1.40 github.com/rogpeppe/go-internal v1.7.0 // indirect github.com/satori/go.uuid v1.2.0 @@ -31,7 +28,6 @@ require ( github.com/spf13/cobra v1.1.3 // indirect github.com/stretchr/testify v1.7.0 github.com/ti-mo/netfilter v0.4.0 - github.com/u-root/u-root v7.0.0+incompatible go.etcd.io/bbolt v1.3.5 golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 @@ -43,3 +39,5 @@ require ( gopkg.in/yaml.v2 v2.4.0 howett.net/plist v0.0.0-20201203080718-1454fab16a06 ) + +replace github.com/insomniacslk/dhcp => github.com/AdguardTeam/dhcp v0.0.0-20210420175708-50b0efd52063 diff --git a/go.sum b/go.sum index 8cd19351..27c83e57 100644 --- a/go.sum +++ b/go.sum @@ -18,11 +18,12 @@ dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBr dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= +github.com/AdguardTeam/dhcp v0.0.0-20210420175708-50b0efd52063 h1:RBsQppxEJEqHApY6WDBkM2H0UG5wt57RcT0El2WGdp8= +github.com/AdguardTeam/dhcp v0.0.0-20210420175708-50b0efd52063/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI= github.com/AdguardTeam/dnsproxy v0.37.1 h1:vULyF1+xSI7vV99m8GD2hmOuCQrpu87awyeSe5qtFbA= github.com/AdguardTeam/dnsproxy v0.37.1/go.mod h1:xkJWEuTr550gPDmB9azsciKZzSXjf9wMn+Ji54PQ4gE= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= -github.com/AdguardTeam/golibs v0.4.4 h1:cM9UySQiYFW79zo5XRwnaIWVzfW4eNXmZktMrWbthpw= github.com/AdguardTeam/golibs v0.4.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.5 h1:RRA9ZsmbJEN4OllAx0BcfvSbRBxxpWluJijBYmtp13U= github.com/AdguardTeam/golibs v0.4.5/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= @@ -68,7 +69,6 @@ github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wX github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/etcd v3.3.13+incompatible h1:8F3hqu9fGYLBifCmRCJsicFqDx/D68Rt3q1JMazcgBQ= github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= @@ -125,7 +125,6 @@ github.com/gobuffalo/packr/v2 v2.5.1/go.mod h1:8f9c96ITobJlPzI44jj+4tHnEKNt0xXWS github.com/gobuffalo/packr/v2 v2.8.1 h1:tkQpju6i3EtMXJ9uoF5GT6kB+LMTimDWD8Xvbz6zDVA= github.com/gobuffalo/packr/v2 v2.8.1/go.mod h1:c/PLlOuTU+p3SybaJATW3H6lX/iK7xEz5OeMf+NnJpg= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= -github.com/gogo/protobuf v1.2.1 h1:/s5zKNz0uPFCZ5hddgPdo2TK2TVrUNMn0OOX8/aZMTE= github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -133,7 +132,6 @@ github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200j github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= -github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/mock v1.5.0 h1:jlYHihg//f7RRwuPfptm04yp4s7O6Kw8EZiVYIGcH0g= github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= @@ -205,11 +203,7 @@ github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 h1:/jC7qQFrv8 github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714/go.mod h1:2Goc3h8EklBH5mspfHFxBnEoURQCGzQQH1ga9Myjvis= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= -github.com/insomniacslk/dhcp v0.0.0-20210310193751-cfd4d47082c2 h1:NpTIlXznCStsY88jU+Gh1Dy5dt/jYV4z4uU8h2TUOt4= -github.com/insomniacslk/dhcp v0.0.0-20210310193751-cfd4d47082c2/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI= github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU= -github.com/jessevdk/go-flags v1.4.0 h1:4IU2WS7AumrZ/40jfhf4QVDMsQwqA7VEHozFRrGARJA= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= @@ -655,7 +649,6 @@ google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98 google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= google.golang.org/genproto v0.0.0-20191108220845-16a3f7862a1a/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= @@ -665,7 +658,6 @@ google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiq google.golang.org/grpc v1.21.0/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.27.0 h1:rRYRFMVgRv6E0D70Skyfsr28tDXIuuPZyWGMPdMcnXg= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= diff --git a/internal/dhcpd/checkother.go b/internal/dhcpd/checkother.go index 11eaa685..c4369e35 100644 --- a/internal/dhcpd/checkother.go +++ b/internal/dhcpd/checkother.go @@ -10,9 +10,9 @@ import ( "runtime" "time" - "github.com/AdguardTeam/AdGuardHome/internal/dhcpd/nclient4" "github.com/AdguardTeam/golibs/log" "github.com/insomniacslk/dhcp/dhcpv4" + "github.com/insomniacslk/dhcp/dhcpv4/nclient4" "github.com/insomniacslk/dhcp/dhcpv6" "github.com/insomniacslk/dhcp/dhcpv6/nclient6" "github.com/insomniacslk/dhcp/iana" diff --git a/internal/dhcpd/nclient4/client.go b/internal/dhcpd/nclient4/client.go deleted file mode 100644 index 127a73c3..00000000 --- a/internal/dhcpd/nclient4/client.go +++ /dev/null @@ -1,554 +0,0 @@ -// Copyright 2018 the u-root Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build darwin dragonfly freebsd linux netbsd openbsd solaris -// +build go1.12 - -// Package nclient4 is a small, minimum-functionality client for DHCPv4. -// -// It only supports the 4-way DHCPv4 Discover-Offer-Request-Ack handshake as -// well as the Request-Ack renewal process. -// Originally from here: github.com/insomniacslk/dhcp/dhcpv4/nclient4 -// with the difference that this package can be built on UNIX (not just Linux), -// because github.com/mdlayher/raw package supports it. -package nclient4 - -import ( - "bytes" - "context" - "errors" - "fmt" - "net" - "sync" - "sync/atomic" - "time" - - "github.com/insomniacslk/dhcp/dhcpv4" -) - -const ( - defaultBufferCap = 5 - - // DefaultTimeout is the default value for read-timeout if option WithTimeout is not set - DefaultTimeout = 5 * time.Second - - // DefaultRetries is amount of retries will be done if no answer was received within read-timeout amount of time - DefaultRetries = 3 - - // MaxMessageSize is the value to be used for DHCP option "MaxMessageSize". - MaxMessageSize = 1500 - - // ClientPort is the port that DHCP clients listen on. - ClientPort = 68 - - // ServerPort is the port that DHCP servers and relay agents listen on. - ServerPort = 67 -) - -// DefaultServers is the address of all link-local DHCP servers and -// relay agents. -var DefaultServers = &net.UDPAddr{ - IP: net.IPv4bcast, - Port: ServerPort, -} - -var ( - // ErrNoResponse is returned when no response packet is received. - ErrNoResponse = errors.New("no matching response packet received") - - // ErrNoConn is returned when NewWithConn is called with nil-value as conn. - ErrNoConn = errors.New("conn is nil") - - // ErrNoIfaceHWAddr is returned when NewWithConn is called with nil-value as ifaceHWAddr - ErrNoIfaceHWAddr = errors.New("ifaceHWAddr is nil") -) - -// pendingCh is a channel associated with a pending TransactionID. -type pendingCh struct { - // SendAndRead closes done to indicate that it wishes for no more - // messages for this particular XID. - done <-chan struct{} - - // ch is used by the receive loop to distribute DHCP messages. - ch chan<- *dhcpv4.DHCPv4 -} - -// Logger is a handler which will be used to output logging messages -type Logger interface { - // PrintMessage print _all_ DHCP messages - PrintMessage(prefix string, message *dhcpv4.DHCPv4) - - // Printf is use to print the rest debugging information - Printf(format string, v ...interface{}) -} - -// EmptyLogger prints nothing -type EmptyLogger struct{} - -// Printf is just a dummy function that does nothing -func (e EmptyLogger) Printf(format string, v ...interface{}) {} - -// PrintMessage is just a dummy function that does nothing -func (e EmptyLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) {} - -// Printfer is used for actual output of the logger. For example *log.Logger is a Printfer. -type Printfer interface { - // Printf is the function for logging output. Arguments are handled in the manner of fmt.Printf. - Printf(format string, v ...interface{}) -} - -// ShortSummaryLogger is a wrapper for Printfer to implement interface Logger. -// DHCP messages are printed in the short format. -type ShortSummaryLogger struct { - // Printfer is used for actual output of the logger - Printfer -} - -// Printf prints a log message as-is via predefined Printfer -func (s ShortSummaryLogger) Printf(format string, v ...interface{}) { - s.Printfer.Printf(format, v...) -} - -// PrintMessage prints a DHCP message in the short format via predefined Printfer -func (s ShortSummaryLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) { - s.Printf("%s: %s", prefix, message) -} - -// DebugLogger is a wrapper for Printfer to implement interface Logger. -// DHCP messages are printed in the long format. -type DebugLogger struct { - // Printfer is used for actual output of the logger - Printfer -} - -// Printf prints a log message as-is via predefined Printfer -func (d DebugLogger) Printf(format string, v ...interface{}) { - d.Printfer.Printf(format, v...) -} - -// PrintMessage prints a DHCP message in the long format via predefined Printfer -func (d DebugLogger) PrintMessage(prefix string, message *dhcpv4.DHCPv4) { - d.Printf("%s: %s", prefix, message.Summary()) -} - -// Client is an IPv4 DHCP client. -type Client struct { - ifaceHWAddr net.HardwareAddr - conn net.PacketConn - timeout time.Duration - retry int - logger Logger - - // bufferCap is the channel capacity for each TransactionID. - bufferCap int - - // serverAddr is the UDP address to send all packets to. - // - // This may be an actual broadcast address, or a unicast address. - serverAddr *net.UDPAddr - - // closed is an atomic bool set to 1 when done is closed. - closed uint32 - - // done is closed to unblock the receive loop. - done chan struct{} - - // wg protects any spawned goroutines, namely the receiveLoop. - wg sync.WaitGroup - - pendingMu sync.Mutex - // pending stores the distribution channels for each pending - // TransactionID. receiveLoop uses this map to determine which channel - // to send a new DHCP message to. - pending map[dhcpv4.TransactionID]*pendingCh -} - -// New returns a client usable with an unconfigured interface. -func New(iface string, opts ...ClientOpt) (*Client, error) { - return new(iface, nil, nil, opts...) -} - -// NewWithConn creates a new DHCP client that sends and receives packets on the -// given interface. -func NewWithConn(conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) { - return new(``, conn, ifaceHWAddr, opts...) -} - -func new(iface string, conn net.PacketConn, ifaceHWAddr net.HardwareAddr, opts ...ClientOpt) (*Client, error) { - c := &Client{ - ifaceHWAddr: ifaceHWAddr, - timeout: DefaultTimeout, - retry: DefaultRetries, - serverAddr: DefaultServers, - bufferCap: defaultBufferCap, - conn: conn, - logger: EmptyLogger{}, - - done: make(chan struct{}), - pending: make(map[dhcpv4.TransactionID]*pendingCh), - } - - for _, opt := range opts { - err := opt(c) - if err != nil { - return nil, fmt.Errorf("unable to apply option: %w", err) - } - } - - if c.ifaceHWAddr == nil { - if iface == `` { - return nil, ErrNoIfaceHWAddr - } - - i, err := net.InterfaceByName(iface) - if err != nil { - return nil, fmt.Errorf("unable to get interface information: %w", err) - } - - c.ifaceHWAddr = i.HardwareAddr - } - - if c.conn == nil { - var err error - if iface == `` { - return nil, ErrNoConn - } - c.conn, err = NewRawUDPConn(iface, ClientPort) // broadcast - if err != nil { - return nil, fmt.Errorf("unable to open a broadcasting socket: %w", err) - } - } - c.wg.Add(1) - go c.receiveLoop() - return c, nil -} - -// Close closes the underlying connection. -func (c *Client) Close() error { - // Make sure not to close done twice. - if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { - return nil - } - - err := c.conn.Close() - - // Closing c.done sets off a chain reaction: - // - // Any SendAndRead unblocks trying to receive more messages, which - // means rem() gets called. - // - // rem() should be unblocking receiveLoop if it is blocked. - // - // receiveLoop should then exit gracefully. - close(c.done) - - // Wait for receiveLoop to stop. - c.wg.Wait() - - return err -} - -func (c *Client) isClosed() bool { - return atomic.LoadUint32(&c.closed) != 0 -} - -func (c *Client) receiveLoop() { - defer c.wg.Done() - for { - // TODO: Clients can send a "max packet size" option in their - // packets, IIRC. Choose a reasonable size and set it. - b := make([]byte, MaxMessageSize) - n, _, err := c.conn.ReadFrom(b) - if err != nil { - if !c.isClosed() { - c.logger.Printf("error reading from UDP connection: %v", err) - } - return - } - - msg, err := dhcpv4.FromBytes(b[:n]) - if err != nil { - // Not a valid DHCP packet; keep listening. - continue - } - - if msg.OpCode != dhcpv4.OpcodeBootReply { - // Not a response message. - continue - } - - // This is a somewhat non-standard check, by the looks - // of RFC 2131. It should work as long as the DHCP - // server is spec-compliant for the HWAddr field. - if c.ifaceHWAddr != nil && !bytes.Equal(c.ifaceHWAddr, msg.ClientHWAddr) { - // Not for us. - continue - } - - c.pendingMu.Lock() - p, ok := c.pending[msg.TransactionID] - if ok { - select { - case <-p.done: - close(p.ch) - delete(c.pending, msg.TransactionID) - - // This send may block. - case p.ch <- msg: - } - } - c.pendingMu.Unlock() - } -} - -// ClientOpt is a function that configures the Client. -type ClientOpt func(c *Client) error - -// WithTimeout configures the retransmission timeout. -// -// Default is 5 seconds. -func WithTimeout(d time.Duration) ClientOpt { - return func(c *Client) (err error) { - c.timeout = d - return - } -} - -// WithLogger set the logger (see interface Logger). -func WithLogger(newLogger Logger) ClientOpt { - return func(c *Client) (err error) { - c.logger = newLogger - return - } -} - -// WithUnicast forces client to send messages as unicast frames. -// By default client sends messages as broadcast frames even if server address is defined. -// -// srcAddr is both: -// * The source address of outgoing frames. -// * The address to be listened for incoming frames. -func WithUnicast(srcAddr *net.UDPAddr) ClientOpt { - return func(c *Client) (err error) { - if srcAddr == nil { - srcAddr = &net.UDPAddr{Port: ServerPort} - } - c.conn, err = net.ListenUDP("udp4", srcAddr) - if err != nil { - err = fmt.Errorf("unable to start listening UDP port: %w", err) - } - return - } -} - -// WithHWAddr tells to the Client to receive messages destinated to selected -// hardware address -func WithHWAddr(hwAddr net.HardwareAddr) ClientOpt { - return func(c *Client) (err error) { - c.ifaceHWAddr = hwAddr - return - } -} - -// WithRetry configures the number of retransmissions to attempt. -// -// Default is 3. -func WithRetry(r int) ClientOpt { - return func(c *Client) (err error) { - c.retry = r - return - } -} - -// WithServerAddr configures the address to send messages to. -func WithServerAddr(n *net.UDPAddr) ClientOpt { - return func(c *Client) (err error) { - c.serverAddr = n - return - } -} - -// Matcher matches DHCP packets. -type Matcher func(*dhcpv4.DHCPv4) bool - -// IsMessageType returns a matcher that checks for the message type. -// -// If t is MessageTypeNone, all packets are matched. -func IsMessageType(t dhcpv4.MessageType) Matcher { - return func(p *dhcpv4.DHCPv4) bool { - return p.MessageType() == t || t == dhcpv4.MessageTypeNone - } -} - -// DiscoverOffer sends a DHCPDiscover message and returns the first valid offer -// received. -func (c *Client) DiscoverOffer(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer *dhcpv4.DHCPv4, err error) { - // RFC 2131, Section 4.4.1, Table 5 details what a DISCOVER packet should - // contain. - discover, err := dhcpv4.NewDiscovery(c.ifaceHWAddr, dhcpv4.PrependModifiers(modifiers, - dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...) - if err != nil { - err = fmt.Errorf("unable to create a discovery request: %w", err) - return - } - - offer, err = c.SendAndRead(ctx, c.serverAddr, discover, IsMessageType(dhcpv4.MessageTypeOffer)) - if err != nil { - err = fmt.Errorf("got an error while the discovery request: %w", err) - return - } - - return -} - -// Request completes the 4-way Discover-Offer-Request-Ack handshake. -// -// Note that modifiers will be applied *both* to Discover and Request packets. -func (c *Client) Request(ctx context.Context, modifiers ...dhcpv4.Modifier) (offer, ack *dhcpv4.DHCPv4, err error) { - offer, err = c.DiscoverOffer(ctx, modifiers...) - if err != nil { - err = fmt.Errorf("unable to receive an offer: %w", err) - return - } - - // TODO(chrisko): should this be unicast to the server? - request, err := dhcpv4.NewRequestFromOffer(offer, dhcpv4.PrependModifiers(modifiers, - dhcpv4.WithOption(dhcpv4.OptMaxMessageSize(MaxMessageSize)))...) - if err != nil { - err = fmt.Errorf("unable to create a request: %w", err) - return - } - - ack, err = c.SendAndRead(ctx, c.serverAddr, request, nil) - if err != nil { - err = fmt.Errorf("got an error while processing the request: %w", err) - return - } - - return -} - -// ErrTransactionIDInUse is returned if there were an attempt to send a message -// with the same TransactionID as we are already waiting an answer for. -type ErrTransactionIDInUse struct { - // TransactionID is the transaction ID of the message which the error is related to. - TransactionID dhcpv4.TransactionID -} - -// Error is just the method to comply interface "error". -func (err *ErrTransactionIDInUse) Error() string { - return fmt.Sprintf("transaction ID %s already in use", err.TransactionID) -} - -// send sends p to destination and returns a response channel. -// -// Responses will be matched by transaction ID and ClientHWAddr. -// -// The returned lambda function must be called after all desired responses have -// been received in order to return the Transaction ID to the usable pool. -func (c *Client) send(dest *net.UDPAddr, msg *dhcpv4.DHCPv4) (resp <-chan *dhcpv4.DHCPv4, cancel func(), err error) { - c.pendingMu.Lock() - if _, ok := c.pending[msg.TransactionID]; ok { - c.pendingMu.Unlock() - return nil, nil, &ErrTransactionIDInUse{msg.TransactionID} - } - - ch := make(chan *dhcpv4.DHCPv4, c.bufferCap) - done := make(chan struct{}) - c.pending[msg.TransactionID] = &pendingCh{done: done, ch: ch} - c.pendingMu.Unlock() - - cancel = func() { - // Why can't we just close ch here? - // - // Because receiveLoop may potentially be blocked trying to - // send on ch. We gotta unblock it first, and then we can take - // the lock and remove the XID from the pending transaction - // map. - close(done) - - c.pendingMu.Lock() - if p, ok := c.pending[msg.TransactionID]; ok { - close(p.ch) - delete(c.pending, msg.TransactionID) - } - c.pendingMu.Unlock() - } - - if _, err = c.conn.WriteTo(msg.ToBytes(), dest); err != nil { - cancel() - return nil, nil, fmt.Errorf("error writing packet to connection: %w", err) - } - return ch, cancel, nil -} - -// This error should never be visible to users. -// It is used only to increase the timeout in retryFn. -var errDeadlineExceeded = errors.New("INTERNAL ERROR: deadline exceeded") - -// SendAndRead sends a packet p to a destination dest and waits for the first -// response matching `match` as well as its Transaction ID and ClientHWAddr. -// -// If match is nil, the first packet matching the Transaction ID and -// ClientHWAddr is returned. -func (c *Client) SendAndRead(ctx context.Context, dest *net.UDPAddr, p *dhcpv4.DHCPv4, match Matcher) (*dhcpv4.DHCPv4, error) { - var response *dhcpv4.DHCPv4 - err := c.retryFn(func(timeout time.Duration) error { - ch, rem, err := c.send(dest, p) - if err != nil { - return err - } - c.logger.PrintMessage("sent message", p) - defer rem() - - for { - select { - case <-c.done: - return ErrNoResponse - - case <-time.After(timeout): - return errDeadlineExceeded - - case <-ctx.Done(): - return ctx.Err() - - case packet := <-ch: - if match == nil || match(packet) { - c.logger.PrintMessage("received message", packet) - response = packet - return nil - } - } - } - }) - if err == errDeadlineExceeded { - return nil, ErrNoResponse - } - if err != nil { - return nil, err - } - return response, nil -} - -func (c *Client) retryFn(fn func(timeout time.Duration) error) error { - timeout := c.timeout - - // Each retry takes the amount of timeout at worst. - for i := 0; i < c.retry || c.retry < 0; i++ { // TODO: why is this called "retry" if this is "tries" ("retries"+1)? - switch err := fn(timeout); err { - case nil: - // Got it! - return nil - - case errDeadlineExceeded: - // Double timeout, then retry. - timeout *= 2 - - default: - return err - } - } - - return errDeadlineExceeded -} diff --git a/internal/dhcpd/nclient4/client_test.go b/internal/dhcpd/nclient4/client_test.go deleted file mode 100644 index 0eb658ac..00000000 --- a/internal/dhcpd/nclient4/client_test.go +++ /dev/null @@ -1,345 +0,0 @@ -// Copyright 2018 the u-root Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build linux -// github.com/hugelgupf/socketpair is Linux-only -// +build go1.12 - -package nclient4 - -import ( - "bytes" - "context" - "fmt" - "net" - "sync" - "testing" - "time" - - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" - "github.com/hugelgupf/socketpair" - "github.com/insomniacslk/dhcp/dhcpv4" - "github.com/insomniacslk/dhcp/dhcpv4/server4" -) - -func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) -} - -type handler struct { - mu sync.Mutex - received []*dhcpv4.DHCPv4 - - // Each received packet can have more than one response (in theory, - // from different servers sending different Advertise, for example). - responses [][]*dhcpv4.DHCPv4 -} - -func (h *handler) handle(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) { - h.mu.Lock() - defer h.mu.Unlock() - - h.received = append(h.received, m) - - if len(h.responses) > 0 { - for _, resp := range h.responses[0] { - _, _ = conn.WriteTo(resp.ToBytes(), peer) - } - h.responses = h.responses[1:] - } -} - -func serveAndClient(ctx context.Context, responses [][]*dhcpv4.DHCPv4, opts ...ClientOpt) (*Client, net.PacketConn) { - // Fake PacketConn connection. - clientRawConn, serverRawConn, err := socketpair.PacketSocketPair() - if err != nil { - panic(err) - } - - clientConn := NewBroadcastUDPConn(clientRawConn, &net.UDPAddr{Port: ClientPort}) - serverConn := NewBroadcastUDPConn(serverRawConn, &net.UDPAddr{Port: ServerPort}) - - o := []ClientOpt{WithRetry(1), WithTimeout(2 * time.Second)} - o = append(o, opts...) - mc, err := NewWithConn(clientConn, net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf}, o...) - if err != nil { - panic(err) - } - - h := &handler{responses: responses} - s, err := server4.NewServer("", nil, h.handle, server4.WithConn(serverConn)) - if err != nil { - panic(err) - } - go func() { - _ = s.Serve() - }() - - return mc, serverConn -} - -func ComparePacket(got, want *dhcpv4.DHCPv4) error { - if got == nil && got == want { - return nil - } - if (want == nil || got == nil) && (got != want) { - return fmt.Errorf("packet got %v, want %v", got, want) - } - if !bytes.Equal(got.ToBytes(), want.ToBytes()) { - return fmt.Errorf("packet got %v, want %v", got, want) - } - return nil -} - -func pktsExpected(got, want []*dhcpv4.DHCPv4) error { - if len(got) != len(want) { - return fmt.Errorf("got %d packets, want %d packets", len(got), len(want)) - } - - for i := range got { - if err := ComparePacket(got[i], want[i]); err != nil { - return err - } - } - return nil -} - -func newPacketWeirdHWAddr(op dhcpv4.OpcodeType, xid dhcpv4.TransactionID) *dhcpv4.DHCPv4 { - p, err := dhcpv4.New() - if err != nil { - panic(fmt.Sprintf("newpacket: %v", err)) - } - p.OpCode = op - p.TransactionID = xid - p.ClientHWAddr = net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 1, 2, 3, 4, 5, 6} - return p -} - -func newPacket(op dhcpv4.OpcodeType, xid dhcpv4.TransactionID) *dhcpv4.DHCPv4 { - p, err := dhcpv4.New() - if err != nil { - panic(fmt.Sprintf("newpacket: %v", err)) - } - p.OpCode = op - p.TransactionID = xid - p.ClientHWAddr = net.HardwareAddr{0xa, 0xb, 0xc, 0xd, 0xe, 0xf} - return p -} - -func withBufferCap(n int) ClientOpt { - return func(c *Client) (err error) { - c.bufferCap = n - return - } -} - -func TestSendAndRead(t *testing.T) { - for _, tt := range []struct { - desc string - send *dhcpv4.DHCPv4 - server []*dhcpv4.DHCPv4 - - // If want is nil, we assume server[0] contains what is wanted. - want *dhcpv4.DHCPv4 - wantErr error - }{ - { - desc: "two response packets", - send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}), - server: []*dhcpv4.DHCPv4{ - newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), - newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), - newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), - newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), - newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), - }, - want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), - }, - { - desc: "one response packet", - send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}), - server: []*dhcpv4.DHCPv4{ - newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), - }, - want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), - }, - { - desc: "one response packet, one invalid XID, one invalid opcode, one invalid hwaddr", - send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}), - server: []*dhcpv4.DHCPv4{ - newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x77, 0x33, 0x33, 0x33}), - newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}), - newPacketWeirdHWAddr(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), - newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), - }, - want: newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), - }, - { - desc: "discard wrong XID", - send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}), - server: []*dhcpv4.DHCPv4{ - newPacket(dhcpv4.OpcodeBootReply, [4]byte{0, 0, 0, 0}), - }, - want: nil, // Explicitly empty. - wantErr: ErrNoResponse, - }, - { - desc: "no response, timeout", - send: newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}), - wantErr: ErrNoResponse, - }, - } { - t.Run(tt.desc, func(t *testing.T) { - // Both server and client only get 2 seconds. - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{tt.server}, - // Use an unbuffered channel to make sure we - // have no deadlocks. - withBufferCap(0)) - defer mc.Close() - - rcvd, err := mc.SendAndRead(context.Background(), DefaultServers, tt.send, nil) - if err != tt.wantErr { - t.Error(err) - } - - if err = ComparePacket(rcvd, tt.want); err != nil { - t.Errorf("got unexpected packets: %v", err) - } - }) - } -} - -func TestParallelSendAndRead(t *testing.T) { - pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}) - - // Both the server and client only get 2 seconds. - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{}, - WithTimeout(10*time.Second), - // Use an unbuffered channel to make sure nothing blocks. - withBufferCap(0)) - defer mc.Close() - - var wg sync.WaitGroup - - wg.Add(1) - go func() { - defer wg.Done() - if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse { - t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse) - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - - time.Sleep(4 * time.Second) - - if err := mc.Close(); err != nil { - t.Errorf("closing failed: %v", err) - } - }() - - wg.Wait() -} - -func TestReuseXID(t *testing.T) { - pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}) - - // Both the server and client only get 2 seconds. - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - mc, _ := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{}) - defer mc.Close() - - if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse { - t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse) - } - - if _, err := mc.SendAndRead(context.Background(), DefaultServers, pkt, nil); err != ErrNoResponse { - t.Errorf("SendAndRead(%v) = %v, want %v", pkt, err, ErrNoResponse) - } -} - -func TestSimpleSendAndReadDiscardGarbage(t *testing.T) { - pkt := newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}) - - responses := newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}) - - // Both the server and client only get 2 seconds. - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - mc, udpConn := serveAndClient(ctx, [][]*dhcpv4.DHCPv4{{responses}}) - defer mc.Close() - - // Too short for valid DHCPv4 packet. - _, _ = udpConn.WriteTo([]byte{0x01}, nil) - _, _ = udpConn.WriteTo([]byte{0x01, 0x2}, nil) - - rcvd, err := mc.SendAndRead(ctx, DefaultServers, pkt, nil) - if err != nil { - t.Errorf("SendAndRead(%v) = %v, want nil", pkt, err) - } - - if err = ComparePacket(rcvd, responses); err != nil { - t.Errorf("got unexpected packets: %v", err) - } -} - -func TestMultipleSendAndRead(t *testing.T) { - for _, tt := range []struct { - desc string - send []*dhcpv4.DHCPv4 - server [][]*dhcpv4.DHCPv4 - wantErr []error - }{ - { - desc: "two requests, two responses", - send: []*dhcpv4.DHCPv4{ - newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x33, 0x33, 0x33, 0x33}), - newPacket(dhcpv4.OpcodeBootRequest, [4]byte{0x44, 0x44, 0x44, 0x44}), - }, - server: [][]*dhcpv4.DHCPv4{ - { // Response for first packet. - newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x33, 0x33, 0x33, 0x33}), - }, - { // Response for second packet. - newPacket(dhcpv4.OpcodeBootReply, [4]byte{0x44, 0x44, 0x44, 0x44}), - }, - }, - wantErr: []error{ - nil, - nil, - }, - }, - } { - // Both server and client only get 2 seconds. - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - - mc, _ := serveAndClient(ctx, tt.server) - defer mc.Close() - - for i, send := range tt.send { - ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - rcvd, err := mc.SendAndRead(ctx, DefaultServers, send, nil) - - if wantErr := tt.wantErr[i]; err != wantErr { - t.Errorf("SendAndReadOne(%v): got %v, want %v", send, err, wantErr) - } - if err = pktsExpected([]*dhcpv4.DHCPv4{rcvd}, tt.server[i]); err != nil { - t.Errorf("got unexpected packets: %v", err) - } - } - } -} diff --git a/internal/dhcpd/nclient4/conn_unix.go b/internal/dhcpd/nclient4/conn_unix.go deleted file mode 100644 index 39009d69..00000000 --- a/internal/dhcpd/nclient4/conn_unix.go +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2018 the u-root Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// +build darwin dragonfly freebsd linux netbsd openbsd solaris -// +build go1.12 - -package nclient4 - -import ( - "errors" - "io" - "net" - - "github.com/mdlayher/ethernet" - "github.com/mdlayher/raw" - "github.com/u-root/u-root/pkg/uio" -) - -// BroadcastMac is the broadcast MAC address. -// -// Any UDP packet sent to this address is broadcast on the subnet. -var BroadcastMac = net.HardwareAddr([]byte{255, 255, 255, 255, 255, 255}) - -// ErrUDPAddrIsRequired is an error used when a passed argument is not of type "*net.UDPAddr". -var ErrUDPAddrIsRequired = errors.New("must supply UDPAddr") - -// NewRawUDPConn returns a UDP connection bound to the interface and port -// given based on a raw packet socket. All packets are broadcasted. -// -// The interface can be completely unconfigured. -func NewRawUDPConn(iface string, port int) (net.PacketConn, error) { - ifc, err := net.InterfaceByName(iface) - if err != nil { - return nil, err - } - rawConn, err := raw.ListenPacket(ifc, uint16(ethernet.EtherTypeIPv4), &raw.Config{LinuxSockDGRAM: true}) - if err != nil { - return nil, err - } - return NewBroadcastUDPConn(rawConn, &net.UDPAddr{Port: port}), nil -} - -// BroadcastRawUDPConn uses a raw socket to send UDP packets to the broadcast -// MAC address. -type BroadcastRawUDPConn struct { - // PacketConn is a raw DGRAM socket. - net.PacketConn - - // boundAddr is the address this RawUDPConn is "bound" to. - // - // Calls to ReadFrom will only return packets destined to this address. - boundAddr *net.UDPAddr -} - -// NewBroadcastUDPConn returns a PacketConn that marshals and unmarshals UDP -// packets, sending them to the broadcast MAC at on rawPacketConn. -// -// Calls to ReadFrom will only return packets destined to boundAddr. -func NewBroadcastUDPConn(rawPacketConn net.PacketConn, boundAddr *net.UDPAddr) net.PacketConn { - return &BroadcastRawUDPConn{ - PacketConn: rawPacketConn, - boundAddr: boundAddr, - } -} - -func udpMatch(addr, bound *net.UDPAddr) bool { - if bound == nil { - return true - } - if bound.IP != nil && !bound.IP.Equal(addr.IP) { - return false - } - return bound.Port == addr.Port -} - -// ReadFrom implements net.PacketConn.ReadFrom. -// -// ReadFrom reads raw IP packets and will try to match them against -// upc.boundAddr. Any matching packets are returned via the given buffer. -func (upc *BroadcastRawUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { - ipHdrMaxLen := IPv4MaximumHeaderSize - udpHdrLen := UDPMinimumSize - - for { - pkt := make([]byte, ipHdrMaxLen+udpHdrLen+len(b)) - n, _, err := upc.PacketConn.ReadFrom(pkt) - if err != nil { - return 0, nil, err - } - if n == 0 { - return 0, nil, io.EOF - } - pkt = pkt[:n] - buf := uio.NewBigEndianBuffer(pkt) - - // To read the header length, access data directly. - ipHdr := IPv4(buf.Data()) - ipHdr = IPv4(buf.Consume(int(ipHdr.HeaderLength()))) - - if ipHdr.TransportProtocol() != UDPProtocolNumber { - continue - } - udpHdr := UDP(buf.Consume(udpHdrLen)) - - addr := &net.UDPAddr{ - IP: ipHdr.DestinationAddress(), - Port: int(udpHdr.DestinationPort()), - } - if !udpMatch(addr, upc.boundAddr) { - continue - } - srcAddr := &net.UDPAddr{ - IP: ipHdr.SourceAddress(), - Port: int(udpHdr.SourcePort()), - } - // Extra padding after end of IP packet should be ignored, - // if not dhcp option parsing will fail. - dhcpLen := int(ipHdr.PayloadLength()) - udpHdrLen - return copy(b, buf.Consume(dhcpLen)), srcAddr, nil - } -} - -// WriteTo implements net.PacketConn.WriteTo and broadcasts all packets at the -// raw socket level. -// -// WriteTo wraps the given packet in the appropriate UDP and IP header before -// sending it on the packet conn. -func (upc *BroadcastRawUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { - udpAddr, ok := addr.(*net.UDPAddr) - if !ok { - return 0, ErrUDPAddrIsRequired - } - - // Using the boundAddr is not quite right here, but it works. - packet := udp4pkt(b, udpAddr, upc.boundAddr) - - // Broadcasting is not always right, but hell, what the ARP do I know. - return upc.PacketConn.WriteTo(packet, &raw.Addr{HardwareAddr: BroadcastMac}) -} diff --git a/internal/dhcpd/nclient4/ipv4.go b/internal/dhcpd/nclient4/ipv4.go deleted file mode 100644 index 50a2d684..00000000 --- a/internal/dhcpd/nclient4/ipv4.go +++ /dev/null @@ -1,377 +0,0 @@ -// Copyright 2018 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// This file contains code taken from gVisor. - -// +build darwin dragonfly freebsd linux netbsd openbsd solaris -// +build go1.12 - -package nclient4 - -import ( - "encoding/binary" - "net" - - "github.com/u-root/u-root/pkg/uio" -) - -const ( - versIHL = 0 - tos = 1 - totalLen = 2 - id = 4 - flagsFO = 6 - ttl = 8 - protocol = 9 - checksum = 10 - srcAddr = 12 - dstAddr = 16 -) - -// TransportProtocolNumber is the number of a transport protocol. -type TransportProtocolNumber uint32 - -// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the -// fields of a packet that needs to be encoded. -type IPv4Fields struct { - // IHL is the "internet header length" field of an IPv4 packet. - IHL uint8 - - // TOS is the "type of service" field of an IPv4 packet. - TOS uint8 - - // TotalLength is the "total length" field of an IPv4 packet. - TotalLength uint16 - - // ID is the "identification" field of an IPv4 packet. - ID uint16 - - // Flags is the "flags" field of an IPv4 packet. - Flags uint8 - - // FragmentOffset is the "fragment offset" field of an IPv4 packet. - FragmentOffset uint16 - - // TTL is the "time to live" field of an IPv4 packet. - TTL uint8 - - // Protocol is the "protocol" field of an IPv4 packet. - Protocol uint8 - - // Checksum is the "checksum" field of an IPv4 packet. - Checksum uint16 - - // SrcAddr is the "source ip address" of an IPv4 packet. - SrcAddr net.IP - - // DstAddr is the "destination ip address" of an IPv4 packet. - DstAddr net.IP -} - -// IPv4 represents an ipv4 header stored in a byte array. -// Most of the methods of IPv4 access to the underlying slice without -// checking the boundaries and could panic because of 'index out of range'. -// Always call IsValid() to validate an instance of IPv4 before using other methods. -type IPv4 []byte - -const ( - // IPv4MinimumSize is the minimum size of a valid IPv4 packet. - IPv4MinimumSize = 20 - - // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given - // that there are only 4 bits to represents the header length in 32-bit - // units, the header cannot exceed 15*4 = 60 bytes. - IPv4MaximumHeaderSize = 60 - - // IPv4AddressSize is the size, in bytes, of an IPv4 address. - IPv4AddressSize = 4 - - // IPv4Version is the version of the ipv4 protocol. - IPv4Version = 4 -) - -var ( - // IPv4Broadcast is the broadcast address of the IPv4 protocol. - IPv4Broadcast = net.IP{0xff, 0xff, 0xff, 0xff} - - // IPv4Any is the non-routable IPv4 "any" meta address. - IPv4Any = net.IP{0, 0, 0, 0} -) - -// Flags that may be set in an IPv4 packet. -const ( - IPv4FlagMoreFragments = 1 << iota - IPv4FlagDontFragment -) - -// HeaderLength returns the value of the "header length" field of the ipv4 -// header. -func (b IPv4) HeaderLength() uint8 { - return (b[versIHL] & 0xf) * 4 -} - -// Protocol returns the value of the protocol field of the ipv4 header. -func (b IPv4) Protocol() uint8 { - return b[protocol] -} - -// SourceAddress returns the "source address" field of the ipv4 header. -func (b IPv4) SourceAddress() net.IP { - return net.IP(b[srcAddr : srcAddr+IPv4AddressSize]) -} - -// DestinationAddress returns the "destination address" field of the ipv4 -// header. -func (b IPv4) DestinationAddress() net.IP { - return net.IP(b[dstAddr : dstAddr+IPv4AddressSize]) -} - -// TransportProtocol implements Network.TransportProtocol. -func (b IPv4) TransportProtocol() TransportProtocolNumber { - return TransportProtocolNumber(b.Protocol()) -} - -// Payload implements Network.Payload. -func (b IPv4) Payload() []byte { - return b[b.HeaderLength():][:b.PayloadLength()] -} - -// PayloadLength returns the length of the payload portion of the ipv4 packet. -func (b IPv4) PayloadLength() uint16 { - return b.TotalLength() - uint16(b.HeaderLength()) -} - -// TotalLength returns the "total length" field of the ipv4 header. -func (b IPv4) TotalLength() uint16 { - return binary.BigEndian.Uint16(b[totalLen:]) -} - -// SetTotalLength sets the "total length" field of the ipv4 header. -func (b IPv4) SetTotalLength(totalLength uint16) { - binary.BigEndian.PutUint16(b[totalLen:], totalLength) -} - -// SetChecksum sets the checksum field of the ipv4 header. -func (b IPv4) SetChecksum(v uint16) { - binary.BigEndian.PutUint16(b[checksum:], v) -} - -// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the -// ipv4 header. -func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) { - v := (uint16(flags) << 13) | (offset >> 3) - binary.BigEndian.PutUint16(b[flagsFO:], v) -} - -// SetSourceAddress sets the "source address" field of the ipv4 header. -func (b IPv4) SetSourceAddress(addr net.IP) { - copy(b[srcAddr:srcAddr+IPv4AddressSize], addr.To4()) -} - -// SetDestinationAddress sets the "destination address" field of the ipv4 -// header. -func (b IPv4) SetDestinationAddress(addr net.IP) { - copy(b[dstAddr:dstAddr+IPv4AddressSize], addr.To4()) -} - -// CalculateChecksum calculates the checksum of the ipv4 header. -func (b IPv4) CalculateChecksum() uint16 { - return Checksum(b[:b.HeaderLength()], 0) -} - -// Encode encodes all the fields of the ipv4 header. -func (b IPv4) Encode(i *IPv4Fields) { - b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf) - b[tos] = i.TOS - b.SetTotalLength(i.TotalLength) - binary.BigEndian.PutUint16(b[id:], i.ID) - b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset) - b[ttl] = i.TTL - b[protocol] = i.Protocol - b.SetChecksum(i.Checksum) - copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr) - copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr) -} - -const ( - udpSrcPort = 0 - udpDstPort = 2 - udpLength = 4 - udpChecksum = 6 -) - -// UDPFields contains the fields of a UDP packet. It is used to describe the -// fields of a packet that needs to be encoded. -type UDPFields struct { - // SrcPort is the "source port" field of a UDP packet. - SrcPort uint16 - - // DstPort is the "destination port" field of a UDP packet. - DstPort uint16 - - // Length is the "length" field of a UDP packet. - Length uint16 - - // Checksum is the "checksum" field of a UDP packet. - Checksum uint16 -} - -// UDP represents a UDP header stored in a byte array. -type UDP []byte - -const ( - // UDPMinimumSize is the minimum size of a valid UDP packet. - UDPMinimumSize = 8 - - // UDPProtocolNumber is UDP's transport protocol number. - UDPProtocolNumber TransportProtocolNumber = 17 -) - -// SourcePort returns the "source port" field of the udp header. -func (b UDP) SourcePort() uint16 { - return binary.BigEndian.Uint16(b[udpSrcPort:]) -} - -// DestinationPort returns the "destination port" field of the udp header. -func (b UDP) DestinationPort() uint16 { - return binary.BigEndian.Uint16(b[udpDstPort:]) -} - -// Length returns the "length" field of the udp header. -func (b UDP) Length() uint16 { - return binary.BigEndian.Uint16(b[udpLength:]) -} - -// SetSourcePort sets the "source port" field of the udp header. -func (b UDP) SetSourcePort(port uint16) { - binary.BigEndian.PutUint16(b[udpSrcPort:], port) -} - -// SetDestinationPort sets the "destination port" field of the udp header. -func (b UDP) SetDestinationPort(port uint16) { - binary.BigEndian.PutUint16(b[udpDstPort:], port) -} - -// SetChecksum sets the "checksum" field of the udp header. -func (b UDP) SetChecksum(checksum uint16) { - binary.BigEndian.PutUint16(b[udpChecksum:], checksum) -} - -// Payload returns the data contained in the UDP datagram. -func (b UDP) Payload() []byte { - return b[UDPMinimumSize:] -} - -// Checksum returns the "checksum" field of the udp header. -func (b UDP) Checksum() uint16 { - return binary.BigEndian.Uint16(b[udpChecksum:]) -} - -// CalculateChecksum calculates the checksum of the udp packet, given the total -// length of the packet and the checksum of the network-layer pseudo-header -// (excluding the total length) and the checksum of the payload. -func (b UDP) CalculateChecksum(partialChecksum, totalLen uint16) uint16 { - // Add the length portion of the checksum to the pseudo-checksum. - tmp := make([]byte, 2) - binary.BigEndian.PutUint16(tmp, totalLen) - checksum := Checksum(tmp, partialChecksum) - - // Calculate the rest of the checksum. - return Checksum(b[:UDPMinimumSize], checksum) -} - -// Encode encodes all the fields of the udp header. -func (b UDP) Encode(u *UDPFields) { - binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort) - binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort) - binary.BigEndian.PutUint16(b[udpLength:], u.Length) - binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum) -} - -func calculateChecksum(buf []byte, initial uint32) uint16 { - v := initial - - l := len(buf) - if l&1 != 0 { - l-- - v += uint32(buf[l]) << 8 - } - - for i := 0; i < l; i += 2 { - v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) - } - - return ChecksumCombine(uint16(v), uint16(v>>16)) -} - -// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the -// given byte array. -// -// The initial checksum must have been computed on an even number of bytes. -func Checksum(buf []byte, initial uint16) uint16 { - return calculateChecksum(buf, uint32(initial)) -} - -// ChecksumCombine combines the two uint16 to form their checksum. This is done -// by adding them and the carry. -// -// Note that checksum a must have been computed on an even number of bytes. -func ChecksumCombine(a, b uint16) uint16 { - v := uint32(a) + uint32(b) - return uint16(v + v>>16) -} - -// PseudoHeaderChecksum calculates the pseudo-header checksum for the -// given destination protocol and network address, ignoring the length -// field. Pseudo-headers are needed by transport layers when calculating -// their own checksum. -func PseudoHeaderChecksum(protocol TransportProtocolNumber, srcAddr, dstAddr net.IP) uint16 { - xsum := Checksum([]byte(srcAddr), 0) - xsum = Checksum([]byte(dstAddr), xsum) - return Checksum([]byte{0, uint8(protocol)}, xsum) -} - -func udp4pkt(packet []byte, dest, src *net.UDPAddr) []byte { - ipLen := IPv4MinimumSize - udpLen := UDPMinimumSize - - h := make([]byte, 0, ipLen+udpLen+len(packet)) - hdr := uio.NewBigEndianBuffer(h) - - ipv4fields := &IPv4Fields{ - IHL: IPv4MinimumSize, - TotalLength: uint16(ipLen + udpLen + len(packet)), - TTL: 64, // Per RFC 1700's recommendation for IP time to live - Protocol: uint8(UDPProtocolNumber), - SrcAddr: src.IP.To4(), - DstAddr: dest.IP.To4(), - } - ipv4hdr := IPv4(hdr.WriteN(ipLen)) - ipv4hdr.Encode(ipv4fields) - ipv4hdr.SetChecksum(^ipv4hdr.CalculateChecksum()) - - udphdr := UDP(hdr.WriteN(udpLen)) - udphdr.Encode(&UDPFields{ - SrcPort: uint16(src.Port), - DstPort: uint16(dest.Port), - Length: uint16(udpLen + len(packet)), - }) - - xsum := Checksum(packet, PseudoHeaderChecksum( - ipv4hdr.TransportProtocol(), ipv4fields.SrcAddr, ipv4fields.DstAddr)) - udphdr.SetChecksum(^udphdr.CalculateChecksum(xsum, udphdr.Length())) - - hdr.WriteBytes(packet) - return hdr.Data() -}