diff --git a/kernel/mem/vmm/walk.go b/kernel/mem/vmm/walk.go new file mode 100644 index 0000000..8c38469 --- /dev/null +++ b/kernel/mem/vmm/walk.go @@ -0,0 +1,56 @@ +package vmm + +import ( + "unsafe" + + "github.com/achilleasa/gopher-os/kernel/mem" +) + +var ( + // ptePointerFn returns a pointer to the supplied entry address. It is + // used by tests to override the generated page table entry pointers so + // walk() can be properly tested. When compiling the kernel this function + // will be automatically inlined. + ptePtrFn = func(entryAddr uintptr) unsafe.Pointer { + return unsafe.Pointer(entryAddr) + } +) + +// pageTableWalker is a function that can be passed to the walk method. The +// function receives the current page level and page table entry as its +// arguments. If the function returns false, then the page walk is aborted. +type pageTableWalker func(pteLevel uint8, pte *pageTableEntry) bool + +// walk performs a page table walk for the given virtual address. It calls the +// suppplied walkFn with the page table entry that corresponds to each page +// table level. If walkFn returns an error then the walk is aborted and the +// error is returned to the caller. +func walk(virtAddr uintptr, walkFn pageTableWalker) { + var ( + level uint8 + tableAddr, entryAddr, entryIndex uintptr + ok bool + ) + + // tableAddr is initially set to the recursively mapped virtual address for the + // last entry in the top-most page table. Dereferencing a pointer to this address + // will allow us to access + for level, tableAddr = uint8(0), pdtVirtualAddr; level < pageLevels; level, tableAddr = level+1, entryAddr { + // Extract the bits from virtual address that correspond to the + // index in this level's page table + entryIndex = (virtAddr >> pageLevelShifts[level]) & ((1 << pageLevelBits[level]) - 1) + + // By shifting the table virtual address left by pageLevelShifts[level] we add + // a new level of indirection to our recursive mapping allowing us to access + // the table pointed to by the page entry + entryAddr = tableAddr + (entryIndex << mem.PointerShift) + + if ok = walkFn(level, (*pageTableEntry)(ptePtrFn(entryAddr))); !ok { + return + } + + // Shift left by the number of bits for this paging level to get + // the virtual address of the table pointed to by entryAddr + entryAddr <<= pageLevelBits[level] + } +} diff --git a/kernel/mem/vmm/walk_test.go b/kernel/mem/vmm/walk_test.go new file mode 100644 index 0000000..ee235eb --- /dev/null +++ b/kernel/mem/vmm/walk_test.go @@ -0,0 +1,76 @@ +package vmm + +import ( + "runtime" + "testing" + "unsafe" + + "github.com/achilleasa/gopher-os/kernel/mem" +) + +func TestPtePtrFn(t *testing.T) { + // Dummy test to keep coverage happy + if exp, got := unsafe.Pointer(uintptr(123)), ptePtrFn(uintptr(123)); exp != got { + t.Fatalf("expected ptePtrFn to return %v; got %v", exp, got) + } +} + +func TestWalkAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origPtePtr func(uintptr) unsafe.Pointer) { + ptePtrFn = origPtePtr + }(ptePtrFn) + + // This address breaks down to: + // p4 index: 1 + // p3 index: 2 + // p2 index: 3 + // p1 index: 4 + // offset : 1024 + targetAddr := uintptr(0x8080604400) + + sizeofPteEntry := uintptr(unsafe.Sizeof(pageTableEntry(0))) + expEntryAddrBits := [pageLevels][pageLevels + 1]uintptr{ + {511, 511, 511, 511, 1 * sizeofPteEntry}, + {511, 511, 511, 1, 2 * sizeofPteEntry}, + {511, 511, 1, 2, 3 * sizeofPteEntry}, + {511, 1, 2, 3, 4 * sizeofPteEntry}, + } + + pteCallCount := 0 + ptePtrFn = func(entry uintptr) unsafe.Pointer { + if pteCallCount >= pageLevels { + t.Fatalf("unexpected call to ptePtrFn; already called %d times", pageLevels) + } + + for i := 0; i < pageLevels; i++ { + pteIndex := (entry >> pageLevelShifts[i]) & ((1 << pageLevelBits[i]) - 1) + if pteIndex != expEntryAddrBits[pteCallCount][i] { + t.Errorf("[ptePtrFn call %d] expected pte entry for level %d to use offset %d; got %d", pteCallCount, i, expEntryAddrBits[pteCallCount][i], pteIndex) + } + } + + // Check the page offset + pteIndex := entry & ((1 << mem.PageShift) - 1) + if pteIndex != expEntryAddrBits[pteCallCount][pageLevels] { + t.Errorf("[ptePtrFn call %d] expected pte offset to be %d; got %d", pteCallCount, expEntryAddrBits[pteCallCount][pageLevels], pteIndex) + } + + pteCallCount++ + + return unsafe.Pointer(uintptr(0xf00)) + } + + walkFnCallCount := 0 + walk(targetAddr, func(level uint8, entry *pageTableEntry) bool { + walkFnCallCount++ + return walkFnCallCount != pageLevels + }) + + if pteCallCount != pageLevels { + t.Errorf("expected ptePtrFn to be called %d times; got %d", pageLevels, pteCallCount) + } +}