mirror of
https://github.com/taigrr/go-selfupdate
synced 2025-01-18 04:33:12 -08:00
236 lines
5.4 KiB
Go
236 lines
5.4 KiB
Go
package download
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"runtime"
|
|
)
|
|
|
|
type roundTripper struct {
|
|
RoundTripFn func(*http.Request) (*http.Response, error)
|
|
}
|
|
|
|
func (rt *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
|
|
return rt.RoundTripFn(r)
|
|
}
|
|
|
|
// Download encapsulates the state and parameters to download content
|
|
// from a URL which:
|
|
//
|
|
// - Publishes the percentage of the download completed to a channel.
|
|
// - May resume a previous download that was partially completed.
|
|
//
|
|
// Create an instance with the New() factory function.
|
|
type Download struct {
|
|
// net/http.Client to use when downloading the update.
|
|
// If nil, a default http.Client is used
|
|
HttpClient *http.Client
|
|
|
|
// As bytes are downloaded, they are written to Target.
|
|
// Download also uses the Target's Seek method to determine
|
|
// the size of partial-downloads so that it may properly
|
|
// request the remaining bytes to resume the download.
|
|
Target Target
|
|
|
|
// Progress returns the percentage of the download
|
|
// completed as an integer between 0 and 100
|
|
Progress chan (int)
|
|
|
|
// HTTP Method to use in the download request. Default is "GET"
|
|
Method string
|
|
|
|
// HTTP URL to issue the download request to
|
|
Url string
|
|
}
|
|
|
|
// New initializes a new Download object which will download
|
|
// the content from url into target.
|
|
func New(url string, target Target, httpClient *http.Client) *Download {
|
|
return &Download{
|
|
HttpClient: httpClient,
|
|
Progress: make(chan int),
|
|
Method: "GET",
|
|
Url: url,
|
|
Target: target,
|
|
}
|
|
}
|
|
|
|
// Get() downloads the content of a url to a target destination.
|
|
//
|
|
// Only HTTP/1.1 servers that implement the Range header support resuming a
|
|
// partially completed download.
|
|
//
|
|
// On success, the server must return 200 and the content, or 206 when resuming a partial download.
|
|
// If the HTTP server returns a 3XX redirect, it will be followed according to d.HttpClient's redirect policy.
|
|
//
|
|
func (d *Download) Get() (err error) {
|
|
// Close the progress channel whenever this function completes
|
|
defer close(d.Progress)
|
|
|
|
// determine the size of the download target to determine if we're resuming a partial download
|
|
offset, err := d.Target.Size()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// create the download request
|
|
req, err := http.NewRequest(d.Method, d.Url, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// create an http client if one does not exist
|
|
if d.HttpClient == nil {
|
|
d.HttpClient = http.DefaultClient
|
|
}
|
|
|
|
// we have to add headers like this so they get used across redirects
|
|
trans := d.HttpClient.Transport
|
|
if trans == nil {
|
|
trans = http.DefaultTransport
|
|
}
|
|
|
|
d.HttpClient.Transport = &roundTripper{
|
|
RoundTripFn: func(r *http.Request) (*http.Response, error) {
|
|
// add header for download continuation
|
|
if offset > 0 {
|
|
r.Header.Add("Range", fmt.Sprintf("%d-", offset))
|
|
}
|
|
|
|
// ask for gzipped content so that net/http won't unzip it for us
|
|
// and destroy the content length header we need for progress calculations
|
|
r.Header.Add("Accept-Encoding", "gzip")
|
|
|
|
return trans.RoundTrip(r)
|
|
},
|
|
}
|
|
|
|
// issue the download request
|
|
resp, err := d.HttpClient.Do(req)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
switch resp.StatusCode {
|
|
// ok
|
|
case 200, 206:
|
|
|
|
// server error
|
|
default:
|
|
err = fmt.Errorf("Non 2XX response when downloading update: %s", resp.Status)
|
|
return
|
|
}
|
|
|
|
// Determine how much we have to download
|
|
// net/http sets this to -1 when it is unknown
|
|
clength := resp.ContentLength
|
|
|
|
// Read the content from the response body
|
|
rd := resp.Body
|
|
|
|
// meter the rate at which we download content for
|
|
// progress reporting if we know how much to expect
|
|
if clength > 0 {
|
|
rd = &meteredReader{rd: rd, totalSize: clength, progress: d.Progress}
|
|
}
|
|
|
|
// Decompress the content if necessary
|
|
if resp.Header.Get("Content-Encoding") == "gzip" {
|
|
rd, err = gzip.NewReader(rd)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
// Download the update
|
|
_, err = io.Copy(d.Target, rd)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// meteredReader wraps a ReadCloser. Calls to a meteredReader's Read() method
|
|
// publish updates to a progress channel with the percentage read so far.
|
|
type meteredReader struct {
|
|
rd io.ReadCloser
|
|
totalSize int64
|
|
progress chan int
|
|
totalRead int64
|
|
ticks int64
|
|
}
|
|
|
|
func (m *meteredReader) Close() error {
|
|
return m.rd.Close()
|
|
}
|
|
|
|
func (m *meteredReader) Read(b []byte) (n int, err error) {
|
|
chunkSize := (m.totalSize / 100) + 1
|
|
lenB := int64(len(b))
|
|
|
|
var nChunk int
|
|
for start := int64(0); start < lenB; start += int64(nChunk) {
|
|
end := start + chunkSize
|
|
if end > lenB {
|
|
end = lenB
|
|
}
|
|
|
|
nChunk, err = m.rd.Read(b[start:end])
|
|
|
|
n += nChunk
|
|
m.totalRead += int64(nChunk)
|
|
|
|
if m.totalRead > (m.ticks * chunkSize) {
|
|
m.ticks += 1
|
|
// try to send on channel, but don't block if it's full
|
|
select {
|
|
case m.progress <- int(m.ticks + 1):
|
|
default:
|
|
}
|
|
|
|
// give the progress channel consumer a chance to run
|
|
runtime.Gosched()
|
|
}
|
|
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// A Target is what you can supply to Download,
|
|
// it's just an io.Writer with a Size() method so that
|
|
// the a Download can "resume" an interrupted download
|
|
type Target interface {
|
|
io.Writer
|
|
Size() (int, error)
|
|
}
|
|
|
|
type FileTarget struct {
|
|
*os.File
|
|
}
|
|
|
|
func (t *FileTarget) Size() (int, error) {
|
|
if fi, err := t.File.Stat(); err != nil {
|
|
return 0, err
|
|
} else {
|
|
return int(fi.Size()), nil
|
|
}
|
|
}
|
|
|
|
type MemoryTarget struct {
|
|
bytes.Buffer
|
|
}
|
|
|
|
func (t *MemoryTarget) Size() (int, error) {
|
|
return t.Buffer.Len(), nil
|
|
}
|