diff --git a/selfupdate/selfupdate.go b/selfupdate/selfupdate.go index 33d2b6a..f4ca576 100644 --- a/selfupdate/selfupdate.go +++ b/selfupdate/selfupdate.go @@ -1,35 +1,35 @@ package selfupdate import ( - "bitbucket.org/kardianos/osext" - "bytes" - "compress/gzip" - "crypto/sha256" - "encoding/json" - "errors" - "fmt" - "github.com/inconshreveable/go-update" - "github.com/kr/binarydist" - "io" - "io/ioutil" - "log" - "math/rand" - "net/http" - "os" - "runtime" - "time" - "path/filepath" + "bitbucket.org/kardianos/osext" + "bytes" + "compress/gzip" + "crypto/sha256" + "encoding/json" + "errors" + "fmt" + "github.com/inconshreveable/go-update" + "github.com/kr/binarydist" + "io" + "io/ioutil" + "log" + "math/rand" + "net/http" + "os" + "path/filepath" + "runtime" + "time" ) - + const ( - upcktimePath = "cktime" - plat = runtime.GOOS + "-" + runtime.GOARCH + upcktimePath = "cktime" + plat = runtime.GOOS + "-" + runtime.GOARCH ) - + const devValidTime = 7 * 24 * time.Hour - + var ErrHashMismatch = errors.New("new file hash mismatch after patch") - + // Update protocol. // // GET hk.heroku.com/hk/linux-amd64.json @@ -54,210 +54,210 @@ var ErrHashMismatch = errors.New("new file hash mismatch after patch") // 200 ok // [gzipped executable data] type Updater struct { - CurrentVersion string - ApiURL string - CmdName string - BinURL string - DiffURL string - Dir string - Info struct { - Version string - Sha256 []byte - } + CurrentVersion string + ApiURL string + CmdName string + BinURL string + DiffURL string + Dir string + Info struct { + Version string + Sha256 []byte + } } func (u *Updater) getExecRelativeDir(dir string) string { - filename, _ := osext.Executable() - path := filepath.Join(filepath.Dir(filename), dir) - fmt.Println(path) - return path + filename, _ := osext.Executable() + path := filepath.Join(filepath.Dir(filename), dir) + fmt.Println(path) + return path } - + func (u *Updater) BackgroundRun() { - os.MkdirAll(u.getExecRelativeDir(u.Dir), 0777) - if u.wantUpdate() { - if err := update.SanityCheck(); err != nil { - // fail - return - } - //self, err := osext.Executable() - //if err != nil { - // fail update, couldn't figure out path to self - //return - //} - // TODO(bgentry): logger isn't on Windows. Replace w/ proper error reports. - if err := u.update(); err != nil { - log.Println(err) - } - } + os.MkdirAll(u.getExecRelativeDir(u.Dir), 0777) + if u.wantUpdate() { + if err := update.SanityCheck(); err != nil { + // fail + return + } + //self, err := osext.Executable() + //if err != nil { + // fail update, couldn't figure out path to self + //return + //} + // TODO(bgentry): logger isn't on Windows. Replace w/ proper error reports. + if err := u.update(); err != nil { + log.Println(err) + } + } } - + func (u *Updater) wantUpdate() bool { - path := u.getExecRelativeDir(u.Dir + upcktimePath) - if u.CurrentVersion == "dev" || readTime(path).After(time.Now()) { - return false - } - wait := 24*time.Hour + randDuration(24*time.Hour) - return writeTime(path, time.Now().Add(wait)) + path := u.getExecRelativeDir(u.Dir + upcktimePath) + if u.CurrentVersion == "dev" || readTime(path).After(time.Now()) { + return false + } + wait := 24*time.Hour + randDuration(24*time.Hour) + return writeTime(path, time.Now().Add(wait)) } - + func (u *Updater) update() error { - path, err := osext.Executable() - if err != nil { - return err - } - old, err := os.Open(path) - if err != nil { - return err - } - defer old.Close() - - err = u.fetchInfo() - if err != nil { - return err - } - if u.Info.Version == u.CurrentVersion { - return nil - } - bin, err := u.fetchAndVerifyPatch(old) - if err != nil { - if err == ErrHashMismatch { - log.Println("update: hash mismatch from patched binary") - } else { - log.Println("update: patching binary,", err) - } - bin, err = u.fetchAndVerifyFullBin() - if err != nil { - if err == ErrHashMismatch { - log.Println("update: hash mismatch from full binary") - } else { - log.Println("update: fetching full binary,", err) - } - return err - } - } - - // close the old binary before installing because on windows - // it can't be renamed if a handle to the file is still open - old.Close() - - err, errRecover := update.FromStream(bytes.NewBuffer(bin)) - if errRecover != nil { - return fmt.Errorf("update and recovery errors: %q %q", err, errRecover) - } - if err != nil { - return err - } - return nil + path, err := osext.Executable() + if err != nil { + return err + } + old, err := os.Open(path) + if err != nil { + return err + } + defer old.Close() + + err = u.fetchInfo() + if err != nil { + return err + } + if u.Info.Version == u.CurrentVersion { + return nil + } + bin, err := u.fetchAndVerifyPatch(old) + if err != nil { + if err == ErrHashMismatch { + log.Println("update: hash mismatch from patched binary") + } else { + log.Println("update: patching binary,", err) + } + bin, err = u.fetchAndVerifyFullBin() + if err != nil { + if err == ErrHashMismatch { + log.Println("update: hash mismatch from full binary") + } else { + log.Println("update: fetching full binary,", err) + } + return err + } + } + + // close the old binary before installing because on windows + // it can't be renamed if a handle to the file is still open + old.Close() + + err, errRecover := update.FromStream(bytes.NewBuffer(bin)) + if errRecover != nil { + return fmt.Errorf("update and recovery errors: %q %q", err, errRecover) + } + if err != nil { + return err + } + return nil } - + func (u *Updater) fetchInfo() error { - fmt.Println(u.ApiURL) - fmt.Println(plat) - r, err := fetch(u.ApiURL + u.CmdName + "/" + plat + ".json") - if err != nil { - return err - } - defer r.Close() - err = json.NewDecoder(r).Decode(&u.Info) - if err != nil { - return err - } - if len(u.Info.Sha256) != sha256.Size { - return errors.New("bad cmd hash in info") - } - return nil + fmt.Println(u.ApiURL) + fmt.Println(plat) + r, err := fetch(u.ApiURL + u.CmdName + "/" + plat + ".json") + if err != nil { + return err + } + defer r.Close() + err = json.NewDecoder(r).Decode(&u.Info) + if err != nil { + return err + } + if len(u.Info.Sha256) != sha256.Size { + return errors.New("bad cmd hash in info") + } + return nil } - + func (u *Updater) fetchAndVerifyPatch(old io.Reader) ([]byte, error) { - bin, err := u.fetchAndApplyPatch(old) - if err != nil { - return nil, err - } - if !verifySha(bin, u.Info.Sha256) { - return nil, ErrHashMismatch - } - return bin, nil + bin, err := u.fetchAndApplyPatch(old) + if err != nil { + return nil, err + } + if !verifySha(bin, u.Info.Sha256) { + return nil, ErrHashMismatch + } + return bin, nil } - + func (u *Updater) fetchAndApplyPatch(old io.Reader) ([]byte, error) { - r, err := fetch(u.DiffURL + u.CmdName + "/" + u.CurrentVersion + "/" + u.Info.Version + "/" + plat) - if err != nil { - return nil, err - } - defer r.Close() - var buf bytes.Buffer - err = binarydist.Patch(old, &buf, r) - return buf.Bytes(), err + r, err := fetch(u.DiffURL + u.CmdName + "/" + u.CurrentVersion + "/" + u.Info.Version + "/" + plat) + if err != nil { + return nil, err + } + defer r.Close() + var buf bytes.Buffer + err = binarydist.Patch(old, &buf, r) + return buf.Bytes(), err } - + func (u *Updater) fetchAndVerifyFullBin() ([]byte, error) { - bin, err := u.fetchBin() - if err != nil { - return nil, err - } - verified := verifySha(bin, u.Info.Sha256) - if !verified { - return nil, ErrHashMismatch - } - return bin, nil + bin, err := u.fetchBin() + if err != nil { + return nil, err + } + verified := verifySha(bin, u.Info.Sha256) + if !verified { + return nil, ErrHashMismatch + } + return bin, nil } - + func (u *Updater) fetchBin() ([]byte, error) { - r, err := fetch(u.BinURL + u.CmdName + "/" + u.Info.Version + "/" + plat + ".gz") - if err != nil { - return nil, err - } - defer r.Close() - buf := new(bytes.Buffer) - gz, err := gzip.NewReader(r) - if err != nil { - return nil, err - } - if _, err = io.Copy(buf, gz); err != nil { - return nil, err - } - - return buf.Bytes(), nil + r, err := fetch(u.BinURL + u.CmdName + "/" + u.Info.Version + "/" + plat + ".gz") + if err != nil { + return nil, err + } + defer r.Close() + buf := new(bytes.Buffer) + gz, err := gzip.NewReader(r) + if err != nil { + return nil, err + } + if _, err = io.Copy(buf, gz); err != nil { + return nil, err + } + + return buf.Bytes(), nil } - + // returns a random duration in [0,n). func randDuration(n time.Duration) time.Duration { - return time.Duration(rand.Int63n(int64(n))) + return time.Duration(rand.Int63n(int64(n))) } - + func fetch(url string) (io.ReadCloser, error) { - resp, err := http.Get(url) - if err != nil { - return nil, err - } - if resp.StatusCode != 200 { - return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status) - } - return resp.Body, nil + resp, err := http.Get(url) + if err != nil { + return nil, err + } + if resp.StatusCode != 200 { + return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status) + } + return resp.Body, nil } - + func readTime(path string) time.Time { - p, err := ioutil.ReadFile(path) - if os.IsNotExist(err) { - return time.Time{} - } - if err != nil { - return time.Now().Add(1000 * time.Hour) - } - t, err := time.Parse(time.RFC3339, string(p)) - if err != nil { - return time.Now().Add(1000 * time.Hour) - } - return t + p, err := ioutil.ReadFile(path) + if os.IsNotExist(err) { + return time.Time{} + } + if err != nil { + return time.Now().Add(1000 * time.Hour) + } + t, err := time.Parse(time.RFC3339, string(p)) + if err != nil { + return time.Now().Add(1000 * time.Hour) + } + return t } - + func verifySha(bin []byte, sha []byte) bool { - h := sha256.New() - h.Write(bin) - return bytes.Equal(h.Sum(nil), sha) + h := sha256.New() + h.Write(bin) + return bytes.Equal(h.Sum(nil), sha) } - + func writeTime(path string, t time.Time) bool { - return ioutil.WriteFile(path, []byte(t.Format(time.RFC3339)), 0644) == nil + return ioutil.WriteFile(path, []byte(t.Format(time.RFC3339)), 0644) == nil } diff --git a/selfupdate/selfupdate_test.go b/selfupdate/selfupdate_test.go index cdc24c3..a679db3 100644 --- a/selfupdate/selfupdate_test.go +++ b/selfupdate/selfupdate_test.go @@ -1,9 +1,8 @@ package selfupdate -import( - "testing" +import ( + "testing" ) - func TestUpdater() { }