diff --git a/fastping.go b/fastping.go index fcf5ec7..119d55c 100644 --- a/fastping.go +++ b/fastping.go @@ -98,6 +98,7 @@ type Pinger struct { // key string is IPAddr.String() addrs map[string]*net.IPAddr ctx *context + mu sync.Mutex // Number of (nano,milli)seconds of an idle timeout. Once it passed, // the library calls an idle callback function. It is also used for an // interval time of RunLoop() method @@ -127,13 +128,17 @@ func (p *Pinger) AddIP(ipaddr string) error { if addr == nil { return errors.New(fmt.Sprintf("%s is not a valid textual representation of an IP address", ipaddr)) } + p.mu.Lock() p.addrs[addr.String()] = &net.IPAddr{IP: addr} + p.mu.Unlock() return nil } // Add an IP address to Pinger. ip arg should be a net.IPAddr pointer. func (p *Pinger) AddIPAddr(ip *net.IPAddr) { + p.mu.Lock() p.addrs[ip.String()] = ip + p.mu.Unlock() } // Add event handler to Pinger. event arg should be "receive" or "idle" string. @@ -155,14 +160,18 @@ func (p *Pinger) AddHandler(event string, handler interface{}) error { switch event { case "receive": if hdl, ok := handler.(func(*net.IPAddr, time.Duration)); ok { + p.mu.Lock() p.handlers[event] = hdl + p.mu.Unlock() return nil } else { return errors.New(fmt.Sprintf("Receive event handler should be `func(*net.IPAddr, time.Duration)`")) } case "idle": if hdl, ok := handler.(func()); ok { + p.mu.Lock() p.handlers[event] = hdl + p.mu.Unlock() return nil } else { return errors.New(fmt.Sprintf("Idle event handler should be `func()`")) @@ -178,8 +187,12 @@ func (p *Pinger) AddHandler(event string, handler interface{}) error { // an error value. It means it blocks until MaxRTT seconds passed. For the // purpose of sending/receiving packets over and over, use RunLoop(). func (p *Pinger) Run() error { + p.mu.Lock() p.ctx = newContext() + p.mu.Unlock() p.run(true) + p.mu.Lock() + defer p.mu.Unlock() return p.ctx.err } @@ -207,7 +220,9 @@ func (p *Pinger) Run() error { // // For more details, please see "cmd/ping/ping.go". func (p *Pinger) RunLoop() { + p.mu.Lock() p.ctx = newContext() + p.mu.Unlock() go p.run(false) } @@ -228,6 +243,8 @@ func (p *Pinger) Stop() { // Return an error that is set by RunLoop(). It must be called after RunLoop(). // If not, it causes panic. func (p *Pinger) Err() error { + p.mu.Lock() + defer p.mu.Unlock() return p.ctx.err } @@ -235,7 +252,9 @@ func (p *Pinger) run(once bool) { p.debugln("Run(): Start") conn, err := net.ListenIP("ip4:icmp", &net.IPAddr{IP: net.IPv4zero}) if err != nil { + p.mu.Lock() p.ctx.err = err + p.mu.Unlock() p.debugln("Run(): close(p.ctx.done)") close(p.ctx.done) return @@ -261,10 +280,15 @@ mainloop: break mainloop case <-recvCtx.done: p.debugln("Run(): <-recvCtx.done") + p.mu.Lock() err = recvCtx.err + p.mu.Unlock() break mainloop case <-ticker.C: - if handler, ok := p.handlers["idle"]; ok && handler != nil { + p.mu.Lock() + handler, ok := p.handlers["idle"] + p.mu.Unlock() + if ok && handler != nil { if hdl, ok := handler.(func()); ok { hdl() } @@ -286,7 +310,11 @@ mainloop: close(recvCtx.stop) p.debugln("Run(): <-recvCtx.done") <-recvCtx.done + + p.mu.Lock() p.ctx.err = err + p.mu.Unlock() + p.debugln("Run(): close(p.ctx.done)") close(p.ctx.done) p.debugln("Run(): End") @@ -294,11 +322,14 @@ mainloop: func (p *Pinger) sendICMP4(conn *net.IPConn) (map[string]*net.IPAddr, error) { p.debugln("sendICMP4(): Start") + p.mu.Lock() p.id = rand.Intn(0xffff) p.seq = rand.Intn(0xffff) + p.mu.Unlock() queue := make(map[string]*net.IPAddr) var wg sync.WaitGroup for k, v := range p.addrs { + p.mu.Lock() bytes, err := (&icmpMessage{ Type: icmpv4EchoRequest, Code: 0, Body: &icmpEcho{ @@ -306,6 +337,7 @@ func (p *Pinger) sendICMP4(conn *net.IPConn) (map[string]*net.IPAddr, error) { Data: timeToBytes(time.Now()), }, }).Marshal() + p.mu.Unlock() if err != nil { wg.Wait() return queue, err @@ -359,7 +391,9 @@ func (p *Pinger) recvICMP4(conn *net.IPConn, recv chan<- *packet, ctx *context) continue } else { p.debugln("recvICMP4(): OpError happen", err) + p.mu.Lock() ctx.err = err + p.mu.Unlock() close(ctx.done) return } @@ -372,9 +406,12 @@ func (p *Pinger) recvICMP4(conn *net.IPConn, recv chan<- *packet, ctx *context) func (p *Pinger) procRecv(recv *packet, queue map[string]*net.IPAddr) { addr := recv.addr.String() + p.mu.Lock() if _, ok := p.addrs[addr]; !ok { + p.mu.Unlock() return } + p.mu.Unlock() bytes := ipv4Payload(recv.bytes) var m *icmpMessage @@ -390,16 +427,21 @@ func (p *Pinger) procRecv(recv *packet, queue map[string]*net.IPAddr) { var rtt time.Duration switch pkt := m.Body.(type) { case *icmpEcho: + p.mu.Lock() if pkt.ID == p.id && pkt.Seq == p.seq { rtt = time.Since(bytesToTime(pkt.Data)) } + p.mu.Unlock() default: return } if _, ok := queue[addr]; ok { delete(queue, addr) - if handler, ok := p.handlers["receive"]; ok { + p.mu.Lock() + handler, ok := p.handlers["receive"] + p.mu.Unlock() + if ok && handler != nil { if hdl, ok := handler.(func(*net.IPAddr, time.Duration)); ok { hdl(recv.addr, rtt) } @@ -408,12 +450,16 @@ func (p *Pinger) procRecv(recv *packet, queue map[string]*net.IPAddr) { } func (p *Pinger) debugln(args ...interface{}) { + p.mu.Lock() + defer p.mu.Unlock() if p.Debug { log.Println(args...) } } func (p *Pinger) debugf(format string, args ...interface{}) { + p.mu.Lock() + defer p.mu.Unlock() if p.Debug { log.Printf(format, args...) }