diff --git a/arch/x86_64/asm/rt0_32.s b/arch/x86_64/asm/rt0_32.s index b07a0e1..4e6fc3f 100644 --- a/arch/x86_64/asm/rt0_32.s +++ b/arch/x86_64/asm/rt0_32.s @@ -247,6 +247,12 @@ _rt0_populate_initial_page_tables: or eax, PAGE_PRESENT | PAGE_WRITABLE mov ebx, page_table_l4 - PAGE_OFFSET mov [ebx], eax + + ; Recursively map the last P4 entry to itself. This allows us to use + ; specially crafted memory addresses to access the page tables themselves + mov ecx, ebx + or ecx, PAGE_PRESENT | PAGE_WRITABLE + mov [ebx + 511*8], ecx ; Also map the addresses starting at PAGE_OFFSET to the same P3 table. ; To find the P4 index for PAGE_OFFSET we need to extract bits 39-47 diff --git a/kernel/mem/vmm/constants_amd64.go b/kernel/mem/vmm/constants_amd64.go new file mode 100644 index 0000000..5e3b95d --- /dev/null +++ b/kernel/mem/vmm/constants_amd64.go @@ -0,0 +1,85 @@ +// +build amd64 + +package vmm + +import "math" + +const ( + // pageLevels indicates the number of page levels supported by the amd64 architecture. + pageLevels = 4 + + // ptePhysPageMask is a mask that allows us to extract the physical memory + // address pointed to by a page table entry. For this particular architecture, + // bits 12-51 contain the physical memory address. + ptePhysPageMask = uintptr(0x000ffffffffff000) + + // tempMappingAddr is a reserved virtual page address used for + // temporary physical page mappings (e.g. when mapping inactive PDT + // pages). For amd64 this address uses the following table indices: + // 510, 511, 511, 511. + tempMappingAddr = uintptr(0Xffffff7ffffff000) +) + +var ( + // pdtVirtualAddr is a special virtual address that exploits the + // recursive mapping used in the last PDT entry for each page directory + // to allow accessing the PDT (P4) table using the system's MMU address + // translation mechanism. By setting all page level bits to 1 the MMU + // keeps following the last P4 entry for all page levels landing on the + // P4. + pdtVirtualAddr = uintptr(math.MaxUint64 &^ ((1 << 12) - 1)) + + // pageLevelBits defines the number of virtual address bits that correspond to each + // page level. For the amd64 architecture each PageLevel uses 9 bits which amounts to + // 512 entries for each page level. + pageLevelBits = [pageLevels]uint8{ + 9, + 9, + 9, + 9, + } + + // pageLevelShifts defines the shift required to access each page table component + // of a virtual address. + pageLevelShifts = [pageLevels]uint8{ + 39, + 30, + 21, + 12, + } +) + +const ( + // FlagPresent is set when the page is available in memory and not swapped out. + FlagPresent PageTableEntryFlag = 1 << iota + + // FlagRW is set if the page can be written to. + FlagRW + + // FlagUserAccessible is set if user-mode processes can access this page. If + // not set only kernel code can access this page. + FlagUserAccessible + + // FlagWriteThroughCaching implies write-through caching when set and write-back + // caching if cleared. + FlagWriteThroughCaching + + // FlagDoNotCache prevents this page from being cached if set. + FlagDoNotCache + + // FlagAccessed is set by the CPU when this page is accessed. + FlagAccessed + + // FlagDirty is set by the CPU when this page is modified. + FlagDirty + + // FlagHugePage is set if when using 2Mb pages instead of 4K pages. + FlagHugePage + + // FlagGlobal if set, prevents the TLB from flushing the cached memory address + // for this page when the swapping page tables by updating the CR3 register. + FlagGlobal + + // FlagNoExecute if set, indicates that a page contains non-executable code. + FlagNoExecute = 1 << 63 +) diff --git a/kernel/mem/vmm/map.go b/kernel/mem/vmm/map.go new file mode 100644 index 0000000..bd8781c --- /dev/null +++ b/kernel/mem/vmm/map.go @@ -0,0 +1,117 @@ +package vmm + +import ( + "unsafe" + + "github.com/achilleasa/gopher-os/kernel" + "github.com/achilleasa/gopher-os/kernel/mem" + "github.com/achilleasa/gopher-os/kernel/mem/pmm" +) + +var ( + // nextAddrFn is used by used by tests to override the nextTableAddr + // calculations used by Map. When compiling the kernel this function + // will be automatically inlined. + nextAddrFn = func(entryAddr uintptr) uintptr { + return entryAddr + } + + // flushTLBEntryFn is used by tests to override calls to flushTLBEntry + // which will cause a fault if called in user-mode. + flushTLBEntryFn = flushTLBEntry + + errNoHugePageSupport = &kernel.Error{Module: "vmm", Message: "huge pages are not supported"} +) + +// FrameAllocator is a function that can allocate physical frames of the specified order. +type FrameAllocator func(mem.PageOrder) (pmm.Frame, *kernel.Error) + +// Map establishes a mapping between a virtual page and a physical memory frame +// using the currently active page directory table. Calls to Map will use the +// supplied physical frame allocator to initialize missing page tables at each +// paging level supported by the MMU. +func Map(page Page, frame pmm.Frame, flags PageTableEntryFlag, allocFn FrameAllocator) *kernel.Error { + var err *kernel.Error + + walk(page.Address(), func(pteLevel uint8, pte *pageTableEntry) bool { + // If we reached the last level all we need to do is to map the + // frame in place and flag it as present and flush its TLB entry + if pteLevel == pageLevels-1 { + *pte = 0 + pte.SetFrame(frame) + pte.SetFlags(FlagPresent | flags) + flushTLBEntryFn(page.Address()) + return true + } + + if pte.HasFlags(FlagHugePage) { + err = errNoHugePageSupport + return false + } + + // Next table does not yet exist; we need to allocate a + // physical frame for it map it and clear its contents. + if !pte.HasFlags(FlagPresent) { + var newTableFrame pmm.Frame + newTableFrame, err = allocFn(mem.PageOrder(0)) + if err != nil { + return false + } + + *pte = 0 + pte.SetFrame(newTableFrame) + pte.SetFlags(FlagPresent | FlagRW) + + // The next pte entry becomes available but we need to + // make sure that the new page is properly cleared + nextTableAddr := (uintptr(unsafe.Pointer(pte)) << pageLevelBits[pteLevel+1]) + mem.Memset(nextAddrFn(nextTableAddr), 0, mem.PageSize) + } + + return true + }) + + return err +} + +// MapTemporary establishes a temporary RW mapping of a physical memory frame +// to a fixed virtual address overwriting any previous mapping. The temporary +// mapping mechanism is primarily used by the kernel to access and initialize +// inactive page tables. +func MapTemporary(frame pmm.Frame, allocFn FrameAllocator) (Page, *kernel.Error) { + if err := Map(PageFromAddress(tempMappingAddr), frame, FlagRW, allocFn); err != nil { + return 0, err + } + + return PageFromAddress(tempMappingAddr), nil +} + +// Unmap removes a mapping previously installed via a call to Map or MapTemporary. +func Unmap(page Page) *kernel.Error { + var err *kernel.Error + + walk(page.Address(), func(pteLevel uint8, pte *pageTableEntry) bool { + // If we reached the last level all we need to do is to set the + // page as non-present and flush its TLB entry + if pteLevel == pageLevels-1 { + pte.ClearFlags(FlagPresent) + flushTLBEntryFn(page.Address()) + return true + } + + // Next table is not present; this is an invalid mapping + if !pte.HasFlags(FlagPresent) { + err = ErrInvalidMapping + return false + } + + if pte.HasFlags(FlagHugePage) { + err = errNoHugePageSupport + return false + } + + return true + }) + + return err +} diff --git a/kernel/mem/vmm/map_test.go b/kernel/mem/vmm/map_test.go new file mode 100644 index 0000000..ee03a3e --- /dev/null +++ b/kernel/mem/vmm/map_test.go @@ -0,0 +1,251 @@ +package vmm + +import ( + "runtime" + "testing" + "unsafe" + + "github.com/achilleasa/gopher-os/kernel" + "github.com/achilleasa/gopher-os/kernel/mem" + "github.com/achilleasa/gopher-os/kernel/mem/pmm" +) + +func TestNextAddrFn(t *testing.T) { + // Dummy test to keep coverage happy + if exp, got := uintptr(123), nextAddrFn(uintptr(123)); exp != got { + t.Fatalf("expected nextAddrFn to return %v; got %v", exp, got) + } +} + +func TestMapTemporaryAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origPtePtr func(uintptr) unsafe.Pointer, origNextAddrFn func(uintptr) uintptr, origFlushTLBEntryFn func(uintptr)) { + ptePtrFn = origPtePtr + nextAddrFn = origNextAddrFn + flushTLBEntryFn = origFlushTLBEntryFn + }(ptePtrFn, nextAddrFn, flushTLBEntryFn) + + var physPages [pageLevels][mem.PageSize >> mem.PointerShift]pageTableEntry + nextPhysPage := 0 + + // allocFn returns pages from index 1; we keep index 0 for the P4 entry + allocFn := func(_ mem.PageOrder) (pmm.Frame, *kernel.Error) { + nextPhysPage++ + pageAddr := unsafe.Pointer(&physPages[nextPhysPage][0]) + return pmm.Frame(uintptr(pageAddr) >> mem.PageShift), nil + } + + pteCallCount := 0 + ptePtrFn = func(entry uintptr) unsafe.Pointer { + pteCallCount++ + // The last 12 bits encode the page table offset in bytes + // which we need to convert to a uint64 entry + pteIndex := (entry & uintptr(mem.PageSize-1)) >> mem.PointerShift + return unsafe.Pointer(&physPages[pteCallCount-1][pteIndex]) + } + + nextAddrFn = func(entry uintptr) uintptr { + return uintptr(unsafe.Pointer(&physPages[nextPhysPage][0])) + } + + flushTLBEntryCallCount := 0 + flushTLBEntryFn = func(uintptr) { + flushTLBEntryCallCount++ + } + + // The temporary mappin address breaks down to: + // p4 index: 510 + // p3 index: 511 + // p2 index: 511 + // p1 index: 511 + frame := pmm.Frame(123) + levelIndices := []uint{510, 511, 511, 511} + + page, err := MapTemporary(frame, allocFn) + if err != nil { + t.Fatal(err) + } + + if got := page.Address(); got != tempMappingAddr { + t.Fatalf("expected temp mapping virtual address to be %x; got %x", tempMappingAddr, got) + } + + for level, physPage := range physPages { + pte := physPage[levelIndices[level]] + if !pte.HasFlags(FlagPresent | FlagRW) { + t.Errorf("[pte at level %d] expected entry to have FlagPresent and FlagRW set", level) + } + + switch { + case level < pageLevels-1: + if exp, got := pmm.Frame(uintptr(unsafe.Pointer(&physPages[level+1][0]))>>mem.PageShift), pte.Frame(); got != exp { + t.Errorf("[pte at level %d] expected entry frame to be %d; got %d", level, exp, got) + } + default: + // The last pte entry should point to frame + if got := pte.Frame(); got != frame { + t.Errorf("[pte at level %d] expected entry frame to be %d; got %d", level, frame, got) + } + } + } + + if exp := 1; flushTLBEntryCallCount != exp { + t.Errorf("expected flushTLBEntry to be called %d times; got %d", exp, flushTLBEntryCallCount) + } +} + +func TestMapTemporaryErrorsAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origPtePtr func(uintptr) unsafe.Pointer, origNextAddrFn func(uintptr) uintptr, origFlushTLBEntryFn func(uintptr)) { + ptePtrFn = origPtePtr + nextAddrFn = origNextAddrFn + flushTLBEntryFn = origFlushTLBEntryFn + }(ptePtrFn, nextAddrFn, flushTLBEntryFn) + + var physPages [pageLevels][mem.PageSize >> mem.PointerShift]pageTableEntry + + // The reserved virt address uses the following page level indices: 510, 511, 511, 511 + p4Index := 510 + frame := pmm.Frame(123) + + t.Run("encounter huge page", func(t *testing.T) { + physPages[0][p4Index].SetFlags(FlagPresent | FlagHugePage) + + ptePtrFn = func(entry uintptr) unsafe.Pointer { + // The last 12 bits encode the page table offset in bytes + // which we need to convert to a uint64 entry + pteIndex := (entry & uintptr(mem.PageSize-1)) >> mem.PointerShift + return unsafe.Pointer(&physPages[0][pteIndex]) + } + + if _, err := MapTemporary(frame, nil); err != errNoHugePageSupport { + t.Fatalf("expected to get errNoHugePageSupport; got %v", err) + } + }) + + t.Run("allocFn returns an error", func(t *testing.T) { + physPages[0][p4Index] = 0 + + expErr := &kernel.Error{Module: "test", Message: "out of memory"} + + allocFn := func(_ mem.PageOrder) (pmm.Frame, *kernel.Error) { + return 0, expErr + } + + if _, err := MapTemporary(frame, allocFn); err != expErr { + t.Fatalf("got unexpected error %v", err) + } + }) +} + +func TestUnmapAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origPtePtr func(uintptr) unsafe.Pointer, origFlushTLBEntryFn func(uintptr)) { + ptePtrFn = origPtePtr + flushTLBEntryFn = origFlushTLBEntryFn + }(ptePtrFn, flushTLBEntryFn) + + var ( + physPages [pageLevels][mem.PageSize >> mem.PointerShift]pageTableEntry + frame = pmm.Frame(123) + ) + + // Emulate a page mapped to virtAddr 0 across all page levels + for level := 0; level < pageLevels; level++ { + physPages[level][0].SetFlags(FlagPresent | FlagRW) + if level < pageLevels-1 { + physPages[level][0].SetFrame(pmm.Frame(uintptr(unsafe.Pointer(&physPages[level+1][0])) >> mem.PageShift)) + } else { + physPages[level][0].SetFrame(frame) + + } + } + + pteCallCount := 0 + ptePtrFn = func(entry uintptr) unsafe.Pointer { + pteCallCount++ + return unsafe.Pointer(&physPages[pteCallCount-1][0]) + } + + flushTLBEntryCallCount := 0 + flushTLBEntryFn = func(uintptr) { + flushTLBEntryCallCount++ + } + + if err := Unmap(PageFromAddress(0)); err != nil { + t.Fatal(err) + } + + for level, physPage := range physPages { + pte := physPage[0] + + switch { + case level < pageLevels-1: + if !pte.HasFlags(FlagPresent) { + t.Errorf("[pte at level %d] expected entry to retain have FlagPresent set", level) + } + if exp, got := pmm.Frame(uintptr(unsafe.Pointer(&physPages[level+1][0]))>>mem.PageShift), pte.Frame(); got != exp { + t.Errorf("[pte at level %d] expected entry frame to still be %d; got %d", level, exp, got) + } + default: + if pte.HasFlags(FlagPresent) { + t.Errorf("[pte at level %d] expected entry not to have FlagPresent set", level) + } + + // The last pte entry should still point to frame + if got := pte.Frame(); got != frame { + t.Errorf("[pte at level %d] expected entry frame to be %d; got %d", level, frame, got) + } + } + } + + if exp := 1; flushTLBEntryCallCount != exp { + t.Errorf("expected flushTLBEntry to be called %d times; got %d", exp, flushTLBEntryCallCount) + } +} + +func TestUnmapErrorsAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origPtePtr func(uintptr) unsafe.Pointer, origNextAddrFn func(uintptr) uintptr, origFlushTLBEntryFn func(uintptr)) { + ptePtrFn = origPtePtr + nextAddrFn = origNextAddrFn + flushTLBEntryFn = origFlushTLBEntryFn + }(ptePtrFn, nextAddrFn, flushTLBEntryFn) + + var physPages [pageLevels][mem.PageSize >> mem.PointerShift]pageTableEntry + + t.Run("encounter huge page", func(t *testing.T) { + physPages[0][0].SetFlags(FlagPresent | FlagHugePage) + + ptePtrFn = func(entry uintptr) unsafe.Pointer { + // The last 12 bits encode the page table offset in bytes + // which we need to convert to a uint64 entry + pteIndex := (entry & uintptr(mem.PageSize-1)) >> mem.PointerShift + return unsafe.Pointer(&physPages[0][pteIndex]) + } + + if err := Unmap(PageFromAddress(0)); err != errNoHugePageSupport { + t.Fatalf("expected to get errNoHugePageSupport; got %v", err) + } + }) + + t.Run("virtual address not mapped", func(t *testing.T) { + physPages[0][0].ClearFlags(FlagPresent) + + if err := Unmap(PageFromAddress(0)); err != ErrInvalidMapping { + t.Fatalf("expected to get ErrInvalidMapping; got %v", err) + } + }) +} diff --git a/kernel/mem/vmm/page.go b/kernel/mem/vmm/page.go new file mode 100644 index 0000000..f1e9cb4 --- /dev/null +++ b/kernel/mem/vmm/page.go @@ -0,0 +1,19 @@ +package vmm + +import "github.com/achilleasa/gopher-os/kernel/mem" + +// Page describes a virtual memory page index. +type Page uintptr + +// Address returns a pointer to the virtual memory address pointed to by this Page. +func (f Page) Address() uintptr { + return uintptr(f << mem.PageShift) +} + +// PageFromAddress returns a Page that corresponds to the given virtual +// address. This function can handle both page-aligned and not aligned virtual +// addresses. in the latter case, the input address will be rounded down to the +// page that contains it. +func PageFromAddress(virtAddr uintptr) Page { + return Page((virtAddr & ^(uintptr(mem.PageSize - 1))) >> mem.PageShift) +} diff --git a/kernel/mem/vmm/page_test.go b/kernel/mem/vmm/page_test.go new file mode 100644 index 0000000..8786e97 --- /dev/null +++ b/kernel/mem/vmm/page_test.go @@ -0,0 +1,35 @@ +package vmm + +import ( + "testing" + + "github.com/achilleasa/gopher-os/kernel/mem" +) + +func TestPageMethods(t *testing.T) { + for pageIndex := uint64(0); pageIndex < 128; pageIndex++ { + page := Page(pageIndex) + + if exp, got := uintptr(pageIndex<> mem.PageShift) + lastPdtEntryAddr uintptr + lastPdtEntry *pageTableEntry + ) + // If this table is not active we need to temporarily map it to the + // last entry in the active PDT so we can access it using the recursive + // virtual address scheme. + if activePdtFrame != pdt.pdtFrame { + lastPdtEntryAddr = activePdtFrame.Address() + (((1 << pageLevelBits[0]) - 1) << mem.PointerShift) + lastPdtEntry = (*pageTableEntry)(unsafe.Pointer(lastPdtEntryAddr)) + lastPdtEntry.SetFrame(pdt.pdtFrame) + flushTLBEntryFn(lastPdtEntryAddr) + } + + err := mapFn(page, frame, flags, allocFn) + + if activePdtFrame != pdt.pdtFrame { + lastPdtEntry.SetFrame(activePdtFrame) + flushTLBEntryFn(lastPdtEntryAddr) + } + + return err +} + +// Unmap removes a mapping previousle installed by a call to Map() on this PDT. +// This method behaves in a similar fashion to the global Unmap() function with +// the difference that it also supports inactive page PDTs by establishing a +// temporary mapping so that Unmap() can access the inactive PDT entries. +func (pdt PageDirectoryTable) Unmap(page Page) *kernel.Error { + var ( + activePdtFrame = pmm.Frame(activePDTFn() >> mem.PageShift) + lastPdtEntryAddr uintptr + lastPdtEntry *pageTableEntry + ) + // If this table is not active we need to temporarily map it to the + // last entry in the active PDT so we can access it using the recursive + // virtual address scheme. + if activePdtFrame != pdt.pdtFrame { + lastPdtEntryAddr = activePdtFrame.Address() + (((1 << pageLevelBits[0]) - 1) << mem.PointerShift) + lastPdtEntry = (*pageTableEntry)(unsafe.Pointer(lastPdtEntryAddr)) + lastPdtEntry.SetFrame(pdt.pdtFrame) + flushTLBEntryFn(lastPdtEntryAddr) + } + + err := unmapFn(page) + + if activePdtFrame != pdt.pdtFrame { + lastPdtEntry.SetFrame(activePdtFrame) + flushTLBEntryFn(lastPdtEntryAddr) + } + + return err +} + +// Activate enables this page directory table and flushes the TLB +func (pdt PageDirectoryTable) Activate() { + switchPDTFn(pdt.pdtFrame.Address()) +} diff --git a/kernel/mem/vmm/pdt_test.go b/kernel/mem/vmm/pdt_test.go new file mode 100644 index 0000000..64172b7 --- /dev/null +++ b/kernel/mem/vmm/pdt_test.go @@ -0,0 +1,331 @@ +package vmm + +import ( + "runtime" + "testing" + "unsafe" + + "github.com/achilleasa/gopher-os/kernel" + "github.com/achilleasa/gopher-os/kernel/mem" + "github.com/achilleasa/gopher-os/kernel/mem/pmm" +) + +func TestPageDirectoryTableInitAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origFlushTLBEntry func(uintptr), origActivePDT func() uintptr, origMapTemporary func(pmm.Frame, FrameAllocator) (Page, *kernel.Error), origUnmap func(Page) *kernel.Error) { + flushTLBEntryFn = origFlushTLBEntry + activePDTFn = origActivePDT + mapTemporaryFn = origMapTemporary + unmapFn = origUnmap + }(flushTLBEntryFn, activePDTFn, mapTemporaryFn, unmapFn) + + t.Run("already mapped PDT", func(t *testing.T) { + var ( + pdt PageDirectoryTable + pdtFrame = pmm.Frame(123) + ) + + activePDTFn = func() uintptr { + return pdtFrame.Address() + } + + mapTemporaryFn = func(_ pmm.Frame, _ FrameAllocator) (Page, *kernel.Error) { + t.Fatal("unexpected call to MapTemporary") + return 0, nil + } + + unmapFn = func(_ Page) *kernel.Error { + t.Fatal("unexpected call to Unmap") + return nil + } + + if err := pdt.Init(pdtFrame, nil); err != nil { + t.Fatal(err) + } + }) + + t.Run("not mapped PDT", func(t *testing.T) { + var ( + pdt PageDirectoryTable + pdtFrame = pmm.Frame(123) + physPage [mem.PageSize >> mem.PointerShift]pageTableEntry + ) + + // Fill phys page with random junk + mem.Memset(uintptr(unsafe.Pointer(&physPage[0])), 0xf0, mem.PageSize) + + activePDTFn = func() uintptr { + return 0 + } + + mapTemporaryFn = func(_ pmm.Frame, _ FrameAllocator) (Page, *kernel.Error) { + return PageFromAddress(uintptr(unsafe.Pointer(&physPage[0]))), nil + } + + flushTLBEntryFn = func(_ uintptr) {} + + unmapCallCount := 0 + unmapFn = func(_ Page) *kernel.Error { + unmapCallCount++ + return nil + } + + if err := pdt.Init(pdtFrame, nil); err != nil { + t.Fatal(err) + } + + if unmapCallCount != 1 { + t.Fatalf("expected Unmap to be called 1 time; called %d", unmapCallCount) + } + + for i := 0; i < len(physPage)-1; i++ { + if physPage[i] != 0 { + t.Errorf("expected PDT entry %d to be cleared; got %x", i, physPage[i]) + } + } + + // The last page should be recursively mapped to the PDT + lastPdtEntry := physPage[len(physPage)-1] + if !lastPdtEntry.HasFlags(FlagPresent | FlagRW) { + t.Fatal("expected last PDT entry to have FlagPresent and FlagRW set") + } + + if lastPdtEntry.Frame() != pdtFrame { + t.Fatalf("expected last PDT entry to be recursively mapped to physical frame %x; got %x", pdtFrame, lastPdtEntry.Frame()) + } + }) + + t.Run("temporary mapping failure", func(t *testing.T) { + var ( + pdt PageDirectoryTable + pdtFrame = pmm.Frame(123) + ) + + activePDTFn = func() uintptr { + return 0 + } + + expErr := &kernel.Error{Module: "test", Message: "error mapping page"} + + mapTemporaryFn = func(_ pmm.Frame, _ FrameAllocator) (Page, *kernel.Error) { + return 0, expErr + } + + unmapFn = func(_ Page) *kernel.Error { + t.Fatal("unexpected call to Unmap") + return nil + } + + if err := pdt.Init(pdtFrame, nil); err != expErr { + t.Fatalf("expected to get error: %v; got %v", *expErr, err) + } + }) +} + +func TestPageDirectoryTableMapAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origFlushTLBEntry func(uintptr), origActivePDT func() uintptr, origMap func(Page, pmm.Frame, PageTableEntryFlag, FrameAllocator) *kernel.Error) { + flushTLBEntryFn = origFlushTLBEntry + activePDTFn = origActivePDT + mapFn = origMap + }(flushTLBEntryFn, activePDTFn, mapFn) + + t.Run("already mapped PDT", func(t *testing.T) { + var ( + pdtFrame = pmm.Frame(123) + pdt = PageDirectoryTable{pdtFrame: pdtFrame} + page = PageFromAddress(uintptr(100 * mem.Mb)) + ) + + activePDTFn = func() uintptr { + return pdtFrame.Address() + } + + mapFn = func(_ Page, _ pmm.Frame, _ PageTableEntryFlag, _ FrameAllocator) *kernel.Error { + return nil + } + + flushCallCount := 0 + flushTLBEntryFn = func(_ uintptr) { + flushCallCount++ + } + + if err := pdt.Map(page, pmm.Frame(321), FlagRW, nil); err != nil { + t.Fatal(err) + } + + if exp := 0; flushCallCount != exp { + t.Fatalf("expected flushTLBEntry to be called %d times; called %d", exp, flushCallCount) + } + }) + + t.Run("not mapped PDT", func(t *testing.T) { + var ( + pdtFrame = pmm.Frame(123) + pdt = PageDirectoryTable{pdtFrame: pdtFrame} + page = PageFromAddress(uintptr(100 * mem.Mb)) + activePhysPage [mem.PageSize >> mem.PointerShift]pageTableEntry + activePdtFrame = pmm.Frame(uintptr(unsafe.Pointer(&activePhysPage[0])) >> mem.PageShift) + ) + + // Initially, activePhysPage is recursively mapped to itself + activePhysPage[len(activePhysPage)-1].SetFlags(FlagPresent | FlagRW) + activePhysPage[len(activePhysPage)-1].SetFrame(activePdtFrame) + + activePDTFn = func() uintptr { + return activePdtFrame.Address() + } + + mapFn = func(_ Page, _ pmm.Frame, _ PageTableEntryFlag, _ FrameAllocator) *kernel.Error { + return nil + } + + flushCallCount := 0 + flushTLBEntryFn = func(_ uintptr) { + switch flushCallCount { + case 0: + // the first time we flush the tlb entry, the last entry of + // the active pdt should be pointing to pdtFrame + if got := activePhysPage[len(activePhysPage)-1].Frame(); got != pdtFrame { + t.Fatalf("expected last PDT entry of active PDT to be re-mapped to frame %x; got %x", pdtFrame, got) + } + case 1: + // the second time we flush the tlb entry, the last entry of + // the active pdt should be pointing back to activePdtFrame + if got := activePhysPage[len(activePhysPage)-1].Frame(); got != activePdtFrame { + t.Fatalf("expected last PDT entry of active PDT to be mapped back frame %x; got %x", activePdtFrame, got) + } + } + flushCallCount++ + } + + if err := pdt.Map(page, pmm.Frame(321), FlagRW, nil); err != nil { + t.Fatal(err) + } + + if exp := 2; flushCallCount != exp { + t.Fatalf("expected flushTLBEntry to be called %d times; called %d", exp, flushCallCount) + } + }) +} + +func TestPageDirectoryTableUnmapAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origFlushTLBEntry func(uintptr), origActivePDT func() uintptr, origUnmap func(Page) *kernel.Error) { + flushTLBEntryFn = origFlushTLBEntry + activePDTFn = origActivePDT + unmapFn = origUnmap + }(flushTLBEntryFn, activePDTFn, unmapFn) + + t.Run("already mapped PDT", func(t *testing.T) { + var ( + pdtFrame = pmm.Frame(123) + pdt = PageDirectoryTable{pdtFrame: pdtFrame} + page = PageFromAddress(uintptr(100 * mem.Mb)) + ) + + activePDTFn = func() uintptr { + return pdtFrame.Address() + } + + unmapFn = func(_ Page) *kernel.Error { + return nil + } + + flushCallCount := 0 + flushTLBEntryFn = func(_ uintptr) { + flushCallCount++ + } + + if err := pdt.Unmap(page); err != nil { + t.Fatal(err) + } + + if exp := 0; flushCallCount != exp { + t.Fatalf("expected flushTLBEntry to be called %d times; called %d", exp, flushCallCount) + } + }) + + t.Run("not mapped PDT", func(t *testing.T) { + var ( + pdtFrame = pmm.Frame(123) + pdt = PageDirectoryTable{pdtFrame: pdtFrame} + page = PageFromAddress(uintptr(100 * mem.Mb)) + activePhysPage [mem.PageSize >> mem.PointerShift]pageTableEntry + activePdtFrame = pmm.Frame(uintptr(unsafe.Pointer(&activePhysPage[0])) >> mem.PageShift) + ) + + // Initially, activePhysPage is recursively mapped to itself + activePhysPage[len(activePhysPage)-1].SetFlags(FlagPresent | FlagRW) + activePhysPage[len(activePhysPage)-1].SetFrame(activePdtFrame) + + activePDTFn = func() uintptr { + return activePdtFrame.Address() + } + + unmapFn = func(_ Page) *kernel.Error { + return nil + } + + flushCallCount := 0 + flushTLBEntryFn = func(_ uintptr) { + switch flushCallCount { + case 0: + // the first time we flush the tlb entry, the last entry of + // the active pdt should be pointing to pdtFrame + if got := activePhysPage[len(activePhysPage)-1].Frame(); got != pdtFrame { + t.Fatalf("expected last PDT entry of active PDT to be re-mapped to frame %x; got %x", pdtFrame, got) + } + case 1: + // the second time we flush the tlb entry, the last entry of + // the active pdt should be pointing back to activePdtFrame + if got := activePhysPage[len(activePhysPage)-1].Frame(); got != activePdtFrame { + t.Fatalf("expected last PDT entry of active PDT to be mapped back frame %x; got %x", activePdtFrame, got) + } + } + flushCallCount++ + } + + if err := pdt.Unmap(page); err != nil { + t.Fatal(err) + } + + if exp := 2; flushCallCount != exp { + t.Fatalf("expected flushTLBEntry to be called %d times; called %d", exp, flushCallCount) + } + }) +} + +func TestPageDirectoryTableActivateAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origSwitchPDT func(uintptr)) { + switchPDTFn = origSwitchPDT + }(switchPDTFn) + + var ( + pdtFrame = pmm.Frame(123) + pdt = PageDirectoryTable{pdtFrame: pdtFrame} + ) + + switchPDTCallCount := 0 + switchPDTFn = func(_ uintptr) { + switchPDTCallCount++ + } + + pdt.Activate() + if exp := 1; switchPDTCallCount != exp { + t.Fatalf("expected switchPDT to be called %d times; called %d", exp, switchPDTCallCount) + } +} diff --git a/kernel/mem/vmm/pte.go b/kernel/mem/vmm/pte.go new file mode 100644 index 0000000..d261a76 --- /dev/null +++ b/kernel/mem/vmm/pte.go @@ -0,0 +1,74 @@ +package vmm + +import ( + "github.com/achilleasa/gopher-os/kernel" + "github.com/achilleasa/gopher-os/kernel/mem" + "github.com/achilleasa/gopher-os/kernel/mem/pmm" +) + +var ( + // ErrInvalidMapping is returned when trying to lookup a virtual memory address that is not yet mapped. + ErrInvalidMapping = &kernel.Error{Module: "vmm", Message: "virtual address does not point to a mapped physical page"} +) + +// PageTableEntryFlag describes a flag that can be applied to a page table entry. +type PageTableEntryFlag uintptr + +// pageTableEntry describes a page table entry. These entries encode +// a physical frame address and a set of flags. The actual format +// of the entry and flags is architecture-dependent. +type pageTableEntry uintptr + +// HasFlags returns true if this entry has all the input flags set. +func (pte pageTableEntry) HasFlags(flags PageTableEntryFlag) bool { + return (uintptr(pte) & uintptr(flags)) == uintptr(flags) +} + +// HasAnyFlag returns true if this entry has at least one of the input flags set. +func (pte pageTableEntry) HasAnyFlag(flags PageTableEntryFlag) bool { + return (uintptr(pte) & uintptr(flags)) != 0 +} + +// SetFlags sets the input list of flags to the page table entry. +func (pte *pageTableEntry) SetFlags(flags PageTableEntryFlag) { + *pte = (pageTableEntry)(uintptr(*pte) | uintptr(flags)) +} + +// ClearFlags unsets the input list of flags from the page table entry. +func (pte *pageTableEntry) ClearFlags(flags PageTableEntryFlag) { + *pte = (pageTableEntry)(uintptr(*pte) &^ uintptr(flags)) +} + +// Frame returns the physical page frame that this page table entry points to. +func (pte pageTableEntry) Frame() pmm.Frame { + return pmm.Frame((uintptr(pte) & ptePhysPageMask) >> mem.PageShift) +} + +// SetFrame updates the page table entry to point the the given physical frame . +func (pte *pageTableEntry) SetFrame(frame pmm.Frame) { + *pte = (pageTableEntry)((uintptr(*pte) &^ ptePhysPageMask) | frame.Address()) +} + +// pteForAddress returns the final page table entry that correspond to a +// particular virtual address. The function performs a page table walk till it +// reaches the final page table entry returning ErrInvalidMapping if the page +// is not present. +func pteForAddress(virtAddr uintptr) (*pageTableEntry, *kernel.Error) { + var ( + err *kernel.Error + entry *pageTableEntry + ) + + walk(virtAddr, func(pteLevel uint8, pte *pageTableEntry) bool { + if !pte.HasFlags(FlagPresent) { + entry = nil + err = ErrInvalidMapping + return false + } + + entry = pte + return true + }) + + return entry, err +} diff --git a/kernel/mem/vmm/pte_test.go b/kernel/mem/vmm/pte_test.go new file mode 100644 index 0000000..108986c --- /dev/null +++ b/kernel/mem/vmm/pte_test.go @@ -0,0 +1,61 @@ +package vmm + +import ( + "testing" + + "github.com/achilleasa/gopher-os/kernel/mem/pmm" +) + +func TestPageTableEntryFlags(t *testing.T) { + var ( + pte pageTableEntry + flag1 = PageTableEntryFlag(1 << 10) + flag2 = PageTableEntryFlag(1 << 21) + ) + + if pte.HasAnyFlag(flag1 | flag2) { + t.Fatalf("expected HasAnyFlags to return false") + } + + pte.SetFlags(flag1 | flag2) + + if !pte.HasAnyFlag(flag1 | flag2) { + t.Fatalf("expected HasAnyFlags to return true") + } + + if !pte.HasFlags(flag1 | flag2) { + t.Fatalf("expected HasFlags to return true") + } + + pte.ClearFlags(flag1) + + if !pte.HasAnyFlag(flag1 | flag2) { + t.Fatalf("expected HasAnyFlags to return true") + } + + if pte.HasFlags(flag1 | flag2) { + t.Fatalf("expected HasFlags to return false") + } + + pte.ClearFlags(flag1 | flag2) + + if pte.HasAnyFlag(flag1 | flag2) { + t.Fatalf("expected HasAnyFlags to return false") + } + + if pte.HasFlags(flag1 | flag2) { + t.Fatalf("expected HasFlags to return false") + } +} + +func TestPageTableEntryFrameEncoding(t *testing.T) { + var ( + pte pageTableEntry + physFrame = pmm.Frame(123) + ) + + pte.SetFrame(physFrame) + if got := pte.Frame(); got != physFrame { + t.Fatalf("expected pte.Frame() to return %v; got %v", physFrame, got) + } +} diff --git a/kernel/mem/vmm/tlb.go b/kernel/mem/vmm/tlb.go new file mode 100644 index 0000000..703e597 --- /dev/null +++ b/kernel/mem/vmm/tlb.go @@ -0,0 +1,11 @@ +package vmm + +// flushTLBEntry flushes a TLB entry for a particular virtual address. +func flushTLBEntry(virtAddr uintptr) + +// switchPDT sets the root page table directory to point to the specified +// physical address and flushes the TLB. +func switchPDT(pdtPhysAddr uintptr) + +// activePDT returns the physical address of the currently active page table. +func activePDT() uintptr diff --git a/kernel/mem/vmm/tlb_amd64.s b/kernel/mem/vmm/tlb_amd64.s new file mode 100644 index 0000000..0f12a43 --- /dev/null +++ b/kernel/mem/vmm/tlb_amd64.s @@ -0,0 +1,15 @@ + #include "textflag.h" + +TEXT ·flushTLBEntry(SB),NOSPLIT,$0 + INVLPG virtAddr+0(FP) + RET + +TEXT ·switchPDT(SB),NOSPLIT,$0 + // loading CR3 also triggers a TLB flush + MOVQ pdtPhysAddr+0(FP), CR3 + RET + +TEXT ·activePDT(SB),NOSPLIT,$0 + MOVQ CR3, AX + MOVQ AX, ret+0(FP) + RET diff --git a/kernel/mem/vmm/translate.go b/kernel/mem/vmm/translate.go new file mode 100644 index 0000000..a5d9f51 --- /dev/null +++ b/kernel/mem/vmm/translate.go @@ -0,0 +1,19 @@ +package vmm + +import "github.com/achilleasa/gopher-os/kernel" + +// Translate returns the physical address that corresponds to the supplied +// virtual address or ErrInvalidMapping if the virtual address does not +// correspond to a mapped physical address. +func Translate(virtAddr uintptr) (uintptr, *kernel.Error) { + pte, err := pteForAddress(virtAddr) + if err != nil { + return 0, err + } + + // Calculate the physical address by taking the physical frame address and + // appending the offset from the virtual address + physAddr := pte.Frame().Address() + (virtAddr & ((1 << pageLevelShifts[pageLevels-1]) - 1)) + + return physAddr, nil +} diff --git a/kernel/mem/vmm/translate_test.go b/kernel/mem/vmm/translate_test.go new file mode 100644 index 0000000..2f9ad78 --- /dev/null +++ b/kernel/mem/vmm/translate_test.go @@ -0,0 +1,73 @@ +package vmm + +import ( + "runtime" + "testing" + "unsafe" + + "github.com/achilleasa/gopher-os/kernel/mem/pmm" +) + +func TestTranslateAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origPtePtr func(uintptr) unsafe.Pointer) { + ptePtrFn = origPtePtr + }(ptePtrFn) + + // the virtual address just contains the page offset + virtAddr := uintptr(1234) + expFrame := pmm.Frame(42) + expPhysAddr := expFrame.Address() + virtAddr + specs := [][pageLevels]bool{ + {true, true, true, true}, + {false, true, true, true}, + {true, false, true, true}, + {true, true, false, true}, + {true, true, true, false}, + } + + for specIndex, spec := range specs { + pteCallCount := 0 + ptePtrFn = func(entry uintptr) unsafe.Pointer { + var pte pageTableEntry + pte.SetFrame(expFrame) + if specs[specIndex][pteCallCount] { + pte.SetFlags(FlagPresent) + } + pteCallCount++ + + return unsafe.Pointer(&pte) + } + + // An error is expected if any page level contains a non-present page + expError := false + for _, hasMapping := range spec { + if !hasMapping { + expError = true + break + } + } + + physAddr, err := Translate(virtAddr) + switch { + case expError && err != ErrInvalidMapping: + t.Errorf("[spec %d] expected to get ErrInvalidMapping; got %v", specIndex, err) + case !expError && err != nil: + t.Errorf("[spec %d] unexpected error %v", specIndex, err) + case !expError && physAddr != expPhysAddr: + t.Errorf("[spec %d] expected phys addr to be 0x%x; got 0x%x", specIndex, expPhysAddr, physAddr) + } + } +} + +/* + phys, err := vmm.Translate(uintptr(100 * mem.Mb)) + if err != nil { + early.Printf("err: %s\n", err.Error()) + } else { + early.Printf("phys: 0x%x\n", phys) + } +*/ diff --git a/kernel/mem/vmm/walk.go b/kernel/mem/vmm/walk.go new file mode 100644 index 0000000..8c38469 --- /dev/null +++ b/kernel/mem/vmm/walk.go @@ -0,0 +1,56 @@ +package vmm + +import ( + "unsafe" + + "github.com/achilleasa/gopher-os/kernel/mem" +) + +var ( + // ptePointerFn returns a pointer to the supplied entry address. It is + // used by tests to override the generated page table entry pointers so + // walk() can be properly tested. When compiling the kernel this function + // will be automatically inlined. + ptePtrFn = func(entryAddr uintptr) unsafe.Pointer { + return unsafe.Pointer(entryAddr) + } +) + +// pageTableWalker is a function that can be passed to the walk method. The +// function receives the current page level and page table entry as its +// arguments. If the function returns false, then the page walk is aborted. +type pageTableWalker func(pteLevel uint8, pte *pageTableEntry) bool + +// walk performs a page table walk for the given virtual address. It calls the +// suppplied walkFn with the page table entry that corresponds to each page +// table level. If walkFn returns an error then the walk is aborted and the +// error is returned to the caller. +func walk(virtAddr uintptr, walkFn pageTableWalker) { + var ( + level uint8 + tableAddr, entryAddr, entryIndex uintptr + ok bool + ) + + // tableAddr is initially set to the recursively mapped virtual address for the + // last entry in the top-most page table. Dereferencing a pointer to this address + // will allow us to access + for level, tableAddr = uint8(0), pdtVirtualAddr; level < pageLevels; level, tableAddr = level+1, entryAddr { + // Extract the bits from virtual address that correspond to the + // index in this level's page table + entryIndex = (virtAddr >> pageLevelShifts[level]) & ((1 << pageLevelBits[level]) - 1) + + // By shifting the table virtual address left by pageLevelShifts[level] we add + // a new level of indirection to our recursive mapping allowing us to access + // the table pointed to by the page entry + entryAddr = tableAddr + (entryIndex << mem.PointerShift) + + if ok = walkFn(level, (*pageTableEntry)(ptePtrFn(entryAddr))); !ok { + return + } + + // Shift left by the number of bits for this paging level to get + // the virtual address of the table pointed to by entryAddr + entryAddr <<= pageLevelBits[level] + } +} diff --git a/kernel/mem/vmm/walk_test.go b/kernel/mem/vmm/walk_test.go new file mode 100644 index 0000000..ee235eb --- /dev/null +++ b/kernel/mem/vmm/walk_test.go @@ -0,0 +1,76 @@ +package vmm + +import ( + "runtime" + "testing" + "unsafe" + + "github.com/achilleasa/gopher-os/kernel/mem" +) + +func TestPtePtrFn(t *testing.T) { + // Dummy test to keep coverage happy + if exp, got := unsafe.Pointer(uintptr(123)), ptePtrFn(uintptr(123)); exp != got { + t.Fatalf("expected ptePtrFn to return %v; got %v", exp, got) + } +} + +func TestWalkAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origPtePtr func(uintptr) unsafe.Pointer) { + ptePtrFn = origPtePtr + }(ptePtrFn) + + // This address breaks down to: + // p4 index: 1 + // p3 index: 2 + // p2 index: 3 + // p1 index: 4 + // offset : 1024 + targetAddr := uintptr(0x8080604400) + + sizeofPteEntry := uintptr(unsafe.Sizeof(pageTableEntry(0))) + expEntryAddrBits := [pageLevels][pageLevels + 1]uintptr{ + {511, 511, 511, 511, 1 * sizeofPteEntry}, + {511, 511, 511, 1, 2 * sizeofPteEntry}, + {511, 511, 1, 2, 3 * sizeofPteEntry}, + {511, 1, 2, 3, 4 * sizeofPteEntry}, + } + + pteCallCount := 0 + ptePtrFn = func(entry uintptr) unsafe.Pointer { + if pteCallCount >= pageLevels { + t.Fatalf("unexpected call to ptePtrFn; already called %d times", pageLevels) + } + + for i := 0; i < pageLevels; i++ { + pteIndex := (entry >> pageLevelShifts[i]) & ((1 << pageLevelBits[i]) - 1) + if pteIndex != expEntryAddrBits[pteCallCount][i] { + t.Errorf("[ptePtrFn call %d] expected pte entry for level %d to use offset %d; got %d", pteCallCount, i, expEntryAddrBits[pteCallCount][i], pteIndex) + } + } + + // Check the page offset + pteIndex := entry & ((1 << mem.PageShift) - 1) + if pteIndex != expEntryAddrBits[pteCallCount][pageLevels] { + t.Errorf("[ptePtrFn call %d] expected pte offset to be %d; got %d", pteCallCount, expEntryAddrBits[pteCallCount][pageLevels], pteIndex) + } + + pteCallCount++ + + return unsafe.Pointer(uintptr(0xf00)) + } + + walkFnCallCount := 0 + walk(targetAddr, func(level uint8, entry *pageTableEntry) bool { + walkFnCallCount++ + return walkFnCallCount != pageLevels + }) + + if pteCallCount != pageLevels { + t.Errorf("expected ptePtrFn to be called %d times; got %d", pageLevels, pteCallCount) + } +}