From f906cb7b7d0ade0b974c7ddc1c3ca06824f7d08b Mon Sep 17 00:00:00 2001 From: vagrant Date: Thu, 14 Nov 2013 17:36:10 -0800 Subject: [PATCH] initial --- main.go | 1 + selfupdate/selfupdate.go | 263 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 264 insertions(+) create mode 100644 selfupdate/selfupdate.go diff --git a/main.go b/main.go index be11f09..fe1229a 100644 --- a/main.go +++ b/main.go @@ -101,6 +101,7 @@ func main() { if file.Name() == version { continue } + os.Mkdir(filepath.Join(genDir, file.Name(), version), 0755) fName := filepath.Join(genDir, file.Name(), plat + ".gz") diff --git a/selfupdate/selfupdate.go b/selfupdate/selfupdate.go new file mode 100644 index 0000000..b7020f7 --- /dev/null +++ b/selfupdate/selfupdate.go @@ -0,0 +1,263 @@ +package selfupdate + +import ( + "bitbucket.org/kardianos/osext" + "bytes" + "compress/gzip" + "crypto/sha256" + "encoding/json" + "errors" + "fmt" + "github.com/inconshreveable/go-update" + "github.com/kr/binarydist" + "io" + "io/ioutil" + "log" + "math/rand" + "net/http" + "os" + "runtime" + "time" + "path/filepath" +) + +const ( + upcktimePath = "cktime" + plat = runtime.GOOS + "-" + runtime.GOARCH +) + +const devValidTime = 7 * 24 * time.Hour + +var ErrHashMismatch = errors.New("new file hash mismatch after patch") + +// Update protocol. +// +// GET hk.heroku.com/hk/linux-amd64.json +// +// 200 ok +// { +// "Version": "2", +// "Sha256": "..." // base64 +// } +// +// then +// +// GET hkpatch.s3.amazonaws.com/hk/1/2/linux-amd64 +// +// 200 ok +// [bsdiff data] +// +// or +// +// GET hkdist.s3.amazonaws.com/hk/2/linux-amd64.gz +// +// 200 ok +// [gzipped executable data] +type Updater struct { + CurrentVersion string + ApiURL string + CmdName string + BinURL string + DiffURL string + Dir string + Info struct { + Version string + Sha256 []byte + } +} + +func (u *Updater) getExecRelativeDir(dir string) string { + filename, _ := osext.Executable() + path := filepath.Join(filepath.Dir(filename), dir) + fmt.Println(path) + return path +} + +func (u *Updater) BackgroundRun() { + os.MkdirAll(u.getExecRelativeDir(u.Dir), 0777) + if u.wantUpdate() { + if err := update.SanityCheck(); err != nil { + // fail + return + } + //self, err := osext.Executable() + //if err != nil { + // fail update, couldn't figure out path to self + //return + //} + // TODO(bgentry): logger isn't on Windows. Replace w/ proper error reports. + if err := u.update(); err != nil { + log.Fatal(err) + } + } +} + +func (u *Updater) wantUpdate() bool { + path := u.getExecRelativeDir(u.Dir + upcktimePath) + if u.CurrentVersion == "dev" || readTime(path).After(time.Now()) { + return false + } + wait := 24*time.Hour + randDuration(24*time.Hour) + return writeTime(path, time.Now().Add(wait)) +} + +func (u *Updater) update() error { + path, err := osext.Executable() + if err != nil { + return err + } + old, err := os.Open(path) + if err != nil { + return err + } + defer old.Close() + + err = u.fetchInfo() + if err != nil { + return err + } + if u.Info.Version == u.CurrentVersion { + return nil + } + bin, err := u.fetchAndVerifyPatch(old) + if err != nil { + if err == ErrHashMismatch { + log.Println("update: hash mismatch from patched binary") + } else { + log.Println("update: patching binary,", err) + } + bin, err = u.fetchAndVerifyFullBin() + if err != nil { + if err == ErrHashMismatch { + log.Println("update: hash mismatch from full binary") + } else { + log.Println("update: fetching full binary,", err) + } + return err + } + } + + // close the old binary before installing because on windows + // it can't be renamed if a handle to the file is still open + old.Close() + + err, errRecover := update.FromStream(bytes.NewBuffer(bin)) + if errRecover != nil { + return fmt.Errorf("update and recovery errors: %q %q", err, errRecover) + } + if err != nil { + return err + } + return nil +} + +func (u *Updater) fetchInfo() error { + fmt.Println(u.ApiURL) + fmt.Println(plat) + r, err := fetch(u.ApiURL + u.CmdName + "/" + plat + ".json") + if err != nil { + return err + } + defer r.Close() + err = json.NewDecoder(r).Decode(&u.Info) + if err != nil { + return err + } + if len(u.Info.Sha256) != sha256.Size { + return errors.New("bad cmd hash in info") + } + return nil +} + +func (u *Updater) fetchAndVerifyPatch(old io.Reader) ([]byte, error) { + bin, err := u.fetchAndApplyPatch(old) + if err != nil { + return nil, err + } + if !verifySha(bin, u.Info.Sha256) { + return nil, ErrHashMismatch + } + return bin, nil +} + +func (u *Updater) fetchAndApplyPatch(old io.Reader) ([]byte, error) { + r, err := fetch(u.DiffURL + u.CmdName + "/" + u.CurrentVersion + "/" + u.Info.Version + "/" + plat) + if err != nil { + return nil, err + } + defer r.Close() + var buf bytes.Buffer + err = binarydist.Patch(old, &buf, r) + return buf.Bytes(), err +} + +func (u *Updater) fetchAndVerifyFullBin() ([]byte, error) { + bin, err := u.fetchBin() + if err != nil { + return nil, err + } + verified := verifySha(bin, u.Info.Sha256) + if !verified { + return nil, ErrHashMismatch + } + return bin, nil +} + +func (u *Updater) fetchBin() ([]byte, error) { + r, err := fetch(u.BinURL + u.CmdName + "/" + u.Info.Version + "/" + plat + ".gz") + if err != nil { + return nil, err + } + defer r.Close() + buf := new(bytes.Buffer) + gz, err := gzip.NewReader(r) + if err != nil { + return nil, err + } + if _, err = io.Copy(buf, gz); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// returns a random duration in [0,n). +func randDuration(n time.Duration) time.Duration { + return time.Duration(rand.Int63n(int64(n))) +} + +func fetch(url string) (io.ReadCloser, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + if resp.StatusCode != 200 { + return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status) + } + return resp.Body, nil +} + +func readTime(path string) time.Time { + p, err := ioutil.ReadFile(path) + if os.IsNotExist(err) { + return time.Time{} + } + if err != nil { + return time.Now().Add(1000 * time.Hour) + } + t, err := time.Parse(time.RFC3339, string(p)) + if err != nil { + return time.Now().Add(1000 * time.Hour) + } + return t +} + +func verifySha(bin []byte, sha []byte) bool { + h := sha256.New() + h.Write(bin) + return bytes.Equal(h.Sum(nil), sha) +} + +func writeTime(path string, t time.Time) bool { + return ioutil.WriteFile(path, []byte(t.Format(time.RFC3339)), 0644) == nil +}