1
0
mirror of https://github.com/taigrr/go-fastping synced 2025-01-18 05:03:15 -08:00

Add non-privileged datagram-oriented ICMP endpoints mode using UDP

This commit is contained in:
Tatsushi Demachi 2015-01-11 00:20:08 +09:00
parent 805934dbbf
commit 41b7495ae3
4 changed files with 242 additions and 178 deletions

View File

@ -44,8 +44,8 @@ callback. For more detail, refer [godoc][godoc] and if you need more example,
please see "cmd/ping/ping.go".
## Caution
This package now only implements ICMP ping using raw socket so the program
using this package needs to be run as root user.
This package implements ICMP ping using both raw socket and UDP. If your program
uses this package in raw socket mode, it needs to be run as a root user.
## License
go-fastping is under MIT License. See the [LICENSE][license] file for details.

View File

@ -1,14 +1,16 @@
package main
import (
"flag"
"fmt"
"github.com/tatsushid/go-fastping"
"net"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/tatsushid/go-fastping"
)
type response struct {
@ -17,17 +19,31 @@ type response struct {
}
func main() {
if len(os.Args) != 2 {
fmt.Fprintf(os.Stderr, "Usage: %s {hostname}\n", os.Args[0])
var useUDP bool
flag.BoolVar(&useUDP, "udp", false, "use non-privileged datagram-oriented UDP as ICMP endpoints")
flag.BoolVar(&useUDP, "u", false, "use non-privileged datagram-oriented UDP as ICMP endpoints (shorthand)")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage:\n %s [options] hostname\n\nOptions:\n", os.Args[0])
flag.PrintDefaults()
}
flag.Parse()
hostname := flag.Arg(0)
if len(hostname) == 0 {
flag.Usage()
os.Exit(1)
}
p := fastping.NewPinger()
if useUDP {
p.Network("udp")
}
netProto := "ip4:icmp"
if strings.Index(os.Args[1], ":") != -1 {
if strings.Index(hostname, ":") != -1 {
netProto = "ip6:ipv6-icmp"
}
ra, err := net.ResolveIPAddr(netProto, os.Args[1])
ra, err := net.ResolveIPAddr(netProto, hostname)
if err != nil {
fmt.Println(err)
os.Exit(1)

View File

@ -29,8 +29,9 @@
// it calls "receive" callback. After that, MaxRTT time passed, it calls
// "idle" callback. If you need more example, please see "cmd/ping/ping.go".
//
// This library needs to run as a superuser for sending ICMP packets so when
// you run go test, please run as a following
// This library needs to run as a superuser for sending ICMP packets when
// privileged raw ICMP endpoints is used so in such a case, to run go test
// for the package, please run like a following
//
// sudo go test
//
@ -54,6 +55,11 @@ import (
const TimeSliceLength = 8
var (
ipv4Proto = map[string]string{"ip": "ip4:icmp", "udp": "udp4"}
ipv6Proto = map[string]string{"ip": "ip6:ipv6-icmp", "udp": "udp6"}
)
func byteSliceOfSize(n int) []byte {
b := make([]byte, n)
for i := 0; i < len(b); i++ {
@ -98,7 +104,7 @@ func ipv4Payload(b []byte) []byte {
type packet struct {
bytes []byte
addr *net.IPAddr
addr net.Addr
}
type context struct {
@ -120,6 +126,7 @@ type Pinger struct {
seq int
// key string is IPAddr.String()
addrs map[string]*net.IPAddr
network string
hasIPv4 bool
hasIPv6 bool
ctx *context
@ -147,6 +154,7 @@ func NewPinger() *Pinger {
id: rand.Intn(0xffff),
seq: rand.Intn(0xffff),
addrs: make(map[string]*net.IPAddr),
network: "ip",
hasIPv4: false,
hasIPv6: false,
Size: TimeSliceLength,
@ -157,6 +165,23 @@ func NewPinger() *Pinger {
}
}
// Network sets a network endpoints for ICMP ping and returns the previous
// setting. network arg should be "ip" or "udp" string or if others are
// specified, it returns an error. If this function isn't called, Pinger
// uses "ip" as default.
func (p *Pinger) Network(network string) (string, error) {
origNet := p.network
switch network {
case "ip":
fallthrough
case "udp":
p.network = network
default:
return origNet, errors.New(network + " can't be used as ICMP endpoint")
}
return origNet, nil
}
// AddIP adds an IP address to Pinger. ipaddr arg should be a string like
// "192.0.2.1".
func (p *Pinger) AddIP(ipaddr string) error {
@ -315,14 +340,14 @@ func (p *Pinger) run(once bool) {
p.debugln("Run(): Start")
var conn, conn6 *icmp.PacketConn
if p.hasIPv4 {
if conn = p.listen("ip4:icmp"); conn == nil {
if conn = p.listen(ipv4Proto[p.network]); conn == nil {
return
}
defer conn.Close()
}
if p.hasIPv6 {
if conn6 = p.listen("ip6:ipv6-icmp"); conn6 == nil {
if conn6 = p.listen(ipv6Proto[p.network]); conn6 == nil {
return
}
defer conn6.Close()
@ -438,10 +463,14 @@ func (p *Pinger) sendICMP(conn, conn6 *icmp.PacketConn) (map[string]*net.IPAddr,
}
queue[key] = addr
var dst net.Addr = addr
if p.network == "udp" {
dst = &net.UDPAddr{IP: addr.IP, Zone: addr.Zone}
}
p.debugln("sendICMP(): Invoke goroutine")
wg.Add(1)
go func(conn *icmp.PacketConn, ra *net.IPAddr, b []byte) {
go func(conn *icmp.PacketConn, ra net.Addr, b []byte) {
for {
if _, err := conn.WriteTo(bytes, ra); err != nil {
if neterr, ok := err.(*net.OpError); ok {
@ -454,7 +483,7 @@ func (p *Pinger) sendICMP(conn, conn6 *icmp.PacketConn) (map[string]*net.IPAddr,
}
p.debugln("sendICMP(): WriteTo End")
wg.Done()
}(cn, addr, bytes)
}(cn, dst, bytes)
}
wg.Wait()
p.debugln("sendICMP(): End")
@ -498,13 +527,8 @@ func (p *Pinger) recvICMP(conn *icmp.PacketConn, recv chan<- *packet, ctx *conte
}
p.debugln("recvICMP(): p.recv <- packet")
addr, ok := ra.(*net.IPAddr)
if !ok {
continue
}
select {
case recv <- &packet{bytes: bytes, addr: addr}:
case recv <- &packet{bytes: bytes, addr: ra}:
case <-ctx.stop:
p.debugln("recvICMP(): <-ctx.stop")
wg.Done()
@ -515,7 +539,17 @@ func (p *Pinger) recvICMP(conn *icmp.PacketConn, recv chan<- *packet, ctx *conte
}
func (p *Pinger) procRecv(recv *packet, queue map[string]*net.IPAddr) {
addr := recv.addr.String()
var ipaddr *net.IPAddr
switch adr := recv.addr.(type) {
case *net.IPAddr:
ipaddr = adr
case *net.UDPAddr:
ipaddr = &net.IPAddr{IP: adr.IP, Zone: adr.Zone}
default:
return
}
addr := ipaddr.String()
p.mu.Lock()
if _, ok := p.addrs[addr]; !ok {
p.mu.Unlock()
@ -525,10 +559,14 @@ func (p *Pinger) procRecv(recv *packet, queue map[string]*net.IPAddr) {
var bytes []byte
var proto int
if isIPv4(recv.addr.IP) {
bytes = ipv4Payload(recv.bytes)
if isIPv4(ipaddr.IP) {
if p.network == "ip" {
bytes = ipv4Payload(recv.bytes)
} else {
bytes = recv.bytes
}
proto = iana.ProtocolICMP
} else if isIPv6(recv.addr.IP) {
} else if isIPv6(ipaddr.IP) {
bytes = recv.bytes
proto = iana.ProtocolIPv6ICMP
} else {
@ -563,7 +601,7 @@ func (p *Pinger) procRecv(recv *packet, queue map[string]*net.IPAddr) {
handler := p.OnRecv
p.mu.Unlock()
if handler != nil {
handler(recv.addr, rtt)
handler(ipaddr, rtt)
}
}
}

View File

@ -38,177 +38,187 @@ func TestAddIP(t *testing.T) {
}
func TestRun(t *testing.T) {
p := NewPinger()
for _, network := range []string{"ip", "udp"} {
p := NewPinger()
p.Network(network)
if err := p.AddIP("127.0.0.1"); err != nil {
t.Fatalf("AddIP failed: %v", err)
}
if err := p.AddIP("127.0.0.100"); err != nil {
t.Fatalf("AddIP failed: %v", err)
}
if err := p.AddIP("::1"); err != nil {
t.Fatalf("AddIP failed: %v", err)
}
found1, found100, foundv6 := false, false, false
called, idle := false, false
p.OnRecv = func(ip *net.IPAddr, d time.Duration) {
called = true
if ip.String() == "127.0.0.1" {
found1 = true
} else if ip.String() == "127.0.0.100" {
found100 = true
} else if ip.String() == "::1" {
foundv6 = true
if err := p.AddIP("127.0.0.1"); err != nil {
t.Fatalf("AddIP failed: %v", err)
}
}
p.OnIdle = func() {
idle = true
}
if err := p.AddIP("127.0.0.100"); err != nil {
t.Fatalf("AddIP failed: %v", err)
}
err := p.Run()
if err != nil {
t.Fatalf("Pinger returns error: %v", err)
}
if !called {
t.Fatalf("Pinger didn't get any responses")
}
if !idle {
t.Fatalf("Pinger didn't call OnIdle function")
}
if !found1 {
t.Fatalf("Pinger `127.0.0.1` didn't respond")
}
if found100 {
t.Fatalf("Pinger `127.0.0.100` responded")
}
if !foundv6 {
t.Fatalf("Pinger `::1` didn't responded")
if err := p.AddIP("::1"); err != nil {
t.Fatalf("AddIP failed: %v", err)
}
found1, found100, foundv6 := false, false, false
called, idle := false, false
p.OnRecv = func(ip *net.IPAddr, d time.Duration) {
called = true
if ip.String() == "127.0.0.1" {
found1 = true
} else if ip.String() == "127.0.0.100" {
found100 = true
} else if ip.String() == "::1" {
foundv6 = true
}
}
p.OnIdle = func() {
idle = true
}
err := p.Run()
if err != nil {
t.Fatalf("Pinger returns error: %v", err)
}
if !called {
t.Fatalf("Pinger didn't get any responses")
}
if !idle {
t.Fatalf("Pinger didn't call OnIdle function")
}
if !found1 {
t.Fatalf("Pinger `127.0.0.1` didn't respond")
}
if found100 {
t.Fatalf("Pinger `127.0.0.100` responded")
}
if !foundv6 {
t.Fatalf("Pinger `::1` didn't responded")
}
}
}
func TestMultiRun(t *testing.T) {
p1 := NewPinger()
p2 := NewPinger()
for _, network := range []string{"ip", "udp"} {
p1 := NewPinger()
p1.Network(network)
p2 := NewPinger()
p2.Network(network)
if err := p1.AddIP("127.0.0.1"); err != nil {
t.Fatalf("AddIP 1 failed: %v", err)
}
if err := p2.AddIP("127.0.0.1"); err != nil {
t.Fatalf("AddIP 2 failed: %v", err)
}
var mu sync.Mutex
res1 := 0
p1.OnRecv = func(*net.IPAddr, time.Duration) {
mu.Lock()
res1++
mu.Unlock()
}
res2 := 0
p2.OnRecv = func(*net.IPAddr, time.Duration) {
mu.Lock()
res2++
mu.Unlock()
}
p1.MaxRTT, p2.MaxRTT = time.Millisecond*100, time.Millisecond*100
if err := p1.Run(); err != nil {
t.Fatalf("Pinger 1 returns error: %v", err)
}
if res1 == 0 {
t.Fatalf("Pinger 1 didn't get any responses")
}
if res2 > 0 {
t.Fatalf("Pinger 2 got response")
}
res1, res2 = 0, 0
if err := p2.Run(); err != nil {
t.Fatalf("Pinger 2 returns error: %v", err)
}
if res1 > 0 {
t.Fatalf("Pinger 1 got response")
}
if res2 == 0 {
t.Fatalf("Pinger 2 didn't get any responses")
}
res1, res2 = 0, 0
errch1, errch2 := make(chan error), make(chan error)
go func(ch chan error) {
err := p1.Run()
if err != nil {
ch <- err
if err := p1.AddIP("127.0.0.1"); err != nil {
t.Fatalf("AddIP 1 failed: %v", err)
}
}(errch1)
go func(ch chan error) {
err := p2.Run()
if err != nil {
ch <- err
if err := p2.AddIP("127.0.0.1"); err != nil {
t.Fatalf("AddIP 2 failed: %v", err)
}
var mu sync.Mutex
res1 := 0
p1.OnRecv = func(*net.IPAddr, time.Duration) {
mu.Lock()
res1++
mu.Unlock()
}
res2 := 0
p2.OnRecv = func(*net.IPAddr, time.Duration) {
mu.Lock()
res2++
mu.Unlock()
}
p1.MaxRTT, p2.MaxRTT = time.Millisecond*100, time.Millisecond*100
if err := p1.Run(); err != nil {
t.Fatalf("Pinger 1 returns error: %v", err)
}
if res1 == 0 {
t.Fatalf("Pinger 1 didn't get any responses")
}
if res2 > 0 {
t.Fatalf("Pinger 2 got response")
}
res1, res2 = 0, 0
if err := p2.Run(); err != nil {
t.Fatalf("Pinger 2 returns error: %v", err)
}
if res1 > 0 {
t.Fatalf("Pinger 1 got response")
}
if res2 == 0 {
t.Fatalf("Pinger 2 didn't get any responses")
}
res1, res2 = 0, 0
errch1, errch2 := make(chan error), make(chan error)
go func(ch chan error) {
err := p1.Run()
if err != nil {
ch <- err
}
}(errch1)
go func(ch chan error) {
err := p2.Run()
if err != nil {
ch <- err
}
}(errch2)
ticker := time.NewTicker(time.Millisecond * 200)
select {
case err := <-errch1:
t.Fatalf("Pinger 1 returns error: %v", err)
case err := <-errch2:
t.Fatalf("Pinger 2 returns error: %v", err)
case <-ticker.C:
break
}
mu.Lock()
defer mu.Unlock()
if res1 != 1 {
t.Fatalf("Pinger 1 didn't get correct response")
}
if res2 != 1 {
t.Fatalf("Pinger 2 didn't get correct response")
}
}(errch2)
ticker := time.NewTicker(time.Millisecond * 200)
select {
case err := <-errch1:
t.Fatalf("Pinger 1 returns error: %v", err)
case err := <-errch2:
t.Fatalf("Pinger 2 returns error: %v", err)
case <-ticker.C:
break
}
mu.Lock()
defer mu.Unlock()
if res1 != 1 {
t.Fatalf("Pinger 1 didn't get correct response")
}
if res2 != 1 {
t.Fatalf("Pinger 2 didn't get correct response")
}
}
func TestRunLoop(t *testing.T) {
p := NewPinger()
for _, network := range []string{"ip", "udp"} {
p := NewPinger()
p.Network(network)
if err := p.AddIP("127.0.0.1"); err != nil {
t.Fatalf("AddIP failed: %v", err)
}
p.MaxRTT = time.Millisecond * 100
recvCount, idleCount := 0, 0
p.OnRecv = func(*net.IPAddr, time.Duration) {
recvCount++
}
p.OnIdle = func() {
idleCount++
}
var err error
p.RunLoop()
ticker := time.NewTicker(time.Millisecond * 250)
select {
case <-p.Done():
if err = p.Err(); err != nil {
t.Fatalf("Pinger returns error %v", err)
if err := p.AddIP("127.0.0.1"); err != nil {
t.Fatalf("AddIP failed: %v", err)
}
case <-ticker.C:
break
}
ticker.Stop()
p.Stop()
p.MaxRTT = time.Millisecond * 100
if recvCount < 2 {
t.Fatalf("Pinger receive count less than 2")
}
if idleCount < 2 {
t.Fatalf("Pinger idle count less than 2")
recvCount, idleCount := 0, 0
p.OnRecv = func(*net.IPAddr, time.Duration) {
recvCount++
}
p.OnIdle = func() {
idleCount++
}
var err error
p.RunLoop()
ticker := time.NewTicker(time.Millisecond * 250)
select {
case <-p.Done():
if err = p.Err(); err != nil {
t.Fatalf("Pinger returns error %v", err)
}
case <-ticker.C:
break
}
ticker.Stop()
p.Stop()
if recvCount < 2 {
t.Fatalf("Pinger receive count less than 2")
}
if idleCount < 2 {
t.Fatalf("Pinger idle count less than 2")
}
}
}