diff --git a/kernel/mem/vmm/constants_amd64.go b/kernel/mem/vmm/constants_amd64.go index 5e3b95d..c785244 100644 --- a/kernel/mem/vmm/constants_amd64.go +++ b/kernel/mem/vmm/constants_amd64.go @@ -80,6 +80,10 @@ const ( // for this page when the swapping page tables by updating the CR3 register. FlagGlobal + // FlagCopyOnWrite is used to implement copy-on-write functionality. This + // flag and FlagRW are mutually exclusive. + FlagCopyOnWrite = 1 << 9 + // FlagNoExecute if set, indicates that a page contains non-executable code. FlagNoExecute = 1 << 63 ) diff --git a/kernel/mem/vmm/vmm.go b/kernel/mem/vmm/vmm.go index ba7cf37..be04f6f 100644 --- a/kernel/mem/vmm/vmm.go +++ b/kernel/mem/vmm/vmm.go @@ -5,6 +5,7 @@ import ( "github.com/achilleasa/gopher-os/kernel/cpu" "github.com/achilleasa/gopher-os/kernel/irq" "github.com/achilleasa/gopher-os/kernel/kfmt/early" + "github.com/achilleasa/gopher-os/kernel/mem" "github.com/achilleasa/gopher-os/kernel/mem/pmm" ) @@ -30,7 +31,58 @@ func SetFrameAllocator(allocFn FrameAllocatorFn) { } func pageFaultHandler(errorCode uint64, frame *irq.Frame, regs *irq.Regs) { - early.Printf("\nPage fault while accessing address: 0x%16x\nReason: ", readCR2Fn()) + var ( + faultAddress = uintptr(readCR2Fn()) + faultPage = PageFromAddress(faultAddress) + pageEntry *pageTableEntry + ) + + // Lookup entry for the page where the fault occurred + walk(faultPage.Address(), func(pteLevel uint8, pte *pageTableEntry) bool { + nextIsPresent := pte.HasFlags(FlagPresent) + + if pteLevel == pageLevels-1 && nextIsPresent { + pageEntry = pte + } + + // Abort walk if the next page table entry is missing + return nextIsPresent + }) + + // CoW is supported for RO pages with the CoW flag set + if pageEntry != nil && !pageEntry.HasFlags(FlagRW) && pageEntry.HasFlags(FlagCopyOnWrite) { + var ( + copy pmm.Frame + tmpPage Page + err *kernel.Error + ) + + if copy, err = frameAllocator(); err != nil { + nonRecoverablePageFault(faultAddress, errorCode, frame, regs, err) + } else if tmpPage, err = mapTemporaryFn(copy); err != nil { + nonRecoverablePageFault(faultAddress, errorCode, frame, regs, err) + } else { + // Copy page contents, mark as RW and remove CoW flag + mem.Memcopy(faultPage.Address(), tmpPage.Address(), mem.PageSize) + unmapFn(tmpPage) + + // Update mapping to point to the new frame, flag it as RW and + // remove the CoW flag + pageEntry.ClearFlags(FlagCopyOnWrite) + pageEntry.SetFlags(FlagPresent | FlagRW) + pageEntry.SetFrame(copy) + flushTLBEntryFn(faultPage.Address()) + + // Fault recovered; retry the instruction that caused the fault + return + } + } + + nonRecoverablePageFault(faultAddress, errorCode, frame, regs, nil) +} + +func nonRecoverablePageFault(faultAddress uintptr, errorCode uint64, frame *irq.Frame, regs *irq.Regs, err *kernel.Error) { + early.Printf("\nPage fault while accessing address: 0x%16x\nReason: ", faultAddress) switch { case errorCode == 0: early.Printf("read from non-present page") @@ -55,7 +107,7 @@ func pageFaultHandler(errorCode uint64, frame *irq.Frame, regs *irq.Regs) { frame.Print() // TODO: Revisit this when user-mode tasks are implemented - panicFn(nil) + panicFn(err) } func generalProtectionFaultHandler(_ uint64, frame *irq.Frame, regs *irq.Regs) { diff --git a/kernel/mem/vmm/vmm_test.go b/kernel/mem/vmm/vmm_test.go index d1aedd3..f7e732b 100644 --- a/kernel/mem/vmm/vmm_test.go +++ b/kernel/mem/vmm/vmm_test.go @@ -11,12 +11,98 @@ import ( "github.com/achilleasa/gopher-os/kernel/driver/video/console" "github.com/achilleasa/gopher-os/kernel/hal" "github.com/achilleasa/gopher-os/kernel/irq" + "github.com/achilleasa/gopher-os/kernel/mem" + "github.com/achilleasa/gopher-os/kernel/mem/pmm" ) -func TestPageFaultHandler(t *testing.T) { - defer func() { +func TestRecoverablePageFault(t *testing.T) { + var ( + frame irq.Frame + regs irq.Regs + panicCalled bool + pageEntry pageTableEntry + origPage = make([]byte, mem.PageSize) + clonedPage = make([]byte, mem.PageSize) + err = &kernel.Error{Module: "test", Message: "something went wrong"} + ) + + defer func(origPtePtr func(uintptr) unsafe.Pointer) { + ptePtrFn = origPtePtr panicFn = kernel.Panic readCR2Fn = cpu.ReadCR2 + frameAllocator = nil + mapTemporaryFn = MapTemporary + unmapFn = Unmap + flushTLBEntryFn = cpu.FlushTLBEntry + }(ptePtrFn) + + specs := []struct { + pteFlags PageTableEntryFlag + allocError *kernel.Error + mapError *kernel.Error + expPanic bool + }{ + // Missing pge + {0, nil, nil, true}, + // Page is present but CoW flag not set + {FlagPresent, nil, nil, true}, + // Page is present but both CoW and RW flags set + {FlagPresent | FlagRW | FlagCopyOnWrite, nil, nil, true}, + // Page is present with CoW flag set but allocating a page copy fails + {FlagPresent | FlagCopyOnWrite, err, nil, true}, + // Page is present with CoW flag set but mapping the page copy fails + {FlagPresent | FlagCopyOnWrite, nil, err, true}, + // Page is present with CoW flag set + {FlagPresent | FlagCopyOnWrite, nil, nil, false}, + } + + mockTTY() + + panicFn = func(_ *kernel.Error) { + panicCalled = true + } + + ptePtrFn = func(entry uintptr) unsafe.Pointer { return unsafe.Pointer(&pageEntry) } + readCR2Fn = func() uint64 { return uint64(uintptr(unsafe.Pointer(&origPage[0]))) } + unmapFn = func(_ Page) *kernel.Error { return nil } + flushTLBEntryFn = func(_ uintptr) {} + + for specIndex, spec := range specs { + mapTemporaryFn = func(f pmm.Frame) (Page, *kernel.Error) { return Page(f), spec.mapError } + SetFrameAllocator(func() (pmm.Frame, *kernel.Error) { + addr := uintptr(unsafe.Pointer(&clonedPage[0])) + return pmm.Frame(addr >> mem.PageShift), spec.allocError + }) + + for i := 0; i < len(origPage); i++ { + origPage[i] = byte(i % 256) + clonedPage[i] = 0 + } + + panicCalled = false + pageEntry = 0 + pageEntry.SetFlags(spec.pteFlags) + + pageFaultHandler(2, &frame, ®s) + + if spec.expPanic != panicCalled { + t.Errorf("[spec %d] expected panic %t; got %t", specIndex, spec.expPanic, panicCalled) + } + + if !spec.expPanic { + for i := 0; i < len(origPage); i++ { + if origPage[i] != clonedPage[i] { + t.Errorf("[spec %d] expected clone page to be a copy of the original page; mismatch at index %d", specIndex, i) + } + } + } + } + +} + +func TestNonRecoverablePageFault(t *testing.T) { + defer func() { + panicFn = kernel.Panic }() specs := []struct { @@ -71,10 +157,6 @@ func TestPageFaultHandler(t *testing.T) { frame irq.Frame ) - readCR2Fn = func() uint64 { - return 0xbadf00d000 - } - panicCalled := false panicFn = func(_ *kernel.Error) { panicCalled = true @@ -84,7 +166,7 @@ func TestPageFaultHandler(t *testing.T) { fb := mockTTY() panicCalled = false - pageFaultHandler(spec.errCode, &frame, ®s) + nonRecoverablePageFault(0xbadf00d000, spec.errCode, &frame, ®s, nil) if got := readTTY(fb); !strings.Contains(got, spec.expReason) { t.Errorf("[spec %d] expected reason %q; got output:\n%q", specIndex, spec.expReason, got) continue