1
0
mirror of https://github.com/taigrr/go-selfupdate synced 2025-01-18 04:33:12 -08:00
2019-03-29 17:12:33 -07:00

236 lines
5.4 KiB
Go

package download
import (
"bytes"
"compress/gzip"
"fmt"
"io"
"net/http"
"os"
"runtime"
)
type roundTripper struct {
RoundTripFn func(*http.Request) (*http.Response, error)
}
func (rt *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
return rt.RoundTripFn(r)
}
// Download encapsulates the state and parameters to download content
// from a URL which:
//
// - Publishes the percentage of the download completed to a channel.
// - May resume a previous download that was partially completed.
//
// Create an instance with the New() factory function.
type Download struct {
// net/http.Client to use when downloading the update.
// If nil, a default http.Client is used
HttpClient *http.Client
// As bytes are downloaded, they are written to Target.
// Download also uses the Target's Seek method to determine
// the size of partial-downloads so that it may properly
// request the remaining bytes to resume the download.
Target Target
// Progress returns the percentage of the download
// completed as an integer between 0 and 100
Progress chan (int)
// HTTP Method to use in the download request. Default is "GET"
Method string
// HTTP URL to issue the download request to
Url string
}
// New initializes a new Download object which will download
// the content from url into target.
func New(url string, target Target, httpClient *http.Client) *Download {
return &Download{
HttpClient: httpClient,
Progress: make(chan int),
Method: "GET",
Url: url,
Target: target,
}
}
// Get() downloads the content of a url to a target destination.
//
// Only HTTP/1.1 servers that implement the Range header support resuming a
// partially completed download.
//
// On success, the server must return 200 and the content, or 206 when resuming a partial download.
// If the HTTP server returns a 3XX redirect, it will be followed according to d.HttpClient's redirect policy.
//
func (d *Download) Get() (err error) {
// Close the progress channel whenever this function completes
defer close(d.Progress)
// determine the size of the download target to determine if we're resuming a partial download
offset, err := d.Target.Size()
if err != nil {
return
}
// create the download request
req, err := http.NewRequest(d.Method, d.Url, nil)
if err != nil {
return
}
// create an http client if one does not exist
if d.HttpClient == nil {
d.HttpClient = http.DefaultClient
}
// we have to add headers like this so they get used across redirects
trans := d.HttpClient.Transport
if trans == nil {
trans = http.DefaultTransport
}
d.HttpClient.Transport = &roundTripper{
RoundTripFn: func(r *http.Request) (*http.Response, error) {
// add header for download continuation
if offset > 0 {
r.Header.Add("Range", fmt.Sprintf("%d-", offset))
}
// ask for gzipped content so that net/http won't unzip it for us
// and destroy the content length header we need for progress calculations
r.Header.Add("Accept-Encoding", "gzip")
return trans.RoundTrip(r)
},
}
// issue the download request
resp, err := d.HttpClient.Do(req)
if err != nil {
return
}
defer resp.Body.Close()
switch resp.StatusCode {
// ok
case 200, 206:
// server error
default:
err = fmt.Errorf("Non 2XX response when downloading update: %s", resp.Status)
return
}
// Determine how much we have to download
// net/http sets this to -1 when it is unknown
clength := resp.ContentLength
// Read the content from the response body
rd := resp.Body
// meter the rate at which we download content for
// progress reporting if we know how much to expect
if clength > 0 {
rd = &meteredReader{rd: rd, totalSize: clength, progress: d.Progress}
}
// Decompress the content if necessary
if resp.Header.Get("Content-Encoding") == "gzip" {
rd, err = gzip.NewReader(rd)
if err != nil {
return
}
}
// Download the update
_, err = io.Copy(d.Target, rd)
if err != nil {
return
}
return
}
// meteredReader wraps a ReadCloser. Calls to a meteredReader's Read() method
// publish updates to a progress channel with the percentage read so far.
type meteredReader struct {
rd io.ReadCloser
totalSize int64
progress chan int
totalRead int64
ticks int64
}
func (m *meteredReader) Close() error {
return m.rd.Close()
}
func (m *meteredReader) Read(b []byte) (n int, err error) {
chunkSize := (m.totalSize / 100) + 1
lenB := int64(len(b))
var nChunk int
for start := int64(0); start < lenB; start += int64(nChunk) {
end := start + chunkSize
if end > lenB {
end = lenB
}
nChunk, err = m.rd.Read(b[start:end])
n += nChunk
m.totalRead += int64(nChunk)
if m.totalRead > (m.ticks * chunkSize) {
m.ticks += 1
// try to send on channel, but don't block if it's full
select {
case m.progress <- int(m.ticks + 1):
default:
}
// give the progress channel consumer a chance to run
runtime.Gosched()
}
if err != nil {
return
}
}
return
}
// A Target is what you can supply to Download,
// it's just an io.Writer with a Size() method so that
// the a Download can "resume" an interrupted download
type Target interface {
io.Writer
Size() (int, error)
}
type FileTarget struct {
*os.File
}
func (t *FileTarget) Size() (int, error) {
if fi, err := t.File.Stat(); err != nil {
return 0, err
} else {
return int(fi.Size()), nil
}
}
type MemoryTarget struct {
bytes.Buffer
}
func (t *MemoryTarget) Size() (int, error) {
return t.Buffer.Len(), nil
}