1
0
mirror of https://github.com/taigrr/go-fastping synced 2025-01-18 05:03:15 -08:00

Add non-privileged datagram-oriented ICMP endpoints mode using UDP

This commit is contained in:
Tatsushi Demachi 2015-01-11 00:20:08 +09:00
parent 805934dbbf
commit 41b7495ae3
4 changed files with 242 additions and 178 deletions

View File

@ -44,8 +44,8 @@ callback. For more detail, refer [godoc][godoc] and if you need more example,
please see "cmd/ping/ping.go". please see "cmd/ping/ping.go".
## Caution ## Caution
This package now only implements ICMP ping using raw socket so the program This package implements ICMP ping using both raw socket and UDP. If your program
using this package needs to be run as root user. uses this package in raw socket mode, it needs to be run as a root user.
## License ## License
go-fastping is under MIT License. See the [LICENSE][license] file for details. go-fastping is under MIT License. See the [LICENSE][license] file for details.

View File

@ -1,14 +1,16 @@
package main package main
import ( import (
"flag"
"fmt" "fmt"
"github.com/tatsushid/go-fastping"
"net" "net"
"os" "os"
"os/signal" "os/signal"
"strings" "strings"
"syscall" "syscall"
"time" "time"
"github.com/tatsushid/go-fastping"
) )
type response struct { type response struct {
@ -17,17 +19,31 @@ type response struct {
} }
func main() { func main() {
if len(os.Args) != 2 { var useUDP bool
fmt.Fprintf(os.Stderr, "Usage: %s {hostname}\n", os.Args[0]) flag.BoolVar(&useUDP, "udp", false, "use non-privileged datagram-oriented UDP as ICMP endpoints")
flag.BoolVar(&useUDP, "u", false, "use non-privileged datagram-oriented UDP as ICMP endpoints (shorthand)")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage:\n %s [options] hostname\n\nOptions:\n", os.Args[0])
flag.PrintDefaults()
}
flag.Parse()
hostname := flag.Arg(0)
if len(hostname) == 0 {
flag.Usage()
os.Exit(1) os.Exit(1)
} }
p := fastping.NewPinger() p := fastping.NewPinger()
if useUDP {
p.Network("udp")
}
netProto := "ip4:icmp" netProto := "ip4:icmp"
if strings.Index(os.Args[1], ":") != -1 { if strings.Index(hostname, ":") != -1 {
netProto = "ip6:ipv6-icmp" netProto = "ip6:ipv6-icmp"
} }
ra, err := net.ResolveIPAddr(netProto, os.Args[1]) ra, err := net.ResolveIPAddr(netProto, hostname)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -29,8 +29,9 @@
// it calls "receive" callback. After that, MaxRTT time passed, it calls // it calls "receive" callback. After that, MaxRTT time passed, it calls
// "idle" callback. If you need more example, please see "cmd/ping/ping.go". // "idle" callback. If you need more example, please see "cmd/ping/ping.go".
// //
// This library needs to run as a superuser for sending ICMP packets so when // This library needs to run as a superuser for sending ICMP packets when
// you run go test, please run as a following // privileged raw ICMP endpoints is used so in such a case, to run go test
// for the package, please run like a following
// //
// sudo go test // sudo go test
// //
@ -54,6 +55,11 @@ import (
const TimeSliceLength = 8 const TimeSliceLength = 8
var (
ipv4Proto = map[string]string{"ip": "ip4:icmp", "udp": "udp4"}
ipv6Proto = map[string]string{"ip": "ip6:ipv6-icmp", "udp": "udp6"}
)
func byteSliceOfSize(n int) []byte { func byteSliceOfSize(n int) []byte {
b := make([]byte, n) b := make([]byte, n)
for i := 0; i < len(b); i++ { for i := 0; i < len(b); i++ {
@ -98,7 +104,7 @@ func ipv4Payload(b []byte) []byte {
type packet struct { type packet struct {
bytes []byte bytes []byte
addr *net.IPAddr addr net.Addr
} }
type context struct { type context struct {
@ -120,6 +126,7 @@ type Pinger struct {
seq int seq int
// key string is IPAddr.String() // key string is IPAddr.String()
addrs map[string]*net.IPAddr addrs map[string]*net.IPAddr
network string
hasIPv4 bool hasIPv4 bool
hasIPv6 bool hasIPv6 bool
ctx *context ctx *context
@ -147,6 +154,7 @@ func NewPinger() *Pinger {
id: rand.Intn(0xffff), id: rand.Intn(0xffff),
seq: rand.Intn(0xffff), seq: rand.Intn(0xffff),
addrs: make(map[string]*net.IPAddr), addrs: make(map[string]*net.IPAddr),
network: "ip",
hasIPv4: false, hasIPv4: false,
hasIPv6: false, hasIPv6: false,
Size: TimeSliceLength, Size: TimeSliceLength,
@ -157,6 +165,23 @@ func NewPinger() *Pinger {
} }
} }
// Network sets a network endpoints for ICMP ping and returns the previous
// setting. network arg should be "ip" or "udp" string or if others are
// specified, it returns an error. If this function isn't called, Pinger
// uses "ip" as default.
func (p *Pinger) Network(network string) (string, error) {
origNet := p.network
switch network {
case "ip":
fallthrough
case "udp":
p.network = network
default:
return origNet, errors.New(network + " can't be used as ICMP endpoint")
}
return origNet, nil
}
// AddIP adds an IP address to Pinger. ipaddr arg should be a string like // AddIP adds an IP address to Pinger. ipaddr arg should be a string like
// "192.0.2.1". // "192.0.2.1".
func (p *Pinger) AddIP(ipaddr string) error { func (p *Pinger) AddIP(ipaddr string) error {
@ -315,14 +340,14 @@ func (p *Pinger) run(once bool) {
p.debugln("Run(): Start") p.debugln("Run(): Start")
var conn, conn6 *icmp.PacketConn var conn, conn6 *icmp.PacketConn
if p.hasIPv4 { if p.hasIPv4 {
if conn = p.listen("ip4:icmp"); conn == nil { if conn = p.listen(ipv4Proto[p.network]); conn == nil {
return return
} }
defer conn.Close() defer conn.Close()
} }
if p.hasIPv6 { if p.hasIPv6 {
if conn6 = p.listen("ip6:ipv6-icmp"); conn6 == nil { if conn6 = p.listen(ipv6Proto[p.network]); conn6 == nil {
return return
} }
defer conn6.Close() defer conn6.Close()
@ -438,10 +463,14 @@ func (p *Pinger) sendICMP(conn, conn6 *icmp.PacketConn) (map[string]*net.IPAddr,
} }
queue[key] = addr queue[key] = addr
var dst net.Addr = addr
if p.network == "udp" {
dst = &net.UDPAddr{IP: addr.IP, Zone: addr.Zone}
}
p.debugln("sendICMP(): Invoke goroutine") p.debugln("sendICMP(): Invoke goroutine")
wg.Add(1) wg.Add(1)
go func(conn *icmp.PacketConn, ra *net.IPAddr, b []byte) { go func(conn *icmp.PacketConn, ra net.Addr, b []byte) {
for { for {
if _, err := conn.WriteTo(bytes, ra); err != nil { if _, err := conn.WriteTo(bytes, ra); err != nil {
if neterr, ok := err.(*net.OpError); ok { if neterr, ok := err.(*net.OpError); ok {
@ -454,7 +483,7 @@ func (p *Pinger) sendICMP(conn, conn6 *icmp.PacketConn) (map[string]*net.IPAddr,
} }
p.debugln("sendICMP(): WriteTo End") p.debugln("sendICMP(): WriteTo End")
wg.Done() wg.Done()
}(cn, addr, bytes) }(cn, dst, bytes)
} }
wg.Wait() wg.Wait()
p.debugln("sendICMP(): End") p.debugln("sendICMP(): End")
@ -498,13 +527,8 @@ func (p *Pinger) recvICMP(conn *icmp.PacketConn, recv chan<- *packet, ctx *conte
} }
p.debugln("recvICMP(): p.recv <- packet") p.debugln("recvICMP(): p.recv <- packet")
addr, ok := ra.(*net.IPAddr)
if !ok {
continue
}
select { select {
case recv <- &packet{bytes: bytes, addr: addr}: case recv <- &packet{bytes: bytes, addr: ra}:
case <-ctx.stop: case <-ctx.stop:
p.debugln("recvICMP(): <-ctx.stop") p.debugln("recvICMP(): <-ctx.stop")
wg.Done() wg.Done()
@ -515,7 +539,17 @@ func (p *Pinger) recvICMP(conn *icmp.PacketConn, recv chan<- *packet, ctx *conte
} }
func (p *Pinger) procRecv(recv *packet, queue map[string]*net.IPAddr) { func (p *Pinger) procRecv(recv *packet, queue map[string]*net.IPAddr) {
addr := recv.addr.String() var ipaddr *net.IPAddr
switch adr := recv.addr.(type) {
case *net.IPAddr:
ipaddr = adr
case *net.UDPAddr:
ipaddr = &net.IPAddr{IP: adr.IP, Zone: adr.Zone}
default:
return
}
addr := ipaddr.String()
p.mu.Lock() p.mu.Lock()
if _, ok := p.addrs[addr]; !ok { if _, ok := p.addrs[addr]; !ok {
p.mu.Unlock() p.mu.Unlock()
@ -525,10 +559,14 @@ func (p *Pinger) procRecv(recv *packet, queue map[string]*net.IPAddr) {
var bytes []byte var bytes []byte
var proto int var proto int
if isIPv4(recv.addr.IP) { if isIPv4(ipaddr.IP) {
bytes = ipv4Payload(recv.bytes) if p.network == "ip" {
bytes = ipv4Payload(recv.bytes)
} else {
bytes = recv.bytes
}
proto = iana.ProtocolICMP proto = iana.ProtocolICMP
} else if isIPv6(recv.addr.IP) { } else if isIPv6(ipaddr.IP) {
bytes = recv.bytes bytes = recv.bytes
proto = iana.ProtocolIPv6ICMP proto = iana.ProtocolIPv6ICMP
} else { } else {
@ -563,7 +601,7 @@ func (p *Pinger) procRecv(recv *packet, queue map[string]*net.IPAddr) {
handler := p.OnRecv handler := p.OnRecv
p.mu.Unlock() p.mu.Unlock()
if handler != nil { if handler != nil {
handler(recv.addr, rtt) handler(ipaddr, rtt)
} }
} }
} }

View File

@ -38,177 +38,187 @@ func TestAddIP(t *testing.T) {
} }
func TestRun(t *testing.T) { func TestRun(t *testing.T) {
p := NewPinger() for _, network := range []string{"ip", "udp"} {
p := NewPinger()
p.Network(network)
if err := p.AddIP("127.0.0.1"); err != nil { if err := p.AddIP("127.0.0.1"); err != nil {
t.Fatalf("AddIP failed: %v", err) t.Fatalf("AddIP failed: %v", err)
}
if err := p.AddIP("127.0.0.100"); err != nil {
t.Fatalf("AddIP failed: %v", err)
}
if err := p.AddIP("::1"); err != nil {
t.Fatalf("AddIP failed: %v", err)
}
found1, found100, foundv6 := false, false, false
called, idle := false, false
p.OnRecv = func(ip *net.IPAddr, d time.Duration) {
called = true
if ip.String() == "127.0.0.1" {
found1 = true
} else if ip.String() == "127.0.0.100" {
found100 = true
} else if ip.String() == "::1" {
foundv6 = true
} }
}
p.OnIdle = func() { if err := p.AddIP("127.0.0.100"); err != nil {
idle = true t.Fatalf("AddIP failed: %v", err)
} }
err := p.Run() if err := p.AddIP("::1"); err != nil {
if err != nil { t.Fatalf("AddIP failed: %v", err)
t.Fatalf("Pinger returns error: %v", err) }
}
if !called { found1, found100, foundv6 := false, false, false
t.Fatalf("Pinger didn't get any responses") called, idle := false, false
} p.OnRecv = func(ip *net.IPAddr, d time.Duration) {
if !idle { called = true
t.Fatalf("Pinger didn't call OnIdle function") if ip.String() == "127.0.0.1" {
} found1 = true
if !found1 { } else if ip.String() == "127.0.0.100" {
t.Fatalf("Pinger `127.0.0.1` didn't respond") found100 = true
} } else if ip.String() == "::1" {
if found100 { foundv6 = true
t.Fatalf("Pinger `127.0.0.100` responded") }
} }
if !foundv6 {
t.Fatalf("Pinger `::1` didn't responded") p.OnIdle = func() {
idle = true
}
err := p.Run()
if err != nil {
t.Fatalf("Pinger returns error: %v", err)
}
if !called {
t.Fatalf("Pinger didn't get any responses")
}
if !idle {
t.Fatalf("Pinger didn't call OnIdle function")
}
if !found1 {
t.Fatalf("Pinger `127.0.0.1` didn't respond")
}
if found100 {
t.Fatalf("Pinger `127.0.0.100` responded")
}
if !foundv6 {
t.Fatalf("Pinger `::1` didn't responded")
}
} }
} }
func TestMultiRun(t *testing.T) { func TestMultiRun(t *testing.T) {
p1 := NewPinger() for _, network := range []string{"ip", "udp"} {
p2 := NewPinger() p1 := NewPinger()
p1.Network(network)
p2 := NewPinger()
p2.Network(network)
if err := p1.AddIP("127.0.0.1"); err != nil { if err := p1.AddIP("127.0.0.1"); err != nil {
t.Fatalf("AddIP 1 failed: %v", err) t.Fatalf("AddIP 1 failed: %v", err)
}
if err := p2.AddIP("127.0.0.1"); err != nil {
t.Fatalf("AddIP 2 failed: %v", err)
}
var mu sync.Mutex
res1 := 0
p1.OnRecv = func(*net.IPAddr, time.Duration) {
mu.Lock()
res1++
mu.Unlock()
}
res2 := 0
p2.OnRecv = func(*net.IPAddr, time.Duration) {
mu.Lock()
res2++
mu.Unlock()
}
p1.MaxRTT, p2.MaxRTT = time.Millisecond*100, time.Millisecond*100
if err := p1.Run(); err != nil {
t.Fatalf("Pinger 1 returns error: %v", err)
}
if res1 == 0 {
t.Fatalf("Pinger 1 didn't get any responses")
}
if res2 > 0 {
t.Fatalf("Pinger 2 got response")
}
res1, res2 = 0, 0
if err := p2.Run(); err != nil {
t.Fatalf("Pinger 2 returns error: %v", err)
}
if res1 > 0 {
t.Fatalf("Pinger 1 got response")
}
if res2 == 0 {
t.Fatalf("Pinger 2 didn't get any responses")
}
res1, res2 = 0, 0
errch1, errch2 := make(chan error), make(chan error)
go func(ch chan error) {
err := p1.Run()
if err != nil {
ch <- err
} }
}(errch1)
go func(ch chan error) { if err := p2.AddIP("127.0.0.1"); err != nil {
err := p2.Run() t.Fatalf("AddIP 2 failed: %v", err)
if err != nil { }
ch <- err
var mu sync.Mutex
res1 := 0
p1.OnRecv = func(*net.IPAddr, time.Duration) {
mu.Lock()
res1++
mu.Unlock()
}
res2 := 0
p2.OnRecv = func(*net.IPAddr, time.Duration) {
mu.Lock()
res2++
mu.Unlock()
}
p1.MaxRTT, p2.MaxRTT = time.Millisecond*100, time.Millisecond*100
if err := p1.Run(); err != nil {
t.Fatalf("Pinger 1 returns error: %v", err)
}
if res1 == 0 {
t.Fatalf("Pinger 1 didn't get any responses")
}
if res2 > 0 {
t.Fatalf("Pinger 2 got response")
}
res1, res2 = 0, 0
if err := p2.Run(); err != nil {
t.Fatalf("Pinger 2 returns error: %v", err)
}
if res1 > 0 {
t.Fatalf("Pinger 1 got response")
}
if res2 == 0 {
t.Fatalf("Pinger 2 didn't get any responses")
}
res1, res2 = 0, 0
errch1, errch2 := make(chan error), make(chan error)
go func(ch chan error) {
err := p1.Run()
if err != nil {
ch <- err
}
}(errch1)
go func(ch chan error) {
err := p2.Run()
if err != nil {
ch <- err
}
}(errch2)
ticker := time.NewTicker(time.Millisecond * 200)
select {
case err := <-errch1:
t.Fatalf("Pinger 1 returns error: %v", err)
case err := <-errch2:
t.Fatalf("Pinger 2 returns error: %v", err)
case <-ticker.C:
break
}
mu.Lock()
defer mu.Unlock()
if res1 != 1 {
t.Fatalf("Pinger 1 didn't get correct response")
}
if res2 != 1 {
t.Fatalf("Pinger 2 didn't get correct response")
} }
}(errch2)
ticker := time.NewTicker(time.Millisecond * 200)
select {
case err := <-errch1:
t.Fatalf("Pinger 1 returns error: %v", err)
case err := <-errch2:
t.Fatalf("Pinger 2 returns error: %v", err)
case <-ticker.C:
break
}
mu.Lock()
defer mu.Unlock()
if res1 != 1 {
t.Fatalf("Pinger 1 didn't get correct response")
}
if res2 != 1 {
t.Fatalf("Pinger 2 didn't get correct response")
} }
} }
func TestRunLoop(t *testing.T) { func TestRunLoop(t *testing.T) {
p := NewPinger() for _, network := range []string{"ip", "udp"} {
p := NewPinger()
p.Network(network)
if err := p.AddIP("127.0.0.1"); err != nil { if err := p.AddIP("127.0.0.1"); err != nil {
t.Fatalf("AddIP failed: %v", err) t.Fatalf("AddIP failed: %v", err)
}
p.MaxRTT = time.Millisecond * 100
recvCount, idleCount := 0, 0
p.OnRecv = func(*net.IPAddr, time.Duration) {
recvCount++
}
p.OnIdle = func() {
idleCount++
}
var err error
p.RunLoop()
ticker := time.NewTicker(time.Millisecond * 250)
select {
case <-p.Done():
if err = p.Err(); err != nil {
t.Fatalf("Pinger returns error %v", err)
} }
case <-ticker.C: p.MaxRTT = time.Millisecond * 100
break
}
ticker.Stop()
p.Stop()
if recvCount < 2 { recvCount, idleCount := 0, 0
t.Fatalf("Pinger receive count less than 2") p.OnRecv = func(*net.IPAddr, time.Duration) {
} recvCount++
if idleCount < 2 { }
t.Fatalf("Pinger idle count less than 2")
p.OnIdle = func() {
idleCount++
}
var err error
p.RunLoop()
ticker := time.NewTicker(time.Millisecond * 250)
select {
case <-p.Done():
if err = p.Err(); err != nil {
t.Fatalf("Pinger returns error %v", err)
}
case <-ticker.C:
break
}
ticker.Stop()
p.Stop()
if recvCount < 2 {
t.Fatalf("Pinger receive count less than 2")
}
if idleCount < 2 {
t.Fatalf("Pinger idle count less than 2")
}
} }
} }