From 8376705435689383736853e9eef0d29d59eb72d3 Mon Sep 17 00:00:00 2001 From: chrisprobinson Date: Fri, 16 Oct 2015 20:31:01 -0700 Subject: [PATCH] Update to allow developers to customize the fetching of the update information, diffs and binaries --- selfupdate/requester.go | 57 ++++++++++++++++++++++++++ selfupdate/selfupdate.go | 39 ++++++++++-------- selfupdate/selfupdate_test.go | 77 ++++++++++++++++++++++++++++++++++- 3 files changed, 155 insertions(+), 18 deletions(-) create mode 100644 selfupdate/requester.go diff --git a/selfupdate/requester.go b/selfupdate/requester.go new file mode 100644 index 0000000..4554fac --- /dev/null +++ b/selfupdate/requester.go @@ -0,0 +1,57 @@ +package selfupdate + +import ( + "fmt" + "io" + "net/http" +) + +// Requester interface allows developers to customize the method in which +// requests are made to retrieve the version and binary +type Requester interface { + Fetch(url string) (io.ReadCloser, error) +} + +// HTTPRequester is the normal requester that is used and does an HTTP +// to the url location requested to retrieve the specified data. +type HTTPRequester struct { +} + +// Fetch will return an HTTP request to the specified url and return +// the body of the result. An error will occur for a non 200 status code. +func (httpRequester *HTTPRequester) Fetch(url string) (io.ReadCloser, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status) + } + + return resp.Body, nil +} + +// mockRequester used for some mock testing to ensure the requester contract +// works as specified. +type mockRequester struct { + currentIndex int + fetches []func(string) (io.ReadCloser, error) +} + +func (mr *mockRequester) handleRequest(requestHandler func(string) (io.ReadCloser, error)) { + if mr.fetches == nil { + mr.fetches = []func(string) (io.ReadCloser, error){} + } + mr.fetches = append(mr.fetches, requestHandler) +} + +func (mr *mockRequester) Fetch(url string) (io.ReadCloser, error) { + if len(mr.fetches) <= mr.currentIndex { + return nil, fmt.Errorf("No for currentIndex %d to mock", mr.currentIndex) + } + current := mr.fetches[mr.currentIndex] + mr.currentIndex++ + + return current(url) +} diff --git a/selfupdate/selfupdate.go b/selfupdate/selfupdate.go index ab3a3c7..86646f3 100644 --- a/selfupdate/selfupdate.go +++ b/selfupdate/selfupdate.go @@ -36,15 +36,14 @@ import ( "io/ioutil" "log" "math/rand" - "net/http" "os" "path/filepath" "runtime" "time" "github.com/kardianos/osext" - "gopkg.in/inconshreveable/go-update.v0" "github.com/kr/binarydist" + "gopkg.in/inconshreveable/go-update.v0" ) const ( @@ -56,6 +55,7 @@ const devValidTime = 7 * 24 * time.Hour var ErrHashMismatch = errors.New("new file hash mismatch after patch") var up = update.New() +var defaultHTTPRequester = HTTPRequester{} // Updater is the configuration and runtime data for doing an update. // @@ -75,12 +75,13 @@ var up = update.New() // go updater.BackgroundRun() // } type Updater struct { - CurrentVersion string // Currently running version. - ApiURL string // Base URL for API requests (json files). - CmdName string // Command name is appended to the ApiURL like http://apiurl/CmdName/. This represents one binary. - BinURL string // Base URL for full binary downloads. - DiffURL string // Base URL for diff downloads. - Dir string // Directory to store selfupdate state. + CurrentVersion string // Currently running version. + ApiURL string // Base URL for API requests (json files). + CmdName string // Command name is appended to the ApiURL like http://apiurl/CmdName/. This represents one binary. + BinURL string // Base URL for full binary downloads. + DiffURL string // Base URL for diff downloads. + Dir string // Directory to store selfupdate state. + Requester Requester //Optional parameter to override existing http request handler Info struct { Version string Sha256 []byte @@ -177,7 +178,7 @@ func (u *Updater) update() error { } func (u *Updater) fetchInfo() error { - r, err := fetch(u.ApiURL + u.CmdName + "/" + plat + ".json") + r, err := u.fetch(u.ApiURL + u.CmdName + "/" + plat + ".json") if err != nil { return err } @@ -204,7 +205,7 @@ func (u *Updater) fetchAndVerifyPatch(old io.Reader) ([]byte, error) { } func (u *Updater) fetchAndApplyPatch(old io.Reader) ([]byte, error) { - r, err := fetch(u.DiffURL + u.CmdName + "/" + u.CurrentVersion + "/" + u.Info.Version + "/" + plat) + r, err := u.fetch(u.DiffURL + u.CmdName + "/" + u.CurrentVersion + "/" + u.Info.Version + "/" + plat) if err != nil { return nil, err } @@ -227,7 +228,7 @@ func (u *Updater) fetchAndVerifyFullBin() ([]byte, error) { } func (u *Updater) fetchBin() ([]byte, error) { - r, err := fetch(u.BinURL + u.CmdName + "/" + u.Info.Version + "/" + plat + ".gz") + r, err := u.fetch(u.BinURL + u.CmdName + "/" + u.Info.Version + "/" + plat + ".gz") if err != nil { return nil, err } @@ -249,15 +250,21 @@ func randDuration(n time.Duration) time.Duration { return time.Duration(rand.Int63n(int64(n))) } -func fetch(url string) (io.ReadCloser, error) { - resp, err := http.Get(url) +func (u *Updater) fetch(url string) (io.ReadCloser, error) { + if u.Requester == nil { + return defaultHTTPRequester.Fetch(url) + } + + readCloser, err := u.Requester.Fetch(url) if err != nil { return nil, err } - if resp.StatusCode != 200 { - return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status) + + if readCloser == nil { + return nil, fmt.Errorf("Fetch was expected to return non-nil ReadCloser") } - return resp.Body, nil + + return readCloser, nil } func readTime(path string) time.Time { diff --git a/selfupdate/selfupdate_test.go b/selfupdate/selfupdate_test.go index 068081a..9d8ada2 100644 --- a/selfupdate/selfupdate_test.go +++ b/selfupdate/selfupdate_test.go @@ -1,6 +1,79 @@ package selfupdate -import "testing" +import ( + "bytes" + "crypto/sha256" + "fmt" + "io" + "testing" +) + +var testHash = sha256.New() + +func TestUpdaterFetchMustReturnNonNilReaderCloser(t *testing.T) { + mr := &mockRequester{} + mr.handleRequest( + func(url string) (io.ReadCloser, error) { + return nil, nil + }) + updater := createUpdater(mr) + err := updater.BackgroundRun() + if err != nil { + equals(t, "Fetch was expected to return non-nil ReadCloser", err.Error()) + } else { + t.Log("Expected an error") + t.Fail() + } -func TestUpdater(t *testing.T) { +} + +func TestUpdaterWithEmptyPaloadNoErrorNoUpdate(t *testing.T) { + mr := &mockRequester{} + mr.handleRequest( + func(url string) (io.ReadCloser, error) { + equals(t, "http://updates.yourdomain.com/myapp/darwin-amd64.json", url) + return newTestReaderCloser("{}"), nil + }) + updater := createUpdater(mr) + + err := updater.BackgroundRun() + if err != nil { + t.Errorf("Error occured: %#v", err) + } + +} + +func createUpdater(mr *mockRequester) *Updater { + return &Updater{ + CurrentVersion: "1.2", + ApiURL: "http://updates.yourdomain.com/", + BinURL: "http://updates.yourdownmain.com/", + DiffURL: "http://updates.yourdomain.com/", + Dir: "update/", + CmdName: "myapp", // app name + Requester: mr, + } +} + +func equals(t *testing.T, expected, actual interface{}) { + if expected != actual { + t.Log(fmt.Sprintf("Expected: %#v %#v\n", expected, actual)) + t.Fail() + } +} + +type testReadCloser struct { + buffer *bytes.Buffer +} + +func newTestReaderCloser(payload string) io.ReadCloser { + return &testReadCloser{buffer: bytes.NewBufferString(payload)} +} + +func (trc *testReadCloser) Read(p []byte) (n int, err error) { + return trc.buffer.Read(p) +} + +func (trc *testReadCloser) Close() error { + return nil }