package selfupdate import ( "bitbucket.org/kardianos/osext" "bytes" "compress/gzip" "crypto/sha256" "encoding/json" "errors" "fmt" "github.com/inconshreveable/go-update" "github.com/kr/binarydist" "io" "io/ioutil" "log" "math/rand" "net/http" "os" "path/filepath" "runtime" "time" ) const ( upcktimePath = "cktime" plat = runtime.GOOS + "-" + runtime.GOARCH ) const devValidTime = 7 * 24 * time.Hour var ErrHashMismatch = errors.New("new file hash mismatch after patch") var up = update.New() // Update protocol. // // GET hk.heroku.com/hk/linux-amd64.json // // 200 ok // { // "Version": "2", // "Sha256": "..." // base64 // } // // then // // GET hkpatch.s3.amazonaws.com/hk/1/2/linux-amd64 // // 200 ok // [bsdiff data] // // or // // GET hkdist.s3.amazonaws.com/hk/2/linux-amd64.gz // // 200 ok // [gzipped executable data] type Updater struct { CurrentVersion string ApiURL string CmdName string BinURL string DiffURL string Dir string Info struct { Version string Sha256 []byte } } func (u *Updater) getExecRelativeDir(dir string) string { filename, _ := osext.Executable() path := filepath.Join(filepath.Dir(filename), dir) fmt.Println(path) return path } func (u *Updater) BackgroundRun() { os.MkdirAll(u.getExecRelativeDir(u.Dir), 0777) if u.wantUpdate() { if err := up.CanUpdate(); err != nil { // fail return } //self, err := osext.Executable() //if err != nil { // fail update, couldn't figure out path to self //return //} // TODO(bgentry): logger isn't on Windows. Replace w/ proper error reports. if err := u.update(); err != nil { log.Println(err) } } } func (u *Updater) wantUpdate() bool { path := u.getExecRelativeDir(u.Dir + upcktimePath) if u.CurrentVersion == "dev" || readTime(path).After(time.Now()) { return false } wait := 24*time.Hour + randDuration(24*time.Hour) return writeTime(path, time.Now().Add(wait)) } func (u *Updater) update() error { path, err := osext.Executable() if err != nil { return err } old, err := os.Open(path) if err != nil { return err } defer old.Close() err = u.fetchInfo() if err != nil { return err } if u.Info.Version == u.CurrentVersion { return nil } bin, err := u.fetchAndVerifyPatch(old) if err != nil { if err == ErrHashMismatch { log.Println("update: hash mismatch from patched binary") } else { log.Println("update: patching binary,", err) } bin, err = u.fetchAndVerifyFullBin() if err != nil { if err == ErrHashMismatch { log.Println("update: hash mismatch from full binary") } else { log.Println("update: fetching full binary,", err) } return err } } // close the old binary before installing because on windows // it can't be renamed if a handle to the file is still open old.Close() err, errRecover := up.FromStream(bytes.NewBuffer(bin)) if errRecover != nil { return fmt.Errorf("update and recovery errors: %q %q", err, errRecover) } if err != nil { return err } return nil } func (u *Updater) fetchInfo() error { fmt.Println(u.ApiURL) fmt.Println(plat) r, err := fetch(u.ApiURL + u.CmdName + "/" + plat + ".json") if err != nil { return err } defer r.Close() err = json.NewDecoder(r).Decode(&u.Info) if err != nil { return err } if len(u.Info.Sha256) != sha256.Size { return errors.New("bad cmd hash in info") } return nil } func (u *Updater) fetchAndVerifyPatch(old io.Reader) ([]byte, error) { bin, err := u.fetchAndApplyPatch(old) if err != nil { return nil, err } if !verifySha(bin, u.Info.Sha256) { return nil, ErrHashMismatch } return bin, nil } func (u *Updater) fetchAndApplyPatch(old io.Reader) ([]byte, error) { r, err := fetch(u.DiffURL + u.CmdName + "/" + u.CurrentVersion + "/" + u.Info.Version + "/" + plat) if err != nil { return nil, err } defer r.Close() var buf bytes.Buffer err = binarydist.Patch(old, &buf, r) return buf.Bytes(), err } func (u *Updater) fetchAndVerifyFullBin() ([]byte, error) { bin, err := u.fetchBin() if err != nil { return nil, err } verified := verifySha(bin, u.Info.Sha256) if !verified { return nil, ErrHashMismatch } return bin, nil } func (u *Updater) fetchBin() ([]byte, error) { r, err := fetch(u.BinURL + u.CmdName + "/" + u.Info.Version + "/" + plat + ".gz") if err != nil { return nil, err } defer r.Close() buf := new(bytes.Buffer) gz, err := gzip.NewReader(r) if err != nil { return nil, err } if _, err = io.Copy(buf, gz); err != nil { return nil, err } return buf.Bytes(), nil } // returns a random duration in [0,n). func randDuration(n time.Duration) time.Duration { return time.Duration(rand.Int63n(int64(n))) } func fetch(url string) (io.ReadCloser, error) { resp, err := http.Get(url) if err != nil { return nil, err } if resp.StatusCode != 200 { return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status) } return resp.Body, nil } func readTime(path string) time.Time { p, err := ioutil.ReadFile(path) if os.IsNotExist(err) { return time.Time{} } if err != nil { return time.Now().Add(1000 * time.Hour) } t, err := time.Parse(time.RFC3339, string(p)) if err != nil { return time.Now().Add(1000 * time.Hour) } return t } func verifySha(bin []byte, sha []byte) bool { h := sha256.New() h.Write(bin) return bytes.Equal(h.Sum(nil), sha) } func writeTime(path string, t time.Time) bool { return ioutil.WriteFile(path, []byte(t.Format(time.RFC3339)), 0644) == nil }