1
0
mirror of https://github.com/taigrr/go-selfupdate synced 2025-01-18 04:33:12 -08:00

Ran go fmt

This commit is contained in:
Mark Sanborn 2015-01-05 09:25:43 -08:00
parent a88bde6abc
commit 614f2ab48b
2 changed files with 202 additions and 203 deletions

View File

@ -1,35 +1,35 @@
package selfupdate package selfupdate
import ( import (
"bitbucket.org/kardianos/osext" "bitbucket.org/kardianos/osext"
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"crypto/sha256" "crypto/sha256"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/inconshreveable/go-update" "github.com/inconshreveable/go-update"
"github.com/kr/binarydist" "github.com/kr/binarydist"
"io" "io"
"io/ioutil" "io/ioutil"
"log" "log"
"math/rand" "math/rand"
"net/http" "net/http"
"os" "os"
"runtime" "path/filepath"
"time" "runtime"
"path/filepath" "time"
) )
const ( const (
upcktimePath = "cktime" upcktimePath = "cktime"
plat = runtime.GOOS + "-" + runtime.GOARCH plat = runtime.GOOS + "-" + runtime.GOARCH
) )
const devValidTime = 7 * 24 * time.Hour const devValidTime = 7 * 24 * time.Hour
var ErrHashMismatch = errors.New("new file hash mismatch after patch") var ErrHashMismatch = errors.New("new file hash mismatch after patch")
// Update protocol. // Update protocol.
// //
// GET hk.heroku.com/hk/linux-amd64.json // GET hk.heroku.com/hk/linux-amd64.json
@ -54,210 +54,210 @@ var ErrHashMismatch = errors.New("new file hash mismatch after patch")
// 200 ok // 200 ok
// [gzipped executable data] // [gzipped executable data]
type Updater struct { type Updater struct {
CurrentVersion string CurrentVersion string
ApiURL string ApiURL string
CmdName string CmdName string
BinURL string BinURL string
DiffURL string DiffURL string
Dir string Dir string
Info struct { Info struct {
Version string Version string
Sha256 []byte Sha256 []byte
} }
} }
func (u *Updater) getExecRelativeDir(dir string) string { func (u *Updater) getExecRelativeDir(dir string) string {
filename, _ := osext.Executable() filename, _ := osext.Executable()
path := filepath.Join(filepath.Dir(filename), dir) path := filepath.Join(filepath.Dir(filename), dir)
fmt.Println(path) fmt.Println(path)
return path return path
} }
func (u *Updater) BackgroundRun() { func (u *Updater) BackgroundRun() {
os.MkdirAll(u.getExecRelativeDir(u.Dir), 0777) os.MkdirAll(u.getExecRelativeDir(u.Dir), 0777)
if u.wantUpdate() { if u.wantUpdate() {
if err := update.SanityCheck(); err != nil { if err := update.SanityCheck(); err != nil {
// fail // fail
return return
} }
//self, err := osext.Executable() //self, err := osext.Executable()
//if err != nil { //if err != nil {
// fail update, couldn't figure out path to self // fail update, couldn't figure out path to self
//return //return
//} //}
// TODO(bgentry): logger isn't on Windows. Replace w/ proper error reports. // TODO(bgentry): logger isn't on Windows. Replace w/ proper error reports.
if err := u.update(); err != nil { if err := u.update(); err != nil {
log.Println(err) log.Println(err)
} }
} }
} }
func (u *Updater) wantUpdate() bool { func (u *Updater) wantUpdate() bool {
path := u.getExecRelativeDir(u.Dir + upcktimePath) path := u.getExecRelativeDir(u.Dir + upcktimePath)
if u.CurrentVersion == "dev" || readTime(path).After(time.Now()) { if u.CurrentVersion == "dev" || readTime(path).After(time.Now()) {
return false return false
} }
wait := 24*time.Hour + randDuration(24*time.Hour) wait := 24*time.Hour + randDuration(24*time.Hour)
return writeTime(path, time.Now().Add(wait)) return writeTime(path, time.Now().Add(wait))
} }
func (u *Updater) update() error { func (u *Updater) update() error {
path, err := osext.Executable() path, err := osext.Executable()
if err != nil { if err != nil {
return err return err
} }
old, err := os.Open(path) old, err := os.Open(path)
if err != nil { if err != nil {
return err return err
} }
defer old.Close() defer old.Close()
err = u.fetchInfo() err = u.fetchInfo()
if err != nil { if err != nil {
return err return err
} }
if u.Info.Version == u.CurrentVersion { if u.Info.Version == u.CurrentVersion {
return nil return nil
} }
bin, err := u.fetchAndVerifyPatch(old) bin, err := u.fetchAndVerifyPatch(old)
if err != nil { if err != nil {
if err == ErrHashMismatch { if err == ErrHashMismatch {
log.Println("update: hash mismatch from patched binary") log.Println("update: hash mismatch from patched binary")
} else { } else {
log.Println("update: patching binary,", err) log.Println("update: patching binary,", err)
} }
bin, err = u.fetchAndVerifyFullBin() bin, err = u.fetchAndVerifyFullBin()
if err != nil { if err != nil {
if err == ErrHashMismatch { if err == ErrHashMismatch {
log.Println("update: hash mismatch from full binary") log.Println("update: hash mismatch from full binary")
} else { } else {
log.Println("update: fetching full binary,", err) log.Println("update: fetching full binary,", err)
} }
return err return err
} }
} }
// close the old binary before installing because on windows // close the old binary before installing because on windows
// it can't be renamed if a handle to the file is still open // it can't be renamed if a handle to the file is still open
old.Close() old.Close()
err, errRecover := update.FromStream(bytes.NewBuffer(bin)) err, errRecover := update.FromStream(bytes.NewBuffer(bin))
if errRecover != nil { if errRecover != nil {
return fmt.Errorf("update and recovery errors: %q %q", err, errRecover) return fmt.Errorf("update and recovery errors: %q %q", err, errRecover)
} }
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func (u *Updater) fetchInfo() error { func (u *Updater) fetchInfo() error {
fmt.Println(u.ApiURL) fmt.Println(u.ApiURL)
fmt.Println(plat) fmt.Println(plat)
r, err := fetch(u.ApiURL + u.CmdName + "/" + plat + ".json") r, err := fetch(u.ApiURL + u.CmdName + "/" + plat + ".json")
if err != nil { if err != nil {
return err return err
} }
defer r.Close() defer r.Close()
err = json.NewDecoder(r).Decode(&u.Info) err = json.NewDecoder(r).Decode(&u.Info)
if err != nil { if err != nil {
return err return err
} }
if len(u.Info.Sha256) != sha256.Size { if len(u.Info.Sha256) != sha256.Size {
return errors.New("bad cmd hash in info") return errors.New("bad cmd hash in info")
} }
return nil return nil
} }
func (u *Updater) fetchAndVerifyPatch(old io.Reader) ([]byte, error) { func (u *Updater) fetchAndVerifyPatch(old io.Reader) ([]byte, error) {
bin, err := u.fetchAndApplyPatch(old) bin, err := u.fetchAndApplyPatch(old)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !verifySha(bin, u.Info.Sha256) { if !verifySha(bin, u.Info.Sha256) {
return nil, ErrHashMismatch return nil, ErrHashMismatch
} }
return bin, nil return bin, nil
} }
func (u *Updater) fetchAndApplyPatch(old io.Reader) ([]byte, error) { func (u *Updater) fetchAndApplyPatch(old io.Reader) ([]byte, error) {
r, err := fetch(u.DiffURL + u.CmdName + "/" + u.CurrentVersion + "/" + u.Info.Version + "/" + plat) r, err := fetch(u.DiffURL + u.CmdName + "/" + u.CurrentVersion + "/" + u.Info.Version + "/" + plat)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer r.Close() defer r.Close()
var buf bytes.Buffer var buf bytes.Buffer
err = binarydist.Patch(old, &buf, r) err = binarydist.Patch(old, &buf, r)
return buf.Bytes(), err return buf.Bytes(), err
} }
func (u *Updater) fetchAndVerifyFullBin() ([]byte, error) { func (u *Updater) fetchAndVerifyFullBin() ([]byte, error) {
bin, err := u.fetchBin() bin, err := u.fetchBin()
if err != nil { if err != nil {
return nil, err return nil, err
} }
verified := verifySha(bin, u.Info.Sha256) verified := verifySha(bin, u.Info.Sha256)
if !verified { if !verified {
return nil, ErrHashMismatch return nil, ErrHashMismatch
} }
return bin, nil return bin, nil
} }
func (u *Updater) fetchBin() ([]byte, error) { func (u *Updater) fetchBin() ([]byte, error) {
r, err := fetch(u.BinURL + u.CmdName + "/" + u.Info.Version + "/" + plat + ".gz") r, err := fetch(u.BinURL + u.CmdName + "/" + u.Info.Version + "/" + plat + ".gz")
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer r.Close() defer r.Close()
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
gz, err := gzip.NewReader(r) gz, err := gzip.NewReader(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if _, err = io.Copy(buf, gz); err != nil { if _, err = io.Copy(buf, gz); err != nil {
return nil, err return nil, err
} }
return buf.Bytes(), nil return buf.Bytes(), nil
} }
// returns a random duration in [0,n). // returns a random duration in [0,n).
func randDuration(n time.Duration) time.Duration { func randDuration(n time.Duration) time.Duration {
return time.Duration(rand.Int63n(int64(n))) return time.Duration(rand.Int63n(int64(n)))
} }
func fetch(url string) (io.ReadCloser, error) { func fetch(url string) (io.ReadCloser, error) {
resp, err := http.Get(url) resp, err := http.Get(url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status) return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status)
} }
return resp.Body, nil return resp.Body, nil
} }
func readTime(path string) time.Time { func readTime(path string) time.Time {
p, err := ioutil.ReadFile(path) p, err := ioutil.ReadFile(path)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return time.Time{} return time.Time{}
} }
if err != nil { if err != nil {
return time.Now().Add(1000 * time.Hour) return time.Now().Add(1000 * time.Hour)
} }
t, err := time.Parse(time.RFC3339, string(p)) t, err := time.Parse(time.RFC3339, string(p))
if err != nil { if err != nil {
return time.Now().Add(1000 * time.Hour) return time.Now().Add(1000 * time.Hour)
} }
return t return t
} }
func verifySha(bin []byte, sha []byte) bool { func verifySha(bin []byte, sha []byte) bool {
h := sha256.New() h := sha256.New()
h.Write(bin) h.Write(bin)
return bytes.Equal(h.Sum(nil), sha) return bytes.Equal(h.Sum(nil), sha)
} }
func writeTime(path string, t time.Time) bool { func writeTime(path string, t time.Time) bool {
return ioutil.WriteFile(path, []byte(t.Format(time.RFC3339)), 0644) == nil return ioutil.WriteFile(path, []byte(t.Format(time.RFC3339)), 0644) == nil
} }

View File

@ -1,9 +1,8 @@
package selfupdate package selfupdate
import( import (
"testing" "testing"
) )
func TestUpdater() { func TestUpdater() {
} }