1
0
mirror of https://github.com/taigrr/go-selfupdate synced 2025-01-18 04:33:12 -08:00

Update to allow developers to customize the fetching of the update

information, diffs and binaries
This commit is contained in:
chrisprobinson 2015-10-16 20:31:01 -07:00
parent 463b28194b
commit 8376705435
3 changed files with 155 additions and 18 deletions

57
selfupdate/requester.go Normal file
View File

@ -0,0 +1,57 @@
package selfupdate
import (
"fmt"
"io"
"net/http"
)
// Requester interface allows developers to customize the method in which
// requests are made to retrieve the version and binary
type Requester interface {
Fetch(url string) (io.ReadCloser, error)
}
// HTTPRequester is the normal requester that is used and does an HTTP
// to the url location requested to retrieve the specified data.
type HTTPRequester struct {
}
// Fetch will return an HTTP request to the specified url and return
// the body of the result. An error will occur for a non 200 status code.
func (httpRequester *HTTPRequester) Fetch(url string) (io.ReadCloser, error) {
resp, err := http.Get(url)
if err != nil {
return nil, err
}
if resp.StatusCode != 200 {
return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status)
}
return resp.Body, nil
}
// mockRequester used for some mock testing to ensure the requester contract
// works as specified.
type mockRequester struct {
currentIndex int
fetches []func(string) (io.ReadCloser, error)
}
func (mr *mockRequester) handleRequest(requestHandler func(string) (io.ReadCloser, error)) {
if mr.fetches == nil {
mr.fetches = []func(string) (io.ReadCloser, error){}
}
mr.fetches = append(mr.fetches, requestHandler)
}
func (mr *mockRequester) Fetch(url string) (io.ReadCloser, error) {
if len(mr.fetches) <= mr.currentIndex {
return nil, fmt.Errorf("No for currentIndex %d to mock", mr.currentIndex)
}
current := mr.fetches[mr.currentIndex]
mr.currentIndex++
return current(url)
}

View File

@ -36,15 +36,14 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"math/rand" "math/rand"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"time" "time"
"github.com/kardianos/osext" "github.com/kardianos/osext"
"gopkg.in/inconshreveable/go-update.v0"
"github.com/kr/binarydist" "github.com/kr/binarydist"
"gopkg.in/inconshreveable/go-update.v0"
) )
const ( const (
@ -56,6 +55,7 @@ const devValidTime = 7 * 24 * time.Hour
var ErrHashMismatch = errors.New("new file hash mismatch after patch") var ErrHashMismatch = errors.New("new file hash mismatch after patch")
var up = update.New() var up = update.New()
var defaultHTTPRequester = HTTPRequester{}
// Updater is the configuration and runtime data for doing an update. // Updater is the configuration and runtime data for doing an update.
// //
@ -81,6 +81,7 @@ type Updater struct {
BinURL string // Base URL for full binary downloads. BinURL string // Base URL for full binary downloads.
DiffURL string // Base URL for diff downloads. DiffURL string // Base URL for diff downloads.
Dir string // Directory to store selfupdate state. Dir string // Directory to store selfupdate state.
Requester Requester //Optional parameter to override existing http request handler
Info struct { Info struct {
Version string Version string
Sha256 []byte Sha256 []byte
@ -177,7 +178,7 @@ func (u *Updater) update() error {
} }
func (u *Updater) fetchInfo() error { func (u *Updater) fetchInfo() error {
r, err := fetch(u.ApiURL + u.CmdName + "/" + plat + ".json") r, err := u.fetch(u.ApiURL + u.CmdName + "/" + plat + ".json")
if err != nil { if err != nil {
return err return err
} }
@ -204,7 +205,7 @@ func (u *Updater) fetchAndVerifyPatch(old io.Reader) ([]byte, error) {
} }
func (u *Updater) fetchAndApplyPatch(old io.Reader) ([]byte, error) { func (u *Updater) fetchAndApplyPatch(old io.Reader) ([]byte, error) {
r, err := fetch(u.DiffURL + u.CmdName + "/" + u.CurrentVersion + "/" + u.Info.Version + "/" + plat) r, err := u.fetch(u.DiffURL + u.CmdName + "/" + u.CurrentVersion + "/" + u.Info.Version + "/" + plat)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -227,7 +228,7 @@ func (u *Updater) fetchAndVerifyFullBin() ([]byte, error) {
} }
func (u *Updater) fetchBin() ([]byte, error) { func (u *Updater) fetchBin() ([]byte, error) {
r, err := fetch(u.BinURL + u.CmdName + "/" + u.Info.Version + "/" + plat + ".gz") r, err := u.fetch(u.BinURL + u.CmdName + "/" + u.Info.Version + "/" + plat + ".gz")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -249,15 +250,21 @@ func randDuration(n time.Duration) time.Duration {
return time.Duration(rand.Int63n(int64(n))) return time.Duration(rand.Int63n(int64(n)))
} }
func fetch(url string) (io.ReadCloser, error) { func (u *Updater) fetch(url string) (io.ReadCloser, error) {
resp, err := http.Get(url) if u.Requester == nil {
return defaultHTTPRequester.Fetch(url)
}
readCloser, err := u.Requester.Fetch(url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if resp.StatusCode != 200 {
return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status) if readCloser == nil {
return nil, fmt.Errorf("Fetch was expected to return non-nil ReadCloser")
} }
return resp.Body, nil
return readCloser, nil
} }
func readTime(path string) time.Time { func readTime(path string) time.Time {

View File

@ -1,6 +1,79 @@
package selfupdate package selfupdate
import "testing" import (
"bytes"
"crypto/sha256"
"fmt"
"io"
"testing"
)
var testHash = sha256.New()
func TestUpdaterFetchMustReturnNonNilReaderCloser(t *testing.T) {
mr := &mockRequester{}
mr.handleRequest(
func(url string) (io.ReadCloser, error) {
return nil, nil
})
updater := createUpdater(mr)
err := updater.BackgroundRun()
if err != nil {
equals(t, "Fetch was expected to return non-nil ReadCloser", err.Error())
} else {
t.Log("Expected an error")
t.Fail()
}
func TestUpdater(t *testing.T) { }
func TestUpdaterWithEmptyPaloadNoErrorNoUpdate(t *testing.T) {
mr := &mockRequester{}
mr.handleRequest(
func(url string) (io.ReadCloser, error) {
equals(t, "http://updates.yourdomain.com/myapp/darwin-amd64.json", url)
return newTestReaderCloser("{}"), nil
})
updater := createUpdater(mr)
err := updater.BackgroundRun()
if err != nil {
t.Errorf("Error occured: %#v", err)
}
}
func createUpdater(mr *mockRequester) *Updater {
return &Updater{
CurrentVersion: "1.2",
ApiURL: "http://updates.yourdomain.com/",
BinURL: "http://updates.yourdownmain.com/",
DiffURL: "http://updates.yourdomain.com/",
Dir: "update/",
CmdName: "myapp", // app name
Requester: mr,
}
}
func equals(t *testing.T, expected, actual interface{}) {
if expected != actual {
t.Log(fmt.Sprintf("Expected: %#v %#v\n", expected, actual))
t.Fail()
}
}
type testReadCloser struct {
buffer *bytes.Buffer
}
func newTestReaderCloser(payload string) io.ReadCloser {
return &testReadCloser{buffer: bytes.NewBufferString(payload)}
}
func (trc *testReadCloser) Read(p []byte) (n int, err error) {
return trc.buffer.Read(p)
}
func (trc *testReadCloser) Close() error {
return nil
} }