diff --git a/kernel/mem/vmm/pdt.go b/kernel/mem/vmm/pdt.go new file mode 100644 index 0000000..9762573 --- /dev/null +++ b/kernel/mem/vmm/pdt.go @@ -0,0 +1,135 @@ +package vmm + +import ( + "unsafe" + + "github.com/achilleasa/gopher-os/kernel" + "github.com/achilleasa/gopher-os/kernel/mem" + "github.com/achilleasa/gopher-os/kernel/mem/pmm" +) + +var ( + // activePDTFn is used by tests to override calls to activePDT which + // will cause a fault if called in user-mode. + activePDTFn = activePDT + + // switchPDTFn is used by tests to override calls to switchPDT which + // will cause a fault if called in user-mode. + switchPDTFn = switchPDT + + // mapFn is used by tests and is automatically inlined by the compiler. + mapFn = Map + + // mapTemporaryFn is used by tests and is automatically inlined by the compiler. + mapTemporaryFn = MapTemporary + + // unmapmFn is used by tests and is automatically inlined by the compiler. + unmapFn = Unmap +) + +// PageDirectoryTable describes the top-most table in a multi-level paging scheme. +type PageDirectoryTable struct { + pdtFrame pmm.Frame +} + +// Init sets up the page table directory starting at the supplied physical +// address. If the supplied frame does not match the currently active PDT, then +// Init assumes that this is a new page table directory that needs +// bootstapping. In such a case, a temporary mapping is established so that +// Init can: +// - call mem.Memset to clear the frame contents +// - setup a recursive mapping for the last table entry to the page itself. +func (pdt *PageDirectoryTable) Init(pdtFrame pmm.Frame, allocFn FrameAllocator) *kernel.Error { + pdt.pdtFrame = pdtFrame + + // Check active PDT physical address. If it matches the input pdt then + // nothing more needs to be done + activePdtAddr := activePDTFn() + if pdtFrame.Address() == activePdtAddr { + return nil + } + + // Create a temporary mapping for the pdt frame so we can work on it + pdtPage, err := mapTemporaryFn(pdtFrame, allocFn) + if err != nil { + return err + } + + // Clear the page contents and setup recursive mapping for the last PDT entry + mem.Memset(pdtPage.Address(), 0, mem.PageSize) + lastPdtEntry := (*pageTableEntry)(unsafe.Pointer(pdtPage.Address() + (((1 << pageLevelBits[0]) - 1) << mem.PointerShift))) + *lastPdtEntry = 0 + lastPdtEntry.SetFlags(FlagPresent | FlagRW) + lastPdtEntry.SetFrame(pdtFrame) + + // Remove temporary mapping + unmapFn(pdtPage) + + return nil +} + +// Map establishes a mapping between a virtual page and a physical memory frame +// using this PDT. This method behaves in a similar fashion to the global Map() +// function with the difference that it also supports inactive page PDTs by +// establishing a temporary mapping so that Map() can access the inactive PDT +// entries. +func (pdt PageDirectoryTable) Map(page Page, frame pmm.Frame, flags PageTableEntryFlag, allocFn FrameAllocator) *kernel.Error { + var ( + activePdtFrame = pmm.Frame(activePDTFn() >> mem.PageShift) + lastPdtEntryAddr uintptr + lastPdtEntry *pageTableEntry + ) + // If this table is not active we need to temporarily map it to the + // last entry in the active PDT so we can access it using the recursive + // virtual address scheme. + if activePdtFrame != pdt.pdtFrame { + lastPdtEntryAddr = activePdtFrame.Address() + (((1 << pageLevelBits[0]) - 1) << mem.PointerShift) + lastPdtEntry = (*pageTableEntry)(unsafe.Pointer(lastPdtEntryAddr)) + lastPdtEntry.SetFrame(pdt.pdtFrame) + flushTLBEntryFn(lastPdtEntryAddr) + } + + err := mapFn(page, frame, flags, allocFn) + + if activePdtFrame != pdt.pdtFrame { + lastPdtEntry.SetFrame(activePdtFrame) + flushTLBEntryFn(lastPdtEntryAddr) + } + + return err +} + +// Unmap removes a mapping previousle installed by a call to Map() on this PDT. +// This method behaves in a similar fashion to the global Unmap() function with +// the difference that it also supports inactive page PDTs by establishing a +// temporary mapping so that Unmap() can access the inactive PDT entries. +func (pdt PageDirectoryTable) Unmap(page Page) *kernel.Error { + var ( + activePdtFrame = pmm.Frame(activePDTFn() >> mem.PageShift) + lastPdtEntryAddr uintptr + lastPdtEntry *pageTableEntry + ) + // If this table is not active we need to temporarily map it to the + // last entry in the active PDT so we can access it using the recursive + // virtual address scheme. + if activePdtFrame != pdt.pdtFrame { + lastPdtEntryAddr = activePdtFrame.Address() + (((1 << pageLevelBits[0]) - 1) << mem.PointerShift) + lastPdtEntry = (*pageTableEntry)(unsafe.Pointer(lastPdtEntryAddr)) + lastPdtEntry.SetFrame(pdt.pdtFrame) + flushTLBEntryFn(lastPdtEntryAddr) + } + + err := unmapFn(page) + + if activePdtFrame != pdt.pdtFrame { + lastPdtEntry.SetFrame(activePdtFrame) + flushTLBEntryFn(lastPdtEntryAddr) + } + + return err +} + +// Activate enables this page directory table and flushes the TLB +func (pdt PageDirectoryTable) Activate() { + switchPDTFn(pdt.pdtFrame.Address()) +} diff --git a/kernel/mem/vmm/pdt_test.go b/kernel/mem/vmm/pdt_test.go new file mode 100644 index 0000000..64172b7 --- /dev/null +++ b/kernel/mem/vmm/pdt_test.go @@ -0,0 +1,331 @@ +package vmm + +import ( + "runtime" + "testing" + "unsafe" + + "github.com/achilleasa/gopher-os/kernel" + "github.com/achilleasa/gopher-os/kernel/mem" + "github.com/achilleasa/gopher-os/kernel/mem/pmm" +) + +func TestPageDirectoryTableInitAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origFlushTLBEntry func(uintptr), origActivePDT func() uintptr, origMapTemporary func(pmm.Frame, FrameAllocator) (Page, *kernel.Error), origUnmap func(Page) *kernel.Error) { + flushTLBEntryFn = origFlushTLBEntry + activePDTFn = origActivePDT + mapTemporaryFn = origMapTemporary + unmapFn = origUnmap + }(flushTLBEntryFn, activePDTFn, mapTemporaryFn, unmapFn) + + t.Run("already mapped PDT", func(t *testing.T) { + var ( + pdt PageDirectoryTable + pdtFrame = pmm.Frame(123) + ) + + activePDTFn = func() uintptr { + return pdtFrame.Address() + } + + mapTemporaryFn = func(_ pmm.Frame, _ FrameAllocator) (Page, *kernel.Error) { + t.Fatal("unexpected call to MapTemporary") + return 0, nil + } + + unmapFn = func(_ Page) *kernel.Error { + t.Fatal("unexpected call to Unmap") + return nil + } + + if err := pdt.Init(pdtFrame, nil); err != nil { + t.Fatal(err) + } + }) + + t.Run("not mapped PDT", func(t *testing.T) { + var ( + pdt PageDirectoryTable + pdtFrame = pmm.Frame(123) + physPage [mem.PageSize >> mem.PointerShift]pageTableEntry + ) + + // Fill phys page with random junk + mem.Memset(uintptr(unsafe.Pointer(&physPage[0])), 0xf0, mem.PageSize) + + activePDTFn = func() uintptr { + return 0 + } + + mapTemporaryFn = func(_ pmm.Frame, _ FrameAllocator) (Page, *kernel.Error) { + return PageFromAddress(uintptr(unsafe.Pointer(&physPage[0]))), nil + } + + flushTLBEntryFn = func(_ uintptr) {} + + unmapCallCount := 0 + unmapFn = func(_ Page) *kernel.Error { + unmapCallCount++ + return nil + } + + if err := pdt.Init(pdtFrame, nil); err != nil { + t.Fatal(err) + } + + if unmapCallCount != 1 { + t.Fatalf("expected Unmap to be called 1 time; called %d", unmapCallCount) + } + + for i := 0; i < len(physPage)-1; i++ { + if physPage[i] != 0 { + t.Errorf("expected PDT entry %d to be cleared; got %x", i, physPage[i]) + } + } + + // The last page should be recursively mapped to the PDT + lastPdtEntry := physPage[len(physPage)-1] + if !lastPdtEntry.HasFlags(FlagPresent | FlagRW) { + t.Fatal("expected last PDT entry to have FlagPresent and FlagRW set") + } + + if lastPdtEntry.Frame() != pdtFrame { + t.Fatalf("expected last PDT entry to be recursively mapped to physical frame %x; got %x", pdtFrame, lastPdtEntry.Frame()) + } + }) + + t.Run("temporary mapping failure", func(t *testing.T) { + var ( + pdt PageDirectoryTable + pdtFrame = pmm.Frame(123) + ) + + activePDTFn = func() uintptr { + return 0 + } + + expErr := &kernel.Error{Module: "test", Message: "error mapping page"} + + mapTemporaryFn = func(_ pmm.Frame, _ FrameAllocator) (Page, *kernel.Error) { + return 0, expErr + } + + unmapFn = func(_ Page) *kernel.Error { + t.Fatal("unexpected call to Unmap") + return nil + } + + if err := pdt.Init(pdtFrame, nil); err != expErr { + t.Fatalf("expected to get error: %v; got %v", *expErr, err) + } + }) +} + +func TestPageDirectoryTableMapAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origFlushTLBEntry func(uintptr), origActivePDT func() uintptr, origMap func(Page, pmm.Frame, PageTableEntryFlag, FrameAllocator) *kernel.Error) { + flushTLBEntryFn = origFlushTLBEntry + activePDTFn = origActivePDT + mapFn = origMap + }(flushTLBEntryFn, activePDTFn, mapFn) + + t.Run("already mapped PDT", func(t *testing.T) { + var ( + pdtFrame = pmm.Frame(123) + pdt = PageDirectoryTable{pdtFrame: pdtFrame} + page = PageFromAddress(uintptr(100 * mem.Mb)) + ) + + activePDTFn = func() uintptr { + return pdtFrame.Address() + } + + mapFn = func(_ Page, _ pmm.Frame, _ PageTableEntryFlag, _ FrameAllocator) *kernel.Error { + return nil + } + + flushCallCount := 0 + flushTLBEntryFn = func(_ uintptr) { + flushCallCount++ + } + + if err := pdt.Map(page, pmm.Frame(321), FlagRW, nil); err != nil { + t.Fatal(err) + } + + if exp := 0; flushCallCount != exp { + t.Fatalf("expected flushTLBEntry to be called %d times; called %d", exp, flushCallCount) + } + }) + + t.Run("not mapped PDT", func(t *testing.T) { + var ( + pdtFrame = pmm.Frame(123) + pdt = PageDirectoryTable{pdtFrame: pdtFrame} + page = PageFromAddress(uintptr(100 * mem.Mb)) + activePhysPage [mem.PageSize >> mem.PointerShift]pageTableEntry + activePdtFrame = pmm.Frame(uintptr(unsafe.Pointer(&activePhysPage[0])) >> mem.PageShift) + ) + + // Initially, activePhysPage is recursively mapped to itself + activePhysPage[len(activePhysPage)-1].SetFlags(FlagPresent | FlagRW) + activePhysPage[len(activePhysPage)-1].SetFrame(activePdtFrame) + + activePDTFn = func() uintptr { + return activePdtFrame.Address() + } + + mapFn = func(_ Page, _ pmm.Frame, _ PageTableEntryFlag, _ FrameAllocator) *kernel.Error { + return nil + } + + flushCallCount := 0 + flushTLBEntryFn = func(_ uintptr) { + switch flushCallCount { + case 0: + // the first time we flush the tlb entry, the last entry of + // the active pdt should be pointing to pdtFrame + if got := activePhysPage[len(activePhysPage)-1].Frame(); got != pdtFrame { + t.Fatalf("expected last PDT entry of active PDT to be re-mapped to frame %x; got %x", pdtFrame, got) + } + case 1: + // the second time we flush the tlb entry, the last entry of + // the active pdt should be pointing back to activePdtFrame + if got := activePhysPage[len(activePhysPage)-1].Frame(); got != activePdtFrame { + t.Fatalf("expected last PDT entry of active PDT to be mapped back frame %x; got %x", activePdtFrame, got) + } + } + flushCallCount++ + } + + if err := pdt.Map(page, pmm.Frame(321), FlagRW, nil); err != nil { + t.Fatal(err) + } + + if exp := 2; flushCallCount != exp { + t.Fatalf("expected flushTLBEntry to be called %d times; called %d", exp, flushCallCount) + } + }) +} + +func TestPageDirectoryTableUnmapAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origFlushTLBEntry func(uintptr), origActivePDT func() uintptr, origUnmap func(Page) *kernel.Error) { + flushTLBEntryFn = origFlushTLBEntry + activePDTFn = origActivePDT + unmapFn = origUnmap + }(flushTLBEntryFn, activePDTFn, unmapFn) + + t.Run("already mapped PDT", func(t *testing.T) { + var ( + pdtFrame = pmm.Frame(123) + pdt = PageDirectoryTable{pdtFrame: pdtFrame} + page = PageFromAddress(uintptr(100 * mem.Mb)) + ) + + activePDTFn = func() uintptr { + return pdtFrame.Address() + } + + unmapFn = func(_ Page) *kernel.Error { + return nil + } + + flushCallCount := 0 + flushTLBEntryFn = func(_ uintptr) { + flushCallCount++ + } + + if err := pdt.Unmap(page); err != nil { + t.Fatal(err) + } + + if exp := 0; flushCallCount != exp { + t.Fatalf("expected flushTLBEntry to be called %d times; called %d", exp, flushCallCount) + } + }) + + t.Run("not mapped PDT", func(t *testing.T) { + var ( + pdtFrame = pmm.Frame(123) + pdt = PageDirectoryTable{pdtFrame: pdtFrame} + page = PageFromAddress(uintptr(100 * mem.Mb)) + activePhysPage [mem.PageSize >> mem.PointerShift]pageTableEntry + activePdtFrame = pmm.Frame(uintptr(unsafe.Pointer(&activePhysPage[0])) >> mem.PageShift) + ) + + // Initially, activePhysPage is recursively mapped to itself + activePhysPage[len(activePhysPage)-1].SetFlags(FlagPresent | FlagRW) + activePhysPage[len(activePhysPage)-1].SetFrame(activePdtFrame) + + activePDTFn = func() uintptr { + return activePdtFrame.Address() + } + + unmapFn = func(_ Page) *kernel.Error { + return nil + } + + flushCallCount := 0 + flushTLBEntryFn = func(_ uintptr) { + switch flushCallCount { + case 0: + // the first time we flush the tlb entry, the last entry of + // the active pdt should be pointing to pdtFrame + if got := activePhysPage[len(activePhysPage)-1].Frame(); got != pdtFrame { + t.Fatalf("expected last PDT entry of active PDT to be re-mapped to frame %x; got %x", pdtFrame, got) + } + case 1: + // the second time we flush the tlb entry, the last entry of + // the active pdt should be pointing back to activePdtFrame + if got := activePhysPage[len(activePhysPage)-1].Frame(); got != activePdtFrame { + t.Fatalf("expected last PDT entry of active PDT to be mapped back frame %x; got %x", activePdtFrame, got) + } + } + flushCallCount++ + } + + if err := pdt.Unmap(page); err != nil { + t.Fatal(err) + } + + if exp := 2; flushCallCount != exp { + t.Fatalf("expected flushTLBEntry to be called %d times; called %d", exp, flushCallCount) + } + }) +} + +func TestPageDirectoryTableActivateAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origSwitchPDT func(uintptr)) { + switchPDTFn = origSwitchPDT + }(switchPDTFn) + + var ( + pdtFrame = pmm.Frame(123) + pdt = PageDirectoryTable{pdtFrame: pdtFrame} + ) + + switchPDTCallCount := 0 + switchPDTFn = func(_ uintptr) { + switchPDTCallCount++ + } + + pdt.Activate() + if exp := 1; switchPDTCallCount != exp { + t.Fatalf("expected switchPDT to be called %d times; called %d", exp, switchPDTCallCount) + } +}