diff --git a/adb.go b/adb.go index 3d38f4e..3b80b1f 100644 --- a/adb.go +++ b/adb.go @@ -127,11 +127,11 @@ func (d Device) Disconnect(ctx context.Context) error { return err } -// Kill the ADB Server +// KillServer kills the ADB server. // -// Warning, this function call may cause inconsostency if not used properly. +// Warning: this function call may cause inconsistency if not used properly. // Killing the ADB server shouldn't ever technically be necessary, but if you do -// decide to use this function. note that it may invalidate all existing device structs. +// decide to use this function, note that it may invalidate all existing device structs. // Older versions of Android don't play nicely with kill-server, and some may // refuse following connection attempts if you don't disconnect from them before // calling this function. @@ -148,71 +148,70 @@ func (d Device) Push(ctx context.Context, src, dest string) error { if err != nil { return err } - stdout, stderr, errcode, err := execute(ctx, []string{"push", src, dest}) + _, _, errcode, err := execute(ctx, []string{"-s", string(d.SerialNo), "push", src, dest}) if err != nil { return err } if errcode != 0 { return ErrUnspecified } - _, _ = stdout, stderr - // TODO check the return strings of the output to determine if the file copy succeeded return nil } -// Pulls a file from a Device +// Pull a file from a Device. // -// Returns an error if src does not exist, or if dest already exists or cannot be created +// Returns an error if dest already exists or the file cannot be pulled. func (d Device) Pull(ctx context.Context, src, dest string) error { - _, err := os.Stat(src) + _, err := os.Stat(dest) + if err == nil { + return ErrDestExists + } if !errors.Is(err, os.ErrNotExist) { return err } - stdout, stderr, errcode, err := execute(ctx, []string{"pull", src, dest}) + _, _, errcode, err := execute(ctx, []string{"-s", string(d.SerialNo), "pull", src, dest}) if err != nil { return err } if errcode != 0 { return ErrUnspecified } - _, _ = stdout, stderr - // TODO check the return strings of the output to determine if the file copy succeeded return nil } -// Attempts to reboot the device +// Reboot attempts to reboot the device. // // Once the device reboots, you must manually reconnect. -// Returns an error if the device cannot be contacted +// Returns an error if the device cannot be contacted. func (d Device) Reboot(ctx context.Context) error { - stdout, stderr, errcode, err := execute(ctx, []string{"reboot"}) + _, _, errcode, err := execute(ctx, []string{"-s", string(d.SerialNo), "reboot"}) if err != nil { return err } if errcode != 0 { return ErrUnspecified } - _, _ = stdout, stderr - // TODO check the return strings of the output to determine if the file copy succeeded return nil } -// Attempt to relaunch adb as root on the Device. +// Root attempts to relaunch adb as root on the Device. // // Note, this may not be possible on most devices. // Returns an error if it can't be done. // The device connection will stay established. // Once adb is relaunched as root, it will stay root until rebooted. -// returns true if the device successfully relaunched as root +// Returns true if the device successfully relaunched as root. func (d Device) Root(ctx context.Context) (success bool, err error) { - stdout, stderr, errcode, err := execute(ctx, []string{"root"}) + stdout, _, errcode, err := execute(ctx, []string{"-s", string(d.SerialNo), "root"}) if err != nil { return false, err } if errcode != 0 { return false, ErrUnspecified } - _, _ = stdout, stderr - // TODO check the return strings of the output to determine if the file copy succeeded - return true, nil + if strings.Contains(stdout, "adbd is already running as root") || + strings.Contains(stdout, "restarting adbd as root") { + return true, nil + } + return false, nil } diff --git a/adb_test.go b/adb_test.go index 5779e42..a098239 100644 --- a/adb_test.go +++ b/adb_test.go @@ -1,8 +1,10 @@ package adb import ( + "net" "reflect" "testing" + "time" ) func Test_parseDevices(t *testing.T) { @@ -48,6 +50,33 @@ HT75R0202681 unauthorized`}, {IsAuthorized: false, SerialNo: "HT75R0202681"}, }, }, + { + name: "empty string", + args: args{stdout: ""}, + wantErr: false, + want: []Device{}, + }, + { + name: "offline device", + args: args{stdout: `List of devices attached +ABCD1234 offline`}, + wantErr: false, + want: []Device{ + {IsAuthorized: false, SerialNo: "ABCD1234"}, + }, + }, + { + name: "extra whitespace lines", + args: args{stdout: `List of devices attached + +19291FDEE0023W device + +`}, + wantErr: false, + want: []Device{ + {IsAuthorized: true, SerialNo: "19291FDEE0023W"}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -62,3 +91,171 @@ HT75R0202681 unauthorized`}, }) } } + +func TestDevice_ConnString(t *testing.T) { + tests := []struct { + name string + dev Device + want string + }{ + { + name: "default port", + dev: Device{IP: net.IPAddr{IP: net.ParseIP("192.168.1.100")}}, + want: "192.168.1.100:5555", + }, + { + name: "custom port", + dev: Device{IP: net.IPAddr{IP: net.ParseIP("10.0.0.5")}, Port: 5556}, + want: "10.0.0.5:5556", + }, + { + name: "ipv6", + dev: Device{IP: net.IPAddr{IP: net.ParseIP("::1")}, Port: 5555}, + want: "::1:5555", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.dev.ConnString() + if got != tt.want { + t.Errorf("ConnString() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestTapSequence_ShortenSleep(t *testing.T) { + seq := TapSequence{ + Events: []Input{ + SequenceTap{X: 100, Y: 200, Type: SeqTap}, + SequenceSleep{Duration: time.Second * 4, Type: SeqSleep}, + SequenceTap{X: 300, Y: 400, Type: SeqTap}, + }, + } + shortened := seq.ShortenSleep(2) + if len(shortened.Events) != 3 { + t.Fatalf("expected 3 events, got %d", len(shortened.Events)) + } + sleep, ok := shortened.Events[1].(SequenceSleep) + if !ok { + t.Fatal("expected second event to be SequenceSleep") + } + if sleep.Duration != time.Second*2 { + t.Errorf("expected sleep duration 2s, got %v", sleep.Duration) + } +} + +func TestTapSequence_GetLength(t *testing.T) { + now := time.Now() + seq := TapSequence{ + Events: []Input{ + SequenceSleep{Duration: time.Second * 10, Type: SeqSleep}, + SequenceSwipe{ + X1: 0, Y1: 0, X2: 100, Y2: 100, + Start: now, End: now.Add(time.Second * 5), + Type: SeqSwipe, + }, + }, + } + got := seq.GetLength() + // 15s * 110/100 = 16.5s + want := time.Second * 15 * 110 / 100 + if got != want { + t.Errorf("GetLength() = %v, want %v", got, want) + } +} + +func TestTapSequence_JSONRoundTrip(t *testing.T) { + now := time.UnixMilli(1700000000000) + original := TapSequence{ + Resolution: Resolution{Width: 1080, Height: 2340}, + Events: []Input{ + SequenceSwipe{ + X1: 10, Y1: 20, X2: 30, Y2: 40, + Start: now, End: now.Add(time.Millisecond * 500), + Type: SeqSwipe, + }, + }, + } + jsonBytes := original.ToJSON() + roundTripped, err := TapSequenceFromJSON(jsonBytes) + if err != nil { + t.Fatalf("TapSequenceFromJSON() error = %v", err) + } + if roundTripped.Resolution != original.Resolution { + t.Errorf("Resolution mismatch: got %v, want %v", roundTripped.Resolution, original.Resolution) + } + if len(roundTripped.Events) != len(original.Events) { + t.Fatalf("Events length mismatch: got %d, want %d", len(roundTripped.Events), len(original.Events)) + } +} + +func TestSequenceImporter_ToInput(t *testing.T) { + now := time.UnixMilli(1700000000000) + tests := []struct { + name string + importer SequenceImporter + wantType SeqType + }{ + { + name: "sleep", + importer: SequenceImporter{Type: SeqSleep, Duration: time.Second}, + wantType: SeqSleep, + }, + { + name: "tap", + importer: SequenceImporter{Type: SeqTap, X: 10, Y: 20, Start: now, End: now}, + wantType: SeqTap, + }, + { + name: "swipe", + importer: SequenceImporter{Type: SeqSwipe, X1: 10, Y1: 20, X2: 30, Y2: 40, Start: now, End: now.Add(time.Second)}, + wantType: SeqSwipe, + }, + { + name: "unknown defaults to sleep", + importer: SequenceImporter{Type: SeqType(99)}, + wantType: SeqSleep, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input := tt.importer.ToInput() + switch tt.wantType { + case SeqSleep: + if _, ok := input.(SequenceSleep); !ok { + t.Errorf("expected SequenceSleep, got %T", input) + } + case SeqTap: + if _, ok := input.(SequenceTap); !ok { + t.Errorf("expected SequenceTap, got %T", input) + } + case SeqSwipe: + if _, ok := input.(SequenceSwipe); !ok { + t.Errorf("expected SequenceSwipe, got %T", input) + } + } + }) + } +} + +func TestInsertSleeps(t *testing.T) { + now := time.UnixMilli(1000) + inputs := []Input{ + SequenceTap{X: 1, Y: 2, Start: now, End: now.Add(time.Millisecond * 100), Type: SeqTap}, + SequenceTap{X: 3, Y: 4, Start: now.Add(time.Millisecond * 500), End: now.Add(time.Millisecond * 600), Type: SeqTap}, + } + result := insertSleeps(inputs) + // Should be: tap, sleep, tap + if len(result) != 3 { + t.Fatalf("expected 3 events, got %d", len(result)) + } + sleep, ok := result[1].(SequenceSleep) + if !ok { + t.Fatal("expected second event to be SequenceSleep") + } + // Sleep should be from end of first (100ms) to end of second (600ms) = 500ms + if sleep.Duration != time.Millisecond*500 { + t.Errorf("expected sleep duration 500ms, got %v", sleep.Duration) + } +} diff --git a/errors.go b/errors.go index 896b313..df377d9 100644 --- a/errors.go +++ b/errors.go @@ -5,12 +5,18 @@ import ( ) var ( - // When an execution should have data but has none, but the exact error is - // indeterminite, this error is returned - ErrStdoutEmpty = errors.New("stdout expected to contain data but was empty") - ErrNotInstalled = errors.New("adb is not installed or not in PATH") + // ErrStdoutEmpty is returned when an execution should have data but has none. + ErrStdoutEmpty = errors.New("stdout expected to contain data but was empty") + // ErrNotInstalled is returned when adb cannot be found in PATH. + ErrNotInstalled = errors.New("adb is not installed or not in PATH") + // ErrCoordinatesNotFound is returned when touch event coordinates are missing. ErrCoordinatesNotFound = errors.New("coordinates for an input event are missing") - ErrConnUSB = errors.New("cannot call connect/disconnect to device using USB") + // ErrConnUSB is returned when connect/disconnect is called on a USB device. + ErrConnUSB = errors.New("cannot call connect/disconnect to device using USB") + // ErrResolutionParseFail is returned when screen resolution output cannot be parsed. ErrResolutionParseFail = errors.New("failed to parse screen size from input text") - ErrUnspecified = errors.New("an unknown error has occurred, please open an issue on GitHub") + // ErrDestExists is returned when a pull destination file already exists. + ErrDestExists = errors.New("destination file already exists") + // ErrUnspecified is returned when the exact error cannot be determined. + ErrUnspecified = errors.New("an unknown error has occurred, please open an issue on GitHub") ) diff --git a/go.mod b/go.mod index aaf0bd9..b3a0646 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ module github.com/taigrr/adb -go 1.26.0 +go 1.26.1 require github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 diff --git a/shell.go b/shell.go index 8fedeec..996c251 100644 --- a/shell.go +++ b/shell.go @@ -15,15 +15,14 @@ import ( // you require functionality not provided by the exposed functions here. // Instead of using Shell, please consider submitting a PR with the functionality // you require. -func (d Device) Shell(ctx context.Context, command string) (stdout string, stderr string, ErrCode int, err error) { +func (d Device) Shell(ctx context.Context, command string) (stdout string, stderr string, errCode int, err error) { cmd, err := shlex.Split(command) if err != nil { return "", "", 1, err } prefix := []string{"-s", string(d.SerialNo), "shell"} cmd = append(prefix, cmd...) - stdout, stderr, errcode, err := execute(ctx, cmd) - return stdout, stderr, errcode, err + return execute(ctx, cmd) } // adb shell wm size diff --git a/shell_test.go b/shell_test.go index 1d8c50f..d4f4f59 100644 --- a/shell_test.go +++ b/shell_test.go @@ -18,7 +18,9 @@ func Test_parseScreenResolution(t *testing.T) { {name: "Pixel 4XL", args: args{in: "Physical size: 1440x3040"}, wantRes: Resolution{Width: 1440, Height: 3040}, wantErr: false}, {name: "Pixel XL", args: args{in: "Physical size: 1440x2560"}, wantRes: Resolution{Width: 1440, Height: 2560}, wantErr: false}, {name: "garbage", args: args{in: "asdfhjkla"}, wantRes: Resolution{Width: -1, Height: -1}, wantErr: true}, - // TODO: Add test cases. + {name: "empty string", args: args{in: ""}, wantRes: Resolution{Width: -1, Height: -1}, wantErr: true}, + {name: "Samsung S21", args: args{in: "Physical size: 1080x2400"}, wantRes: Resolution{Width: 1080, Height: 2400}, wantErr: false}, + {name: "override size", args: args{in: "Physical size: 1440x3120\nOverride size: 1080x2340"}, wantRes: Resolution{Width: 1440, Height: 3120}, wantErr: false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/util.go b/util.go index 8f879e7..2d0da1d 100644 --- a/util.go +++ b/util.go @@ -4,16 +4,18 @@ import ( "bytes" "context" "fmt" - "log" "os/exec" + "sync" ) -var adb string +var ( + adb string + adbOnce sync.Once +) -func init() { +func findADB() { path, err := exec.LookPath("adb") if err != nil { - log.Printf("%v", ErrNotInstalled) adb = "" return } @@ -21,25 +23,24 @@ func init() { } func execute(ctx context.Context, args []string) (string, string, int, error) { - var ( - err error - stderr bytes.Buffer - stdout bytes.Buffer - code int - output string - warnings string - ) + adbOnce.Do(findADB) if adb == "" { - panic(ErrNotInstalled) + return "", "", -1, ErrNotInstalled } + + var ( + stderr bytes.Buffer + stdout bytes.Buffer + ) + cmd := exec.CommandContext(ctx, adb, args...) cmd.Stdout = &stdout cmd.Stderr = &stderr - err = cmd.Run() - output = stdout.String() - warnings = stderr.String() - code = cmd.ProcessState.ExitCode() + err := cmd.Run() + output := stdout.String() + warnings := stderr.String() + code := cmd.ProcessState.ExitCode() customErr := filterErr(warnings) if customErr != nil { diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..43fb30c --- /dev/null +++ b/util_test.go @@ -0,0 +1,24 @@ +package adb + +import ( + "testing" +) + +func Test_filterErr(t *testing.T) { + tests := []struct { + name string + stderr string + wantErr bool + }{ + {name: "empty stderr", stderr: "", wantErr: false}, + {name: "random output", stderr: "some warning text", wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := filterErr(tt.stderr) + if (err != nil) != tt.wantErr { + t.Errorf("filterErr() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}