diff --git a/kernel/mem/vmm/pte.go b/kernel/mem/vmm/pte.go index e86919d..d261a76 100644 --- a/kernel/mem/vmm/pte.go +++ b/kernel/mem/vmm/pte.go @@ -1,10 +1,16 @@ package vmm import ( + "github.com/achilleasa/gopher-os/kernel" "github.com/achilleasa/gopher-os/kernel/mem" "github.com/achilleasa/gopher-os/kernel/mem/pmm" ) +var ( + // ErrInvalidMapping is returned when trying to lookup a virtual memory address that is not yet mapped. + ErrInvalidMapping = &kernel.Error{Module: "vmm", Message: "virtual address does not point to a mapped physical page"} +) + // PageTableEntryFlag describes a flag that can be applied to a page table entry. type PageTableEntryFlag uintptr @@ -42,3 +48,27 @@ func (pte pageTableEntry) Frame() pmm.Frame { func (pte *pageTableEntry) SetFrame(frame pmm.Frame) { *pte = (pageTableEntry)((uintptr(*pte) &^ ptePhysPageMask) | frame.Address()) } + +// pteForAddress returns the final page table entry that correspond to a +// particular virtual address. The function performs a page table walk till it +// reaches the final page table entry returning ErrInvalidMapping if the page +// is not present. +func pteForAddress(virtAddr uintptr) (*pageTableEntry, *kernel.Error) { + var ( + err *kernel.Error + entry *pageTableEntry + ) + + walk(virtAddr, func(pteLevel uint8, pte *pageTableEntry) bool { + if !pte.HasFlags(FlagPresent) { + entry = nil + err = ErrInvalidMapping + return false + } + + entry = pte + return true + }) + + return entry, err +} diff --git a/kernel/mem/vmm/translate.go b/kernel/mem/vmm/translate.go new file mode 100644 index 0000000..a5d9f51 --- /dev/null +++ b/kernel/mem/vmm/translate.go @@ -0,0 +1,19 @@ +package vmm + +import "github.com/achilleasa/gopher-os/kernel" + +// Translate returns the physical address that corresponds to the supplied +// virtual address or ErrInvalidMapping if the virtual address does not +// correspond to a mapped physical address. +func Translate(virtAddr uintptr) (uintptr, *kernel.Error) { + pte, err := pteForAddress(virtAddr) + if err != nil { + return 0, err + } + + // Calculate the physical address by taking the physical frame address and + // appending the offset from the virtual address + physAddr := pte.Frame().Address() + (virtAddr & ((1 << pageLevelShifts[pageLevels-1]) - 1)) + + return physAddr, nil +} diff --git a/kernel/mem/vmm/translate_test.go b/kernel/mem/vmm/translate_test.go new file mode 100644 index 0000000..2f9ad78 --- /dev/null +++ b/kernel/mem/vmm/translate_test.go @@ -0,0 +1,73 @@ +package vmm + +import ( + "runtime" + "testing" + "unsafe" + + "github.com/achilleasa/gopher-os/kernel/mem/pmm" +) + +func TestTranslateAmd64(t *testing.T) { + if runtime.GOARCH != "amd64" { + t.Skip("test requires amd64 runtime; skipping") + } + + defer func(origPtePtr func(uintptr) unsafe.Pointer) { + ptePtrFn = origPtePtr + }(ptePtrFn) + + // the virtual address just contains the page offset + virtAddr := uintptr(1234) + expFrame := pmm.Frame(42) + expPhysAddr := expFrame.Address() + virtAddr + specs := [][pageLevels]bool{ + {true, true, true, true}, + {false, true, true, true}, + {true, false, true, true}, + {true, true, false, true}, + {true, true, true, false}, + } + + for specIndex, spec := range specs { + pteCallCount := 0 + ptePtrFn = func(entry uintptr) unsafe.Pointer { + var pte pageTableEntry + pte.SetFrame(expFrame) + if specs[specIndex][pteCallCount] { + pte.SetFlags(FlagPresent) + } + pteCallCount++ + + return unsafe.Pointer(&pte) + } + + // An error is expected if any page level contains a non-present page + expError := false + for _, hasMapping := range spec { + if !hasMapping { + expError = true + break + } + } + + physAddr, err := Translate(virtAddr) + switch { + case expError && err != ErrInvalidMapping: + t.Errorf("[spec %d] expected to get ErrInvalidMapping; got %v", specIndex, err) + case !expError && err != nil: + t.Errorf("[spec %d] unexpected error %v", specIndex, err) + case !expError && physAddr != expPhysAddr: + t.Errorf("[spec %d] expected phys addr to be 0x%x; got 0x%x", specIndex, expPhysAddr, physAddr) + } + } +} + +/* + phys, err := vmm.Translate(uintptr(100 * mem.Mb)) + if err != nil { + early.Printf("err: %s\n", err.Error()) + } else { + early.Printf("phys: 0x%x\n", phys) + } +*/