mirror of
https://github.com/taigrr/go-selfupdate
synced 2025-01-18 04:33:12 -08:00
last version adapted
This commit is contained in:
parent
a88bde6abc
commit
c2f128aaaf
@ -1,35 +1,36 @@
|
|||||||
package selfupdate
|
package selfupdate
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bitbucket.org/kardianos/osext"
|
"bitbucket.org/kardianos/osext"
|
||||||
"bytes"
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/inconshreveable/go-update"
|
"github.com/inconshreveable/go-update"
|
||||||
"github.com/kr/binarydist"
|
"github.com/kr/binarydist"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"path/filepath"
|
||||||
"time"
|
"runtime"
|
||||||
"path/filepath"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
upcktimePath = "cktime"
|
upcktimePath = "cktime"
|
||||||
plat = runtime.GOOS + "-" + runtime.GOARCH
|
plat = runtime.GOOS + "-" + runtime.GOARCH
|
||||||
)
|
)
|
||||||
|
|
||||||
const devValidTime = 7 * 24 * time.Hour
|
const devValidTime = 7 * 24 * time.Hour
|
||||||
|
|
||||||
var ErrHashMismatch = errors.New("new file hash mismatch after patch")
|
var ErrHashMismatch = errors.New("new file hash mismatch after patch")
|
||||||
|
var up = update.New()
|
||||||
|
|
||||||
// Update protocol.
|
// Update protocol.
|
||||||
//
|
//
|
||||||
// GET hk.heroku.com/hk/linux-amd64.json
|
// GET hk.heroku.com/hk/linux-amd64.json
|
||||||
@ -54,210 +55,210 @@ var ErrHashMismatch = errors.New("new file hash mismatch after patch")
|
|||||||
// 200 ok
|
// 200 ok
|
||||||
// [gzipped executable data]
|
// [gzipped executable data]
|
||||||
type Updater struct {
|
type Updater struct {
|
||||||
CurrentVersion string
|
CurrentVersion string
|
||||||
ApiURL string
|
ApiURL string
|
||||||
CmdName string
|
CmdName string
|
||||||
BinURL string
|
BinURL string
|
||||||
DiffURL string
|
DiffURL string
|
||||||
Dir string
|
Dir string
|
||||||
Info struct {
|
Info struct {
|
||||||
Version string
|
Version string
|
||||||
Sha256 []byte
|
Sha256 []byte
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) getExecRelativeDir(dir string) string {
|
func (u *Updater) getExecRelativeDir(dir string) string {
|
||||||
filename, _ := osext.Executable()
|
filename, _ := osext.Executable()
|
||||||
path := filepath.Join(filepath.Dir(filename), dir)
|
path := filepath.Join(filepath.Dir(filename), dir)
|
||||||
fmt.Println(path)
|
fmt.Println(path)
|
||||||
return path
|
return path
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) BackgroundRun() {
|
func (u *Updater) BackgroundRun() {
|
||||||
os.MkdirAll(u.getExecRelativeDir(u.Dir), 0777)
|
os.MkdirAll(u.getExecRelativeDir(u.Dir), 0777)
|
||||||
if u.wantUpdate() {
|
if u.wantUpdate() {
|
||||||
if err := update.SanityCheck(); err != nil {
|
if err := up.CanUpdate(); err != nil {
|
||||||
// fail
|
// fail
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
//self, err := osext.Executable()
|
//self, err := osext.Executable()
|
||||||
//if err != nil {
|
//if err != nil {
|
||||||
// fail update, couldn't figure out path to self
|
// fail update, couldn't figure out path to self
|
||||||
//return
|
//return
|
||||||
//}
|
//}
|
||||||
// TODO(bgentry): logger isn't on Windows. Replace w/ proper error reports.
|
// TODO(bgentry): logger isn't on Windows. Replace w/ proper error reports.
|
||||||
if err := u.update(); err != nil {
|
if err := u.update(); err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) wantUpdate() bool {
|
func (u *Updater) wantUpdate() bool {
|
||||||
path := u.getExecRelativeDir(u.Dir + upcktimePath)
|
path := u.getExecRelativeDir(u.Dir + upcktimePath)
|
||||||
if u.CurrentVersion == "dev" || readTime(path).After(time.Now()) {
|
if u.CurrentVersion == "dev" || readTime(path).After(time.Now()) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
wait := 24*time.Hour + randDuration(24*time.Hour)
|
wait := 24*time.Hour + randDuration(24*time.Hour)
|
||||||
return writeTime(path, time.Now().Add(wait))
|
return writeTime(path, time.Now().Add(wait))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) update() error {
|
func (u *Updater) update() error {
|
||||||
path, err := osext.Executable()
|
path, err := osext.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
old, err := os.Open(path)
|
old, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer old.Close()
|
defer old.Close()
|
||||||
|
|
||||||
err = u.fetchInfo()
|
err = u.fetchInfo()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if u.Info.Version == u.CurrentVersion {
|
if u.Info.Version == u.CurrentVersion {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
bin, err := u.fetchAndVerifyPatch(old)
|
bin, err := u.fetchAndVerifyPatch(old)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == ErrHashMismatch {
|
if err == ErrHashMismatch {
|
||||||
log.Println("update: hash mismatch from patched binary")
|
log.Println("update: hash mismatch from patched binary")
|
||||||
} else {
|
} else {
|
||||||
log.Println("update: patching binary,", err)
|
log.Println("update: patching binary,", err)
|
||||||
}
|
}
|
||||||
bin, err = u.fetchAndVerifyFullBin()
|
bin, err = u.fetchAndVerifyFullBin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == ErrHashMismatch {
|
if err == ErrHashMismatch {
|
||||||
log.Println("update: hash mismatch from full binary")
|
log.Println("update: hash mismatch from full binary")
|
||||||
} else {
|
} else {
|
||||||
log.Println("update: fetching full binary,", err)
|
log.Println("update: fetching full binary,", err)
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// close the old binary before installing because on windows
|
// close the old binary before installing because on windows
|
||||||
// it can't be renamed if a handle to the file is still open
|
// it can't be renamed if a handle to the file is still open
|
||||||
old.Close()
|
old.Close()
|
||||||
|
|
||||||
err, errRecover := update.FromStream(bytes.NewBuffer(bin))
|
err, errRecover := up.FromStream(bytes.NewBuffer(bin))
|
||||||
if errRecover != nil {
|
if errRecover != nil {
|
||||||
return fmt.Errorf("update and recovery errors: %q %q", err, errRecover)
|
return fmt.Errorf("update and recovery errors: %q %q", err, errRecover)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) fetchInfo() error {
|
func (u *Updater) fetchInfo() error {
|
||||||
fmt.Println(u.ApiURL)
|
fmt.Println(u.ApiURL)
|
||||||
fmt.Println(plat)
|
fmt.Println(plat)
|
||||||
r, err := fetch(u.ApiURL + u.CmdName + "/" + plat + ".json")
|
r, err := fetch(u.ApiURL + u.CmdName + "/" + plat + ".json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
err = json.NewDecoder(r).Decode(&u.Info)
|
err = json.NewDecoder(r).Decode(&u.Info)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if len(u.Info.Sha256) != sha256.Size {
|
if len(u.Info.Sha256) != sha256.Size {
|
||||||
return errors.New("bad cmd hash in info")
|
return errors.New("bad cmd hash in info")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) fetchAndVerifyPatch(old io.Reader) ([]byte, error) {
|
func (u *Updater) fetchAndVerifyPatch(old io.Reader) ([]byte, error) {
|
||||||
bin, err := u.fetchAndApplyPatch(old)
|
bin, err := u.fetchAndApplyPatch(old)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !verifySha(bin, u.Info.Sha256) {
|
if !verifySha(bin, u.Info.Sha256) {
|
||||||
return nil, ErrHashMismatch
|
return nil, ErrHashMismatch
|
||||||
}
|
}
|
||||||
return bin, nil
|
return bin, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) fetchAndApplyPatch(old io.Reader) ([]byte, error) {
|
func (u *Updater) fetchAndApplyPatch(old io.Reader) ([]byte, error) {
|
||||||
r, err := fetch(u.DiffURL + u.CmdName + "/" + u.CurrentVersion + "/" + u.Info.Version + "/" + plat)
|
r, err := fetch(u.DiffURL + u.CmdName + "/" + u.CurrentVersion + "/" + u.Info.Version + "/" + plat)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
err = binarydist.Patch(old, &buf, r)
|
err = binarydist.Patch(old, &buf, r)
|
||||||
return buf.Bytes(), err
|
return buf.Bytes(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) fetchAndVerifyFullBin() ([]byte, error) {
|
func (u *Updater) fetchAndVerifyFullBin() ([]byte, error) {
|
||||||
bin, err := u.fetchBin()
|
bin, err := u.fetchBin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
verified := verifySha(bin, u.Info.Sha256)
|
verified := verifySha(bin, u.Info.Sha256)
|
||||||
if !verified {
|
if !verified {
|
||||||
return nil, ErrHashMismatch
|
return nil, ErrHashMismatch
|
||||||
}
|
}
|
||||||
return bin, nil
|
return bin, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) fetchBin() ([]byte, error) {
|
func (u *Updater) fetchBin() ([]byte, error) {
|
||||||
r, err := fetch(u.BinURL + u.CmdName + "/" + u.Info.Version + "/" + plat + ".gz")
|
r, err := fetch(u.BinURL + u.CmdName + "/" + u.Info.Version + "/" + plat + ".gz")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
gz, err := gzip.NewReader(r)
|
gz, err := gzip.NewReader(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if _, err = io.Copy(buf, gz); err != nil {
|
if _, err = io.Copy(buf, gz); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// returns a random duration in [0,n).
|
// returns a random duration in [0,n).
|
||||||
func randDuration(n time.Duration) time.Duration {
|
func randDuration(n time.Duration) time.Duration {
|
||||||
return time.Duration(rand.Int63n(int64(n)))
|
return time.Duration(rand.Int63n(int64(n)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetch(url string) (io.ReadCloser, error) {
|
func fetch(url string) (io.ReadCloser, error) {
|
||||||
resp, err := http.Get(url)
|
resp, err := http.Get(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status)
|
return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status)
|
||||||
}
|
}
|
||||||
return resp.Body, nil
|
return resp.Body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readTime(path string) time.Time {
|
func readTime(path string) time.Time {
|
||||||
p, err := ioutil.ReadFile(path)
|
p, err := ioutil.ReadFile(path)
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
return time.Time{}
|
return time.Time{}
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return time.Now().Add(1000 * time.Hour)
|
return time.Now().Add(1000 * time.Hour)
|
||||||
}
|
}
|
||||||
t, err := time.Parse(time.RFC3339, string(p))
|
t, err := time.Parse(time.RFC3339, string(p))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return time.Now().Add(1000 * time.Hour)
|
return time.Now().Add(1000 * time.Hour)
|
||||||
}
|
}
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifySha(bin []byte, sha []byte) bool {
|
func verifySha(bin []byte, sha []byte) bool {
|
||||||
h := sha256.New()
|
h := sha256.New()
|
||||||
h.Write(bin)
|
h.Write(bin)
|
||||||
return bytes.Equal(h.Sum(nil), sha)
|
return bytes.Equal(h.Sum(nil), sha)
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeTime(path string, t time.Time) bool {
|
func writeTime(path string, t time.Time) bool {
|
||||||
return ioutil.WriteFile(path, []byte(t.Format(time.RFC3339)), 0644) == nil
|
return ioutil.WriteFile(path, []byte(t.Format(time.RFC3339)), 0644) == nil
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user