diff --git a/src/components/BodyPortal.spec.tsx b/src/components/BodyPortal.spec.tsx index 1cec6f0cc..860cf85d0 100644 --- a/src/components/BodyPortal.spec.tsx +++ b/src/components/BodyPortal.spec.tsx @@ -1,3 +1,4 @@ +import React from 'react'; import { render } from '@testing-library/react'; import { BodyPortal } from './BodyPortal'; import { BodyPortalSlotsContext } from './BodyPortalSlotsContext'; @@ -176,6 +177,61 @@ describe('BodyPortal', () => { Modal +`); + }); + + it('accepts an optional ref parameter that will be set', () => { + const TestPortal = ({ children }: React.PropsWithChildren<{}>) => { + const ref = React.useRef(null); + expect(ref.current).toBeNull(); + + React.useEffect(() => { + expect(ref.current).toBeInstanceOf(HTMLElement); + }, []); + + return {children}; + }; + render(<>Footer stuff

Title

, { container: root }); + expect(document.body).toMatchInlineSnapshot(` + +
+

+ Title +

+
+ + +`); + }); + + it('accepts a ref callback', () => { + const setRef = jest.fn().mockImplementation((element) => { + expect(element).toBeInstanceOf(HTMLElement); + }); + render(<>Footer stuff +

Title

, { container: root }); + expect(setRef).toHaveBeenCalled(); + expect(document.body).toMatchInlineSnapshot(` + +
+

+ Title +

+
+ + `); }); }); diff --git a/src/components/BodyPortal.tsx b/src/components/BodyPortal.tsx index c421ac29e..b3a42cb90 100644 --- a/src/components/BodyPortal.tsx +++ b/src/components/BodyPortal.tsx @@ -21,21 +21,24 @@ const getInsertBeforeTarget = (bodyPortalSlots: string[], slot?: string) => { return null; } -export const BodyPortal = ({ - children, className, ref, role, slot, tagName -}: React.PropsWithChildren<{ +export const BodyPortal = React.forwardRef; role?: string; slot?: string; - tagName?: string -}>) => { + tagName?: string; +}>>(({ children, className, role, slot, tagName }, ref) => { const tag = tagName?.toUpperCase() ?? 'DIV'; const internalRef = React.useRef(document.createElement(tag)); if (internalRef.current.tagName !== tag) { internalRef.current = document.createElement(tag); } - if (ref) { ref.current = internalRef.current; } + if (ref) { + if (typeof ref === 'function') { + ref(internalRef.current); + } else { + ref.current = internalRef.current; + } + } const bodyPortalOrderedRefs = React.useContext(BodyPortalSlotsContext); @@ -64,4 +67,4 @@ export const BodyPortal = ({ }, [bodyPortalOrderedRefs, className, role, slot, tag]); return createPortal(children, internalRef.current); -}; +});