From b1a91bdcd760e7fc37a92915f8811103085dd01f Mon Sep 17 00:00:00 2001 From: Riya Sinha Date: Fri, 21 Feb 2025 00:17:47 -0800 Subject: [PATCH] support agent cells, registering dfs to globals --- README_ext.md | 25 ++++ .../editor/cell/code/language-toggle.tsx | 22 ++- .../components/editor/renderers/CellArray.tsx | 28 ++++ .../codemirror/language/LanguageAdapters.ts | 2 + frontend/src/core/codemirror/language/ai.ts | 128 ++++++++++++++++++ .../src/core/codemirror/language/extension.ts | 3 + .../src/core/codemirror/language/types.ts | 2 +- .../config/__tests__/config-schema.test.ts | 4 +- marimo/__init__.py | 5 +- marimo/_ai/__init__.py | 2 + marimo/_ai/agents.py | 32 +++++ marimo/_config/config.py | 2 +- marimo/_runtime/agents.py | 24 ++++ marimo/_runtime/context/kernel_context.py | 2 + marimo/_runtime/context/script_context.py | 2 + marimo/_runtime/context/types.py | 2 + marimo/_runtime/editor/__init__.py | 7 + marimo/_runtime/editor/_editor.py | 48 +++++++ pyproject.toml | 4 +- tests/_runtime/editor/test_editor.py | 57 ++++++++ 20 files changed, 392 insertions(+), 9 deletions(-) create mode 100644 README_ext.md create mode 100644 frontend/src/core/codemirror/language/ai.ts create mode 100644 marimo/_ai/agents.py create mode 100644 marimo/_runtime/agents.py create mode 100644 marimo/_runtime/editor/__init__.py create mode 100644 marimo/_runtime/editor/_editor.py create mode 100644 tests/_runtime/editor/test_editor.py diff --git a/README_ext.md b/README_ext.md new file mode 100644 index 00000000000..9b6c984bfee --- /dev/null +++ b/README_ext.md @@ -0,0 +1,25 @@ +# Marimo LLM Agent Extension + +This fork is a modified version of Marimo with support for executing LLM agents from cells. Some features may interfere with Marimo's original functionality, so use with caution. + +## Feature overview + +### Agent Registry + +The agent registry is a new feature that allows users to register LLM agents (e.g. LangChain, LangGraph) with the Marimo UI. The cell input can be set as a plain-text input to the agent, with the cell output representing the agent's response. + +Usage: + + +### Background Datasource Variable Registration + +> [!CAUTION] +> This function may cause unintended bugs in Marimo's reactivity, since +> defined variables cannot be statically analyzed. Also, this can be +> confusing for users if used inappropriately to flood the global scope. +> Please be mindful of this function. + +This feature allows LLM agents designed to work with this version of Marimo +to emit variables to the global scope. This is useful for agents that +make tool calls and want to implicitly assign intermediate fetched data to +variables. diff --git a/frontend/src/components/editor/cell/code/language-toggle.tsx b/frontend/src/components/editor/cell/code/language-toggle.tsx index b44f39b7e19..328486ac7b0 100644 --- a/frontend/src/components/editor/cell/code/language-toggle.tsx +++ b/frontend/src/components/editor/cell/code/language-toggle.tsx @@ -6,11 +6,12 @@ import { MarkdownIcon, PythonIcon } from "./icons"; import { Button } from "@/components/ui/button"; import { Tooltip } from "@/components/ui/tooltip"; import type { LanguageAdapter } from "@/core/codemirror/language/types"; -import { DatabaseIcon } from "lucide-react"; +import { BotIcon, DatabaseIcon } from "lucide-react"; import { useMemo } from "react"; import { MarkdownLanguageAdapter } from "@/core/codemirror/language/markdown"; import { SQLLanguageAdapter } from "@/core/codemirror/language/sql"; import { Functions } from "@/utils/functions"; +import { AIAgentLanguageAdapter } from "@/core/codemirror/language/ai"; interface LanguageTogglesProps { editorView: EditorView | null; @@ -33,6 +34,10 @@ export const LanguageToggles: React.FC = ({ () => new SQLLanguageAdapter().isSupported(code) || code.trim() === "", [code], ); + const canUseAgent = useMemo( + () => new AIAgentLanguageAdapter().isSupported(code) || code.trim() === "", + [code], + ); return (
@@ -83,6 +88,21 @@ export const LanguageToggles: React.FC = ({ displayName="Python" onAfterToggle={Functions.NOOP} /> + + } + toType="agent" + displayName="Agent" + onAfterToggle={onAfterToggle} + />
); }; diff --git a/frontend/src/components/editor/renderers/CellArray.tsx b/frontend/src/components/editor/renderers/CellArray.tsx index 075d4efacd6..1339d1e5596 100644 --- a/frontend/src/components/editor/renderers/CellArray.tsx +++ b/frontend/src/components/editor/renderers/CellArray.tsx @@ -27,6 +27,7 @@ import { useDeleteCellCallback } from "../cell/useDeleteCell"; import { cn } from "@/utils/cn"; import { Button } from "@/components/ui/button"; import { + BotIcon, DatabaseIcon, SparklesIcon, SquareCodeIcon, @@ -37,6 +38,7 @@ import { aiEnabledAtom, autoInstantiateAtom } from "@/core/config/config"; import { useAtomValue } from "jotai"; import { useBoolean } from "@/hooks/useBoolean"; import { AddCellWithAI } from "../ai/add-cell-with-ai"; +import { AIAgentLanguageAdapter } from "@/core/codemirror/language/ai"; import type { Milliseconds } from "@/utils/time"; import { SQLLanguageAdapter } from "@/core/codemirror/language/sql"; import { MarkdownLanguageAdapter } from "@/core/codemirror/language/markdown"; @@ -343,6 +345,32 @@ const AddCellButtons: React.FC<{ SQL + Enable via settings under AI Assist + null + } + delayDuration={100} + asChild={false} + > + + Enable via settings under AI Assist diff --git a/frontend/src/core/codemirror/language/LanguageAdapters.ts b/frontend/src/core/codemirror/language/LanguageAdapters.ts index 8b86518cfd3..1075c54e367 100644 --- a/frontend/src/core/codemirror/language/LanguageAdapters.ts +++ b/frontend/src/core/codemirror/language/LanguageAdapters.ts @@ -3,6 +3,7 @@ import type { LanguageAdapter, LanguageAdapterType } from "./types"; import { PythonLanguageAdapter } from "./python"; import { MarkdownLanguageAdapter } from "./markdown"; import { SQLLanguageAdapter } from "./sql"; +import { AIAgentLanguageAdapter } from "./ai"; export const LanguageAdapters: Record< LanguageAdapterType, @@ -11,6 +12,7 @@ export const LanguageAdapters: Record< python: () => new PythonLanguageAdapter(), markdown: () => new MarkdownLanguageAdapter(), sql: () => new SQLLanguageAdapter(), + agent: () => new AIAgentLanguageAdapter(), }; export function getLanguageAdapters() { diff --git a/frontend/src/core/codemirror/language/ai.ts b/frontend/src/core/codemirror/language/ai.ts new file mode 100644 index 00000000000..268de471129 --- /dev/null +++ b/frontend/src/core/codemirror/language/ai.ts @@ -0,0 +1,128 @@ +/* Copyright 2024 Marimo. All rights reserved. */ +import type { Extension } from "@codemirror/state"; +import type { LanguageAdapter } from "./types"; +import dedent from "string-dedent"; +import type { CompletionConfig } from "@/core/config/config-schema"; +import type { HotkeyProvider } from "@/core/hotkeys/hotkeys"; +import { indentOneTab } from "./utils/indentOneTab"; +import { type QuotePrefixKind, splitQuotePrefix } from "./utils/quotes"; +import type { MovementCallbacks } from "../cells/extensions"; +import type { PlaceholderType } from "../config/extension"; + +const quoteKinds = [ + ['"""', '"""'], + ["'''", "'''"], + ['"', '"'], + ["'", "'"], +]; + +// explode into all combinations +// +// A note on f-strings: +// +// f-strings are not yet supported due to bad interactions with +// string escaping, LaTeX, and loss of Python syntax highlighting +const pairs = ["", "r"].flatMap((prefix) => + quoteKinds.map(([start, end]) => [prefix + start, end]), +); + +const regexes = pairs.map( + ([start, end]) => + // await mo.ai.agents.run_agent( + any number of spaces + start + capture + any number of spaces + end) + [ + start, + new RegExp( + `^await\\smo\\.ai\\.agents\\.run_agent\\(\\s*${start}(.*)${end}\\s*\\)$`, + "s", + ), + ] as const, +); + +/** + * Language adapter for Markdown. + */ +export class AIAgentLanguageAdapter implements LanguageAdapter { + readonly type = "agent"; + readonly defaultCode = 'await mo.ai.agents.run_agent(r""" """)'; + + lastQuotePrefix: QuotePrefixKind = ""; + + transformIn(pythonCode: string): [string, number] { + pythonCode = pythonCode.trim(); + + // empty string + if (pythonCode === "") { + this.lastQuotePrefix = "r"; + return ["", 0]; + } + + for (const [start, regex] of regexes) { + const match = pythonCode.match(regex); + if (match) { + const innerCode = match[1]; + + const [quotePrefix, quoteType] = splitQuotePrefix(start); + // store the quote prefix for later when we transform out + this.lastQuotePrefix = quotePrefix; + const unescapedCode = innerCode.replaceAll(`\\${quoteType}`, quoteType); + + const offset = pythonCode.indexOf(innerCode); + // string-dedent expects the first and last line to be empty / contain only whitespace, so we pad with \n + return [dedent(`\n${unescapedCode}\n`).trim(), offset]; + } + } + + // no match + return [pythonCode, 0]; + } + + transformOut(code: string): [string, number] { + // Get the quote type from the last transformIn + const prefix = this.lastQuotePrefix; + + // Empty string + if (code === "") { + // Need at least a space, otherwise the output will be 6 quotes + code = " "; + } + + // We always transform back with triple quotes, as to avoid needing to + // escape single quotes. + const escapedCode = code.replaceAll('"""', String.raw`\"""`); + + // If its one line and not bounded by quotes, write it as single line + const isOneLine = !code.includes("\n"); + const boundedByQuote = code.startsWith('"') || code.endsWith('"'); + if (isOneLine && !boundedByQuote) { + const start = `await mo.ai.agents.run_agent(${prefix}"""`; + const end = `""")`; + return [start + escapedCode + end, start.length]; + } + + // Multiline code + const start = `await mo.ai.agents.run_agent(\n ${prefix}"""\n`; + const end = `\n """\n)`; + return [start + indentOneTab(escapedCode) + end, start.length + 1]; + } + + isSupported(pythonCode: string): boolean { + if (pythonCode.startsWith("await mo.ai.agents.run_agent(")) { + return true; + } + + if (pythonCode.trim() === "") { + return true; + } + + return regexes.some(([, regex]) => regex.test(pythonCode)); + } + + getExtension( + _completionConfig: CompletionConfig, + _hotkeys: HotkeyProvider, + _: PlaceholderType, + _movementCallbacks: MovementCallbacks, + ): Extension[] { + return []; + } +} diff --git a/frontend/src/core/codemirror/language/extension.ts b/frontend/src/core/codemirror/language/extension.ts index 5a9ee5defe5..4d0ea619bb7 100644 --- a/frontend/src/core/codemirror/language/extension.ts +++ b/frontend/src/core/codemirror/language/extension.ts @@ -235,6 +235,9 @@ export function getInitialLanguageAdapter(state: EditorView["state"]) { if (LanguageAdapters.sql().isSupported(doc)) { return LanguageAdapters.sql(); } + if (LanguageAdapters.agent().isSupported(doc)) { + return LanguageAdapters.agent(); + } return LanguageAdapters.python(); } diff --git a/frontend/src/core/codemirror/language/types.ts b/frontend/src/core/codemirror/language/types.ts index 0488d026671..85b091781e6 100644 --- a/frontend/src/core/codemirror/language/types.ts +++ b/frontend/src/core/codemirror/language/types.ts @@ -24,4 +24,4 @@ export interface LanguageAdapter { ): Extension[]; } -export type LanguageAdapterType = "python" | "markdown" | "sql"; +export type LanguageAdapterType = "python" | "markdown" | "sql" | "agent"; diff --git a/frontend/src/core/config/__tests__/config-schema.test.ts b/frontend/src/core/config/__tests__/config-schema.test.ts index c5950f1ed17..1481be4393e 100644 --- a/frontend/src/core/config/__tests__/config-schema.test.ts +++ b/frontend/src/core/config/__tests__/config-schema.test.ts @@ -46,7 +46,7 @@ test("default UserConfig - empty", () => { "copilot": false, }, "display": { - "cell_output": "above", + "cell_output": "below", "code_editor_font_size": 14, "dataframes": "rich", "default_width": "medium", @@ -99,7 +99,7 @@ test("default UserConfig - one level", () => { "copilot": false, }, "display": { - "cell_output": "above", + "cell_output": "below", "code_editor_font_size": 14, "dataframes": "rich", "default_width": "medium", diff --git a/marimo/__init__.py b/marimo/__init__.py index 039bb66213e..9c7d7dd536f 100644 --- a/marimo/__init__.py +++ b/marimo/__init__.py @@ -24,6 +24,7 @@ "Thread", # Other namespaces "ai", + "editor", "ui", "islands", # Application elements @@ -82,7 +83,7 @@ "video", "vstack", ] -__version__ = "0.11.7" +__version__ = "0.0.1" import marimo._ai as ai import marimo._islands as islands @@ -117,7 +118,7 @@ from marimo._plugins.stateless.tabs import tabs from marimo._plugins.stateless.tree import tree from marimo._plugins.stateless.video import video -from marimo._runtime import output +from marimo._runtime import editor, output from marimo._runtime.capture import ( capture_stderr, capture_stdout, diff --git a/marimo/_ai/__init__.py b/marimo/_ai/__init__.py index 88844f85f06..b683da8e987 100644 --- a/marimo/_ai/__init__.py +++ b/marimo/_ai/__init__.py @@ -6,9 +6,11 @@ "ChatModelConfig", "ChatAttachment", "llm", + "agents", ] import marimo._ai.llm as llm +from marimo._ai import agents from marimo._ai._types import ( ChatAttachment, ChatMessage, diff --git a/marimo/_ai/agents.py b/marimo/_ai/agents.py new file mode 100644 index 00000000000..6bb3c20819d --- /dev/null +++ b/marimo/_ai/agents.py @@ -0,0 +1,32 @@ +# Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + +from typing import Any, Callable + +from marimo._output.rich_help import mddoc +from marimo._runtime.context import ContextNotInitializedError, get_context + + +@mddoc +def register_agent(run_fn: Callable[..., Any], name: str = "default") -> None: + """Register an LLM agent.""" + try: + _registry = get_context().agent_registry + _registry.register(run_fn, name) + except ContextNotInitializedError: + # Registration may be picked up later, but there is nothing to do + # at this point. + pass + + +@mddoc +async def run_agent(prompt: str, name: str = "default") -> Any: + """ + Run an LLM agent. + """ + try: + _registry = get_context().agent_registry + agent_fn = _registry.get_agent(name) + return agent_fn(prompt) + except ContextNotInitializedError: + pass diff --git a/marimo/_config/config.py b/marimo/_config/config.py index 879f39f2fb3..9c7f598a3cd 100644 --- a/marimo/_config/config.py +++ b/marimo/_config/config.py @@ -300,7 +300,7 @@ class PartialMarimoConfig(TypedDict, total=False): "display": { "theme": "light", "code_editor_font_size": 14, - "cell_output": "above", + "cell_output": "below", "default_width": "medium", "dataframes": "rich", }, diff --git a/marimo/_runtime/agents.py b/marimo/_runtime/agents.py new file mode 100644 index 00000000000..e41d85dbd71 --- /dev/null +++ b/marimo/_runtime/agents.py @@ -0,0 +1,24 @@ +# Copyright 2024 Marimo. All rights reserved. +from __future__ import annotations + +from typing import Any, Callable + + +class AgentRegistry: + def __init__(self) -> None: + self._agents: dict[str, Callable[..., Any]] = {} + + def register( + self, + agent_fn: Callable[..., Any], + name: str = "default", + ) -> None: + self._agents[name] = agent_fn + + def get_agent( + self, + name: str = "default", + ) -> Callable[..., Any]: + if name not in self._agents: + raise ValueError(f"Agent name '{name}' is not registered.") + return self._agents[name] diff --git a/marimo/_runtime/context/kernel_context.py b/marimo/_runtime/context/kernel_context.py index c7125c1e6db..711264eaa59 100644 --- a/marimo/_runtime/context/kernel_context.py +++ b/marimo/_runtime/context/kernel_context.py @@ -142,6 +142,7 @@ def create_kernel_context( parent: KernelRuntimeContext | None = None, ) -> KernelRuntimeContext: from marimo._plugins.ui._core.registry import UIElementRegistry + from marimo._runtime.agents import AgentRegistry from marimo._runtime.state import StateRegistry from marimo._runtime.virtual_file import VirtualFileRegistry @@ -151,6 +152,7 @@ def create_kernel_context( _app=app, ui_element_registry=UIElementRegistry(), state_registry=StateRegistry(), + agent_registry=AgentRegistry(), function_registry=FunctionRegistry(), cell_lifecycle_registry=CellLifecycleRegistry(), app_kernel_runner_registry=AppKernelRunnerRegistry(), diff --git a/marimo/_runtime/context/script_context.py b/marimo/_runtime/context/script_context.py index cb0417492ad..9d7f204c8e8 100644 --- a/marimo/_runtime/context/script_context.py +++ b/marimo/_runtime/context/script_context.py @@ -12,6 +12,7 @@ from marimo._config.manager import get_default_config_manager from marimo._plugins.ui._core.ids import NoIDProviderException from marimo._plugins.ui._core.registry import UIElementRegistry +from marimo._runtime.agents import AgentRegistry from marimo._runtime.cell_lifecycle_registry import CellLifecycleRegistry from marimo._runtime.context.types import ( ExecutionContext, @@ -140,6 +141,7 @@ def initialize_script_context( ui_element_registry=UIElementRegistry(), state_registry=StateRegistry(), function_registry=FunctionRegistry(), + agent_registry=AgentRegistry(), cell_lifecycle_registry=CellLifecycleRegistry(), app_kernel_runner_registry=AppKernelRunnerRegistry(), virtual_file_registry=VirtualFileRegistry(), diff --git a/marimo/_runtime/context/types.py b/marimo/_runtime/context/types.py index a9f2de10837..72b517e4dda 100644 --- a/marimo/_runtime/context/types.py +++ b/marimo/_runtime/context/types.py @@ -25,6 +25,7 @@ from marimo._messaging.types import Stream from marimo._output.hypertext import Html from marimo._plugins.ui._core.registry import UIElementRegistry + from marimo._runtime.agents import AgentRegistry from marimo._runtime.params import CLIArgs, QueryParams from marimo._runtime.state import State, StateRegistry from marimo._runtime.virtual_file import VirtualFileRegistry @@ -67,6 +68,7 @@ class ExecutionContext: class RuntimeContext(abc.ABC): ui_element_registry: UIElementRegistry state_registry: StateRegistry + agent_registry: AgentRegistry function_registry: FunctionRegistry cell_lifecycle_registry: CellLifecycleRegistry virtual_file_registry: VirtualFileRegistry diff --git a/marimo/_runtime/editor/__init__.py b/marimo/_runtime/editor/__init__.py new file mode 100644 index 00000000000..a2816093d8c --- /dev/null +++ b/marimo/_runtime/editor/__init__.py @@ -0,0 +1,7 @@ +# Copyright 2024 Marimo. All rights reserved. + +__all__ = [ + "register_datasource", +] + +from marimo._runtime.editor._editor import register_datasource diff --git a/marimo/_runtime/editor/_editor.py b/marimo/_runtime/editor/_editor.py new file mode 100644 index 00000000000..740f15d9f3a --- /dev/null +++ b/marimo/_runtime/editor/_editor.py @@ -0,0 +1,48 @@ +# Copyright 2024 Marimo. All rights reserved. +from marimo._ast.visitor import VariableData +from marimo._output.rich_help import mddoc +from marimo._plugins.ui._impl.tables.utils import get_table_manager_or_none +from marimo._runtime.context.types import ( + ContextNotInitializedError, + get_context, +) + + +@mddoc +def register_datasource(obj: object, name: str) -> None: + """Register a datasource. + + This registered object will be available in the global scope of the + notebook, including as a variable in the graph. + + WARNING: This function may cause unintended bugs in reactivity, since + defined variables cannot be statically analyzed. Also, this can be + confusing for users if used inappropriately to flood the global scope. + Please be mindful of this function. + + **Args:** + + - `obj`: The datasource object to register. + - `name`: The name to register the datasource under. + """ + try: + ctx = get_context() + except ContextNotInitializedError: + return + + if ctx.execution_context is None: + return + + if get_table_manager_or_none(obj) is None: + raise ValueError(f"Failed to get table data for variable {name}") + + ctx.globals[name] = obj + + cell_id = ctx.execution_context.cell_id + cell = ctx.graph.cells[cell_id] + cell.defs.add(name) + cell.variable_data[name] = [VariableData("variable")] + if name in ctx.graph.definitions: + ctx.graph.definitions[name].add(cell_id) + else: + ctx.graph.definitions.update({name: {cell_id}}) diff --git a/pyproject.toml b/pyproject.toml index 874fbe7861a..9bb88478761 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,8 +3,8 @@ requires = ["hatchling"] build-backend = "hatchling.build" [project] -name = "marimo" -description = "A library for making reactive notebooks and apps" +name = "marimo-agents" +description = "A Marimo concept fork with some support for LLM execution as cells" dynamic = ["version"] # We try to keep dependencies to a minimum, to avoid conflicts with # user environments;we need a very compelling reason for each dependency added. diff --git a/tests/_runtime/editor/test_editor.py b/tests/_runtime/editor/test_editor.py new file mode 100644 index 00000000000..a3c3ef480e9 --- /dev/null +++ b/tests/_runtime/editor/test_editor.py @@ -0,0 +1,57 @@ +import sys +import types + +import pytest + +from marimo._dependencies.dependencies import DependencyManager +from marimo._runtime.runtime import Kernel +from tests.conftest import ExecReqProvider + + +class TestCellRun: + @staticmethod + @pytest.mark.skipif( + condition=not DependencyManager.pandas.has(), + reason="requires matplotlib", + ) + async def test_register_datasource( + execution_kernel: Kernel, exec_req: ExecReqProvider + ) -> None: + registering_fn_module = types.ModuleType("registering_fn_module") + exec( + """ + import marimo as mo + import pandas as pd + + def registering_fn() -> None: + print('hi') + df = pd.DataFrame({'a': [1], 'b': [2]}) + name_fn = lambda x: x + mo.editor.register_datasource(df, name_fn('test_var_name')) + """, + registering_fn_module.__dict__, + ) + + # Add this module to `sys.modules` + sys.modules["registering_fn_module"] = registering_fn_module + k = execution_kernel + await k.run( + [ + exec_req.get( + """ + import registering_fn_module + registering_fn_module.registering_fn() + """ + ) + ] + ) + + assert k.globals["registering_fn_module"] + assert "test_var_name" in k.globals, ( + "test_var_name not found in globals." + ) + import pandas as pd + + assert isinstance(k.globals["test_var_name"], pd.DataFrame), ( + "test_var_name is not a pandas DataFrame." + )