From 41b7495ae337675151c762cbafafdb846021c34a Mon Sep 17 00:00:00 2001 From: Tatsushi Demachi Date: Sun, 11 Jan 2015 00:20:08 +0900 Subject: [PATCH] Add non-privileged datagram-oriented ICMP endpoints mode using UDP --- README.md | 4 +- cmd/ping/ping.go | 26 +++- fastping.go | 74 ++++++++--- fastping_test.go | 316 ++++++++++++++++++++++++----------------------- 4 files changed, 242 insertions(+), 178 deletions(-) diff --git a/README.md b/README.md index 552c12f..3aae31d 100644 --- a/README.md +++ b/README.md @@ -44,8 +44,8 @@ callback. For more detail, refer [godoc][godoc] and if you need more example, please see "cmd/ping/ping.go". ## Caution -This package now only implements ICMP ping using raw socket so the program -using this package needs to be run as root user. +This package implements ICMP ping using both raw socket and UDP. If your program +uses this package in raw socket mode, it needs to be run as a root user. ## License go-fastping is under MIT License. See the [LICENSE][license] file for details. diff --git a/cmd/ping/ping.go b/cmd/ping/ping.go index 215eb87..eddd695 100644 --- a/cmd/ping/ping.go +++ b/cmd/ping/ping.go @@ -1,14 +1,16 @@ package main import ( + "flag" "fmt" - "github.com/tatsushid/go-fastping" "net" "os" "os/signal" "strings" "syscall" "time" + + "github.com/tatsushid/go-fastping" ) type response struct { @@ -17,17 +19,31 @@ type response struct { } func main() { - if len(os.Args) != 2 { - fmt.Fprintf(os.Stderr, "Usage: %s {hostname}\n", os.Args[0]) + var useUDP bool + flag.BoolVar(&useUDP, "udp", false, "use non-privileged datagram-oriented UDP as ICMP endpoints") + flag.BoolVar(&useUDP, "u", false, "use non-privileged datagram-oriented UDP as ICMP endpoints (shorthand)") + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage:\n %s [options] hostname\n\nOptions:\n", os.Args[0]) + flag.PrintDefaults() + } + flag.Parse() + + hostname := flag.Arg(0) + if len(hostname) == 0 { + flag.Usage() os.Exit(1) } + p := fastping.NewPinger() + if useUDP { + p.Network("udp") + } netProto := "ip4:icmp" - if strings.Index(os.Args[1], ":") != -1 { + if strings.Index(hostname, ":") != -1 { netProto = "ip6:ipv6-icmp" } - ra, err := net.ResolveIPAddr(netProto, os.Args[1]) + ra, err := net.ResolveIPAddr(netProto, hostname) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/fastping.go b/fastping.go index 6488384..f22ca0f 100644 --- a/fastping.go +++ b/fastping.go @@ -29,8 +29,9 @@ // it calls "receive" callback. After that, MaxRTT time passed, it calls // "idle" callback. If you need more example, please see "cmd/ping/ping.go". // -// This library needs to run as a superuser for sending ICMP packets so when -// you run go test, please run as a following +// This library needs to run as a superuser for sending ICMP packets when +// privileged raw ICMP endpoints is used so in such a case, to run go test +// for the package, please run like a following // // sudo go test // @@ -54,6 +55,11 @@ import ( const TimeSliceLength = 8 +var ( + ipv4Proto = map[string]string{"ip": "ip4:icmp", "udp": "udp4"} + ipv6Proto = map[string]string{"ip": "ip6:ipv6-icmp", "udp": "udp6"} +) + func byteSliceOfSize(n int) []byte { b := make([]byte, n) for i := 0; i < len(b); i++ { @@ -98,7 +104,7 @@ func ipv4Payload(b []byte) []byte { type packet struct { bytes []byte - addr *net.IPAddr + addr net.Addr } type context struct { @@ -120,6 +126,7 @@ type Pinger struct { seq int // key string is IPAddr.String() addrs map[string]*net.IPAddr + network string hasIPv4 bool hasIPv6 bool ctx *context @@ -147,6 +154,7 @@ func NewPinger() *Pinger { id: rand.Intn(0xffff), seq: rand.Intn(0xffff), addrs: make(map[string]*net.IPAddr), + network: "ip", hasIPv4: false, hasIPv6: false, Size: TimeSliceLength, @@ -157,6 +165,23 @@ func NewPinger() *Pinger { } } +// Network sets a network endpoints for ICMP ping and returns the previous +// setting. network arg should be "ip" or "udp" string or if others are +// specified, it returns an error. If this function isn't called, Pinger +// uses "ip" as default. +func (p *Pinger) Network(network string) (string, error) { + origNet := p.network + switch network { + case "ip": + fallthrough + case "udp": + p.network = network + default: + return origNet, errors.New(network + " can't be used as ICMP endpoint") + } + return origNet, nil +} + // AddIP adds an IP address to Pinger. ipaddr arg should be a string like // "192.0.2.1". func (p *Pinger) AddIP(ipaddr string) error { @@ -315,14 +340,14 @@ func (p *Pinger) run(once bool) { p.debugln("Run(): Start") var conn, conn6 *icmp.PacketConn if p.hasIPv4 { - if conn = p.listen("ip4:icmp"); conn == nil { + if conn = p.listen(ipv4Proto[p.network]); conn == nil { return } defer conn.Close() } if p.hasIPv6 { - if conn6 = p.listen("ip6:ipv6-icmp"); conn6 == nil { + if conn6 = p.listen(ipv6Proto[p.network]); conn6 == nil { return } defer conn6.Close() @@ -438,10 +463,14 @@ func (p *Pinger) sendICMP(conn, conn6 *icmp.PacketConn) (map[string]*net.IPAddr, } queue[key] = addr + var dst net.Addr = addr + if p.network == "udp" { + dst = &net.UDPAddr{IP: addr.IP, Zone: addr.Zone} + } p.debugln("sendICMP(): Invoke goroutine") wg.Add(1) - go func(conn *icmp.PacketConn, ra *net.IPAddr, b []byte) { + go func(conn *icmp.PacketConn, ra net.Addr, b []byte) { for { if _, err := conn.WriteTo(bytes, ra); err != nil { if neterr, ok := err.(*net.OpError); ok { @@ -454,7 +483,7 @@ func (p *Pinger) sendICMP(conn, conn6 *icmp.PacketConn) (map[string]*net.IPAddr, } p.debugln("sendICMP(): WriteTo End") wg.Done() - }(cn, addr, bytes) + }(cn, dst, bytes) } wg.Wait() p.debugln("sendICMP(): End") @@ -498,13 +527,8 @@ func (p *Pinger) recvICMP(conn *icmp.PacketConn, recv chan<- *packet, ctx *conte } p.debugln("recvICMP(): p.recv <- packet") - addr, ok := ra.(*net.IPAddr) - if !ok { - continue - } - select { - case recv <- &packet{bytes: bytes, addr: addr}: + case recv <- &packet{bytes: bytes, addr: ra}: case <-ctx.stop: p.debugln("recvICMP(): <-ctx.stop") wg.Done() @@ -515,7 +539,17 @@ func (p *Pinger) recvICMP(conn *icmp.PacketConn, recv chan<- *packet, ctx *conte } func (p *Pinger) procRecv(recv *packet, queue map[string]*net.IPAddr) { - addr := recv.addr.String() + var ipaddr *net.IPAddr + switch adr := recv.addr.(type) { + case *net.IPAddr: + ipaddr = adr + case *net.UDPAddr: + ipaddr = &net.IPAddr{IP: adr.IP, Zone: adr.Zone} + default: + return + } + + addr := ipaddr.String() p.mu.Lock() if _, ok := p.addrs[addr]; !ok { p.mu.Unlock() @@ -525,10 +559,14 @@ func (p *Pinger) procRecv(recv *packet, queue map[string]*net.IPAddr) { var bytes []byte var proto int - if isIPv4(recv.addr.IP) { - bytes = ipv4Payload(recv.bytes) + if isIPv4(ipaddr.IP) { + if p.network == "ip" { + bytes = ipv4Payload(recv.bytes) + } else { + bytes = recv.bytes + } proto = iana.ProtocolICMP - } else if isIPv6(recv.addr.IP) { + } else if isIPv6(ipaddr.IP) { bytes = recv.bytes proto = iana.ProtocolIPv6ICMP } else { @@ -563,7 +601,7 @@ func (p *Pinger) procRecv(recv *packet, queue map[string]*net.IPAddr) { handler := p.OnRecv p.mu.Unlock() if handler != nil { - handler(recv.addr, rtt) + handler(ipaddr, rtt) } } } diff --git a/fastping_test.go b/fastping_test.go index cef5ab9..99ca01b 100644 --- a/fastping_test.go +++ b/fastping_test.go @@ -38,177 +38,187 @@ func TestAddIP(t *testing.T) { } func TestRun(t *testing.T) { - p := NewPinger() + for _, network := range []string{"ip", "udp"} { + p := NewPinger() + p.Network(network) - if err := p.AddIP("127.0.0.1"); err != nil { - t.Fatalf("AddIP failed: %v", err) - } - - if err := p.AddIP("127.0.0.100"); err != nil { - t.Fatalf("AddIP failed: %v", err) - } - - if err := p.AddIP("::1"); err != nil { - t.Fatalf("AddIP failed: %v", err) - } - - found1, found100, foundv6 := false, false, false - called, idle := false, false - p.OnRecv = func(ip *net.IPAddr, d time.Duration) { - called = true - if ip.String() == "127.0.0.1" { - found1 = true - } else if ip.String() == "127.0.0.100" { - found100 = true - } else if ip.String() == "::1" { - foundv6 = true + if err := p.AddIP("127.0.0.1"); err != nil { + t.Fatalf("AddIP failed: %v", err) } - } - p.OnIdle = func() { - idle = true - } + if err := p.AddIP("127.0.0.100"); err != nil { + t.Fatalf("AddIP failed: %v", err) + } - err := p.Run() - if err != nil { - t.Fatalf("Pinger returns error: %v", err) - } - if !called { - t.Fatalf("Pinger didn't get any responses") - } - if !idle { - t.Fatalf("Pinger didn't call OnIdle function") - } - if !found1 { - t.Fatalf("Pinger `127.0.0.1` didn't respond") - } - if found100 { - t.Fatalf("Pinger `127.0.0.100` responded") - } - if !foundv6 { - t.Fatalf("Pinger `::1` didn't responded") + if err := p.AddIP("::1"); err != nil { + t.Fatalf("AddIP failed: %v", err) + } + + found1, found100, foundv6 := false, false, false + called, idle := false, false + p.OnRecv = func(ip *net.IPAddr, d time.Duration) { + called = true + if ip.String() == "127.0.0.1" { + found1 = true + } else if ip.String() == "127.0.0.100" { + found100 = true + } else if ip.String() == "::1" { + foundv6 = true + } + } + + p.OnIdle = func() { + idle = true + } + + err := p.Run() + if err != nil { + t.Fatalf("Pinger returns error: %v", err) + } + if !called { + t.Fatalf("Pinger didn't get any responses") + } + if !idle { + t.Fatalf("Pinger didn't call OnIdle function") + } + if !found1 { + t.Fatalf("Pinger `127.0.0.1` didn't respond") + } + if found100 { + t.Fatalf("Pinger `127.0.0.100` responded") + } + if !foundv6 { + t.Fatalf("Pinger `::1` didn't responded") + } } } func TestMultiRun(t *testing.T) { - p1 := NewPinger() - p2 := NewPinger() + for _, network := range []string{"ip", "udp"} { + p1 := NewPinger() + p1.Network(network) + p2 := NewPinger() + p2.Network(network) - if err := p1.AddIP("127.0.0.1"); err != nil { - t.Fatalf("AddIP 1 failed: %v", err) - } - - if err := p2.AddIP("127.0.0.1"); err != nil { - t.Fatalf("AddIP 2 failed: %v", err) - } - - var mu sync.Mutex - res1 := 0 - p1.OnRecv = func(*net.IPAddr, time.Duration) { - mu.Lock() - res1++ - mu.Unlock() - } - - res2 := 0 - p2.OnRecv = func(*net.IPAddr, time.Duration) { - mu.Lock() - res2++ - mu.Unlock() - } - - p1.MaxRTT, p2.MaxRTT = time.Millisecond*100, time.Millisecond*100 - - if err := p1.Run(); err != nil { - t.Fatalf("Pinger 1 returns error: %v", err) - } - if res1 == 0 { - t.Fatalf("Pinger 1 didn't get any responses") - } - if res2 > 0 { - t.Fatalf("Pinger 2 got response") - } - - res1, res2 = 0, 0 - if err := p2.Run(); err != nil { - t.Fatalf("Pinger 2 returns error: %v", err) - } - if res1 > 0 { - t.Fatalf("Pinger 1 got response") - } - if res2 == 0 { - t.Fatalf("Pinger 2 didn't get any responses") - } - - res1, res2 = 0, 0 - errch1, errch2 := make(chan error), make(chan error) - go func(ch chan error) { - err := p1.Run() - if err != nil { - ch <- err + if err := p1.AddIP("127.0.0.1"); err != nil { + t.Fatalf("AddIP 1 failed: %v", err) } - }(errch1) - go func(ch chan error) { - err := p2.Run() - if err != nil { - ch <- err + + if err := p2.AddIP("127.0.0.1"); err != nil { + t.Fatalf("AddIP 2 failed: %v", err) + } + + var mu sync.Mutex + res1 := 0 + p1.OnRecv = func(*net.IPAddr, time.Duration) { + mu.Lock() + res1++ + mu.Unlock() + } + + res2 := 0 + p2.OnRecv = func(*net.IPAddr, time.Duration) { + mu.Lock() + res2++ + mu.Unlock() + } + + p1.MaxRTT, p2.MaxRTT = time.Millisecond*100, time.Millisecond*100 + + if err := p1.Run(); err != nil { + t.Fatalf("Pinger 1 returns error: %v", err) + } + if res1 == 0 { + t.Fatalf("Pinger 1 didn't get any responses") + } + if res2 > 0 { + t.Fatalf("Pinger 2 got response") + } + + res1, res2 = 0, 0 + if err := p2.Run(); err != nil { + t.Fatalf("Pinger 2 returns error: %v", err) + } + if res1 > 0 { + t.Fatalf("Pinger 1 got response") + } + if res2 == 0 { + t.Fatalf("Pinger 2 didn't get any responses") + } + + res1, res2 = 0, 0 + errch1, errch2 := make(chan error), make(chan error) + go func(ch chan error) { + err := p1.Run() + if err != nil { + ch <- err + } + }(errch1) + go func(ch chan error) { + err := p2.Run() + if err != nil { + ch <- err + } + }(errch2) + ticker := time.NewTicker(time.Millisecond * 200) + select { + case err := <-errch1: + t.Fatalf("Pinger 1 returns error: %v", err) + case err := <-errch2: + t.Fatalf("Pinger 2 returns error: %v", err) + case <-ticker.C: + break + } + mu.Lock() + defer mu.Unlock() + if res1 != 1 { + t.Fatalf("Pinger 1 didn't get correct response") + } + if res2 != 1 { + t.Fatalf("Pinger 2 didn't get correct response") } - }(errch2) - ticker := time.NewTicker(time.Millisecond * 200) - select { - case err := <-errch1: - t.Fatalf("Pinger 1 returns error: %v", err) - case err := <-errch2: - t.Fatalf("Pinger 2 returns error: %v", err) - case <-ticker.C: - break - } - mu.Lock() - defer mu.Unlock() - if res1 != 1 { - t.Fatalf("Pinger 1 didn't get correct response") - } - if res2 != 1 { - t.Fatalf("Pinger 2 didn't get correct response") } } func TestRunLoop(t *testing.T) { - p := NewPinger() + for _, network := range []string{"ip", "udp"} { + p := NewPinger() + p.Network(network) - if err := p.AddIP("127.0.0.1"); err != nil { - t.Fatalf("AddIP failed: %v", err) - } - p.MaxRTT = time.Millisecond * 100 - - recvCount, idleCount := 0, 0 - p.OnRecv = func(*net.IPAddr, time.Duration) { - recvCount++ - } - - p.OnIdle = func() { - idleCount++ - } - - var err error - p.RunLoop() - ticker := time.NewTicker(time.Millisecond * 250) - select { - case <-p.Done(): - if err = p.Err(); err != nil { - t.Fatalf("Pinger returns error %v", err) + if err := p.AddIP("127.0.0.1"); err != nil { + t.Fatalf("AddIP failed: %v", err) } - case <-ticker.C: - break - } - ticker.Stop() - p.Stop() + p.MaxRTT = time.Millisecond * 100 - if recvCount < 2 { - t.Fatalf("Pinger receive count less than 2") - } - if idleCount < 2 { - t.Fatalf("Pinger idle count less than 2") + recvCount, idleCount := 0, 0 + p.OnRecv = func(*net.IPAddr, time.Duration) { + recvCount++ + } + + p.OnIdle = func() { + idleCount++ + } + + var err error + p.RunLoop() + ticker := time.NewTicker(time.Millisecond * 250) + select { + case <-p.Done(): + if err = p.Err(); err != nil { + t.Fatalf("Pinger returns error %v", err) + } + case <-ticker.C: + break + } + ticker.Stop() + p.Stop() + + if recvCount < 2 { + t.Fatalf("Pinger receive count less than 2") + } + if idleCount < 2 { + t.Fatalf("Pinger idle count less than 2") + } } }