mirror of
https://github.com/taigrr/go-selfupdate
synced 2025-01-18 04:33:12 -08:00
Update to allow developers to customize the fetching of the update
information, diffs and binaries
This commit is contained in:
parent
463b28194b
commit
8376705435
57
selfupdate/requester.go
Normal file
57
selfupdate/requester.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package selfupdate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Requester interface allows developers to customize the method in which
|
||||||
|
// requests are made to retrieve the version and binary
|
||||||
|
type Requester interface {
|
||||||
|
Fetch(url string) (io.ReadCloser, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPRequester is the normal requester that is used and does an HTTP
|
||||||
|
// to the url location requested to retrieve the specified data.
|
||||||
|
type HTTPRequester struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch will return an HTTP request to the specified url and return
|
||||||
|
// the body of the result. An error will occur for a non 200 status code.
|
||||||
|
func (httpRequester *HTTPRequester) Fetch(url string) (io.ReadCloser, error) {
|
||||||
|
resp, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.Body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mockRequester used for some mock testing to ensure the requester contract
|
||||||
|
// works as specified.
|
||||||
|
type mockRequester struct {
|
||||||
|
currentIndex int
|
||||||
|
fetches []func(string) (io.ReadCloser, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *mockRequester) handleRequest(requestHandler func(string) (io.ReadCloser, error)) {
|
||||||
|
if mr.fetches == nil {
|
||||||
|
mr.fetches = []func(string) (io.ReadCloser, error){}
|
||||||
|
}
|
||||||
|
mr.fetches = append(mr.fetches, requestHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mr *mockRequester) Fetch(url string) (io.ReadCloser, error) {
|
||||||
|
if len(mr.fetches) <= mr.currentIndex {
|
||||||
|
return nil, fmt.Errorf("No for currentIndex %d to mock", mr.currentIndex)
|
||||||
|
}
|
||||||
|
current := mr.fetches[mr.currentIndex]
|
||||||
|
mr.currentIndex++
|
||||||
|
|
||||||
|
return current(url)
|
||||||
|
}
|
@ -36,15 +36,14 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/kardianos/osext"
|
"github.com/kardianos/osext"
|
||||||
"gopkg.in/inconshreveable/go-update.v0"
|
|
||||||
"github.com/kr/binarydist"
|
"github.com/kr/binarydist"
|
||||||
|
"gopkg.in/inconshreveable/go-update.v0"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -56,6 +55,7 @@ const devValidTime = 7 * 24 * time.Hour
|
|||||||
|
|
||||||
var ErrHashMismatch = errors.New("new file hash mismatch after patch")
|
var ErrHashMismatch = errors.New("new file hash mismatch after patch")
|
||||||
var up = update.New()
|
var up = update.New()
|
||||||
|
var defaultHTTPRequester = HTTPRequester{}
|
||||||
|
|
||||||
// Updater is the configuration and runtime data for doing an update.
|
// Updater is the configuration and runtime data for doing an update.
|
||||||
//
|
//
|
||||||
@ -81,6 +81,7 @@ type Updater struct {
|
|||||||
BinURL string // Base URL for full binary downloads.
|
BinURL string // Base URL for full binary downloads.
|
||||||
DiffURL string // Base URL for diff downloads.
|
DiffURL string // Base URL for diff downloads.
|
||||||
Dir string // Directory to store selfupdate state.
|
Dir string // Directory to store selfupdate state.
|
||||||
|
Requester Requester //Optional parameter to override existing http request handler
|
||||||
Info struct {
|
Info struct {
|
||||||
Version string
|
Version string
|
||||||
Sha256 []byte
|
Sha256 []byte
|
||||||
@ -177,7 +178,7 @@ func (u *Updater) update() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) fetchInfo() error {
|
func (u *Updater) fetchInfo() error {
|
||||||
r, err := fetch(u.ApiURL + u.CmdName + "/" + plat + ".json")
|
r, err := u.fetch(u.ApiURL + u.CmdName + "/" + plat + ".json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -204,7 +205,7 @@ func (u *Updater) fetchAndVerifyPatch(old io.Reader) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) fetchAndApplyPatch(old io.Reader) ([]byte, error) {
|
func (u *Updater) fetchAndApplyPatch(old io.Reader) ([]byte, error) {
|
||||||
r, err := fetch(u.DiffURL + u.CmdName + "/" + u.CurrentVersion + "/" + u.Info.Version + "/" + plat)
|
r, err := u.fetch(u.DiffURL + u.CmdName + "/" + u.CurrentVersion + "/" + u.Info.Version + "/" + plat)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -227,7 +228,7 @@ func (u *Updater) fetchAndVerifyFullBin() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) fetchBin() ([]byte, error) {
|
func (u *Updater) fetchBin() ([]byte, error) {
|
||||||
r, err := fetch(u.BinURL + u.CmdName + "/" + u.Info.Version + "/" + plat + ".gz")
|
r, err := u.fetch(u.BinURL + u.CmdName + "/" + u.Info.Version + "/" + plat + ".gz")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -249,15 +250,21 @@ func randDuration(n time.Duration) time.Duration {
|
|||||||
return time.Duration(rand.Int63n(int64(n)))
|
return time.Duration(rand.Int63n(int64(n)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetch(url string) (io.ReadCloser, error) {
|
func (u *Updater) fetch(url string) (io.ReadCloser, error) {
|
||||||
resp, err := http.Get(url)
|
if u.Requester == nil {
|
||||||
|
return defaultHTTPRequester.Fetch(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
readCloser, err := u.Requester.Fetch(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if resp.StatusCode != 200 {
|
|
||||||
return nil, fmt.Errorf("bad http status from %s: %v", url, resp.Status)
|
if readCloser == nil {
|
||||||
|
return nil, fmt.Errorf("Fetch was expected to return non-nil ReadCloser")
|
||||||
}
|
}
|
||||||
return resp.Body, nil
|
|
||||||
|
return readCloser, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readTime(path string) time.Time {
|
func readTime(path string) time.Time {
|
||||||
|
@ -1,6 +1,79 @@
|
|||||||
package selfupdate
|
package selfupdate
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var testHash = sha256.New()
|
||||||
|
|
||||||
|
func TestUpdaterFetchMustReturnNonNilReaderCloser(t *testing.T) {
|
||||||
|
mr := &mockRequester{}
|
||||||
|
mr.handleRequest(
|
||||||
|
func(url string) (io.ReadCloser, error) {
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
updater := createUpdater(mr)
|
||||||
|
err := updater.BackgroundRun()
|
||||||
|
if err != nil {
|
||||||
|
equals(t, "Fetch was expected to return non-nil ReadCloser", err.Error())
|
||||||
|
} else {
|
||||||
|
t.Log("Expected an error")
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
|
||||||
func TestUpdater(t *testing.T) {
|
}
|
||||||
|
|
||||||
|
func TestUpdaterWithEmptyPaloadNoErrorNoUpdate(t *testing.T) {
|
||||||
|
mr := &mockRequester{}
|
||||||
|
mr.handleRequest(
|
||||||
|
func(url string) (io.ReadCloser, error) {
|
||||||
|
equals(t, "http://updates.yourdomain.com/myapp/darwin-amd64.json", url)
|
||||||
|
return newTestReaderCloser("{}"), nil
|
||||||
|
})
|
||||||
|
updater := createUpdater(mr)
|
||||||
|
|
||||||
|
err := updater.BackgroundRun()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error occured: %#v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUpdater(mr *mockRequester) *Updater {
|
||||||
|
return &Updater{
|
||||||
|
CurrentVersion: "1.2",
|
||||||
|
ApiURL: "http://updates.yourdomain.com/",
|
||||||
|
BinURL: "http://updates.yourdownmain.com/",
|
||||||
|
DiffURL: "http://updates.yourdomain.com/",
|
||||||
|
Dir: "update/",
|
||||||
|
CmdName: "myapp", // app name
|
||||||
|
Requester: mr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func equals(t *testing.T, expected, actual interface{}) {
|
||||||
|
if expected != actual {
|
||||||
|
t.Log(fmt.Sprintf("Expected: %#v %#v\n", expected, actual))
|
||||||
|
t.Fail()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type testReadCloser struct {
|
||||||
|
buffer *bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestReaderCloser(payload string) io.ReadCloser {
|
||||||
|
return &testReadCloser{buffer: bytes.NewBufferString(payload)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (trc *testReadCloser) Read(p []byte) (n int, err error) {
|
||||||
|
return trc.buffer.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (trc *testReadCloser) Close() error {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user