From 2a252f904ba40b0d50e75c6140e62f58dd4438d5 Mon Sep 17 00:00:00 2001 From: Arvin Xu Date: Sun, 12 May 2024 17:10:52 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20perf:=20improve=20tool=20c?= =?UTF-8?q?alling=20streaming=20(#2460)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🐛 fix: fix token count * ⚡️ perf: reduce plugin store fetcher * ⚡️ perf: support tool calling smoothing * ✅ test: fix test * ♻️ refactor: refactor the dalle generation * 🚸 style: improve dalle plugin error handle * ✅ test: fix test * ✅ test: fix test --- package.json | 58 ++-- src/app/api/chat/[provider]/route.test.ts | 4 +- src/app/api/chat/[provider]/route.ts | 2 +- src/app/api/chat/models/[provider]/route.ts | 2 +- src/app/api/config.test.ts | 52 +--- .../{chat => middleware}/auth/index.test.ts | 0 .../api/{chat => middleware}/auth/index.ts | 0 .../api/{chat => middleware}/auth/utils.ts | 0 .../api/openai/createBizOpenAI/auth.test.ts | 52 ++++ .../api/{ => openai/createBizOpenAI}/auth.ts | 0 src/app/api/openai/createBizOpenAI/index.ts | 2 +- .../openai/images/createImageGeneration.ts | 26 -- src/app/api/openai/images/route.ts | 16 -- src/app/api/plugin/gateway/route.ts | 2 +- src/app/api/text-to-image/[provider]/route.ts | 61 ++++ src/components/GalleyGrid/index.tsx | 4 +- src/database/client/schemas/message.ts | 2 + .../Conversation/Actions/Assistant.tsx | 5 +- .../Conversation/Actions/Function.tsx | 17 -- src/features/Conversation/Actions/Tool.tsx | 34 ++- .../Messages/Assistant/ToolCalls/index.tsx | 23 +- .../Conversation/Messages/Assistant/index.tsx | 10 +- .../Messages/Tool/Inspector/index.tsx | 2 +- .../Conversation/Plugins/Render/index.tsx | 13 +- src/hooks/useTokenCount.test.ts | 38 +++ src/hooks/useTokenCount.ts | 3 +- src/libs/agent-runtime/AgentRuntime.ts | 10 +- src/libs/agent-runtime/BaseAI.ts | 3 + src/libs/agent-runtime/types/index.ts | 1 + src/libs/agent-runtime/types/textToImage.ts | 34 +++ src/libs/agent-runtime/utils/createError.ts | 1 + .../utils/openaiCompatibleFactory/index.ts | 51 ++++ src/locales/default/tool.ts | 1 + src/services/_url.ts | 2 +- .../{imageGeneration.ts => textToImage.ts} | 13 +- src/store/chat/initialState.ts | 2 +- src/store/chat/selectors.ts | 2 +- .../{tool => builtinTool}/action.test.ts | 2 +- .../slices/{tool => builtinTool}/action.ts | 20 +- .../{tool => builtinTool}/initialState.ts | 0 .../slices/{tool => builtinTool}/selectors.ts | 0 src/store/chat/slices/enchance/action.ts | 21 +- src/store/chat/slices/message/action.ts | 122 ++------ src/store/chat/slices/message/initialState.ts | 5 + src/store/chat/slices/message/selectors.ts | 8 + src/store/chat/slices/plugin/action.test.ts | 2 +- src/store/chat/slices/plugin/action.ts | 175 ++++++------ src/store/chat/store.ts | 4 +- src/store/tool/slices/store/action.test.ts | 8 +- src/store/tool/slices/store/action.ts | 4 +- .../dalle/Render/{ => Item}/EditMode.tsx | 0 src/tools/dalle/Render/Item/Error.tsx | 50 ++++ src/tools/dalle/Render/Item/Image.tsx | 44 +++ .../dalle/Render/{Item.tsx => Item/index.tsx} | 49 ++-- src/utils/fetch.test.ts | 211 +++++++++++++- src/utils/fetch.ts | 261 ++++++++++++++++-- 56 files changed, 1101 insertions(+), 433 deletions(-) rename src/app/api/{chat => middleware}/auth/index.test.ts (100%) rename src/app/api/{chat => middleware}/auth/index.ts (100%) rename src/app/api/{chat => middleware}/auth/utils.ts (100%) create mode 100644 src/app/api/openai/createBizOpenAI/auth.test.ts rename src/app/api/{ => openai/createBizOpenAI}/auth.ts (100%) delete mode 100644 src/app/api/openai/images/createImageGeneration.ts delete mode 100644 src/app/api/openai/images/route.ts create mode 100644 src/app/api/text-to-image/[provider]/route.ts delete mode 100644 src/features/Conversation/Actions/Function.tsx create mode 100644 src/hooks/useTokenCount.test.ts create mode 100644 src/libs/agent-runtime/types/textToImage.ts rename src/services/{imageGeneration.ts => textToImage.ts} (66%) rename src/store/chat/slices/{tool => builtinTool}/action.test.ts (98%) rename src/store/chat/slices/{tool => builtinTool}/action.ts (84%) rename src/store/chat/slices/{tool => builtinTool}/initialState.ts (100%) rename src/store/chat/slices/{tool => builtinTool}/selectors.ts (100%) rename src/tools/dalle/Render/{ => Item}/EditMode.tsx (100%) create mode 100644 src/tools/dalle/Render/Item/Error.tsx create mode 100644 src/tools/dalle/Render/Item/Image.tsx rename src/tools/dalle/Render/{Item.tsx => Item/index.tsx} (60%) diff --git a/package.json b/package.json index 83480b7e57e6a..8b48fad807fb6 100644 --- a/package.json +++ b/package.json @@ -81,25 +81,25 @@ ] }, "dependencies": { - "@ant-design/icons": "^5.3.6", + "@ant-design/icons": "^5.3.7", "@anthropic-ai/sdk": "^0.20.9", "@auth/core": "0.28.0", - "@aws-sdk/client-bedrock-runtime": "^3.565.0", - "@azure/openai": "^1.0.0-beta.12", + "@aws-sdk/client-bedrock-runtime": "^3.574.0", + "@azure/openai": "1.0.0-beta.12", "@cfworker/json-schema": "^1.12.8", "@clerk/localizations": "2.0.0", - "@clerk/nextjs": "^5.0.6", - "@clerk/themes": "^2.0.0", + "@clerk/nextjs": "^5.0.8", + "@clerk/themes": "^2.1.3", "@google/generative-ai": "^0.10.0", - "@icons-pack/react-simple-icons": "^9.4.1", + "@icons-pack/react-simple-icons": "^9.5.0", "@lobehub/chat-plugin-sdk": "latest", "@lobehub/chat-plugins-gateway": "latest", "@lobehub/icons": "latest", "@lobehub/tts": "latest", - "@lobehub/ui": "^1.138.17", + "@lobehub/ui": "^1.138.23", "@microsoft/fetch-event-source": "^2.0.1", "@next/third-parties": "^14.2.3", - "@sentry/nextjs": "^7.112.2", + "@sentry/nextjs": "^7.114.0", "@t3-oss/env-nextjs": "^0.10.1", "@trpc/client": "next", "@trpc/next": "next", @@ -118,15 +118,15 @@ "diff": "^5.2.0", "fast-deep-equal": "^3.1.3", "gpt-tokenizer": "^2.1.2", - "i18next": "^23.11.3", + "i18next": "^23.11.4", "i18next-browser-languagedetector": "^7.2.1", "i18next-resources-to-backend": "^1.2.1", "idb-keyval": "^6.2.1", "immer": "^10.1.1", "ip": "^2.0.1", - "jose": "^5.2.4", - "langfuse": "^3.8.0", - "langfuse-core": "^3.8.0", + "jose": "^5.3.0", + "langfuse": "^3.10.0", + "langfuse-core": "^3.10.0", "lodash-es": "^4.17.21", "lucide-react": "latest", "modern-screenshot": "^4.4.39", @@ -135,12 +135,12 @@ "next-auth": "5.0.0-beta.15", "next-sitemap": "^4.2.3", "numeral": "^2.0.6", - "nuqs": "^1.17.1", - "ollama": "^0.5.0", - "openai": "^4.39.0", + "nuqs": "^1.17.2", + "ollama": "^0.5.1", + "openai": "^4.45.0", "pino": "^9.0.0", "polished": "^4.3.1", - "posthog-js": "^1.130.1", + "posthog-js": "^1.131.4", "query-string": "^9.0.0", "random-words": "^2.0.1", "react": "^18.3.1", @@ -155,7 +155,7 @@ "remark-gfm": "^3.0.1", "remark-html": "^15.0.2", "rtl-detect": "^1.1.2", - "semver": "^7.6.0", + "semver": "^7.6.2", "sharp": "^0.33.3", "superjson": "^2.2.1", "swr": "^2.2.5", @@ -170,13 +170,13 @@ "y-webrtc": "^10.3.0", "yaml": "^2.4.2", "yjs": "^13.6.15", - "zod": "^3.23.5", + "zod": "^3.23.8", "zustand": "^4.5.2", "zustand-utils": "^1.3.2" }, "devDependencies": { "@commitlint/cli": "^19.3.0", - "@ducanh2912/next-pwa": "^10.2.6", + "@ducanh2912/next-pwa": "^10.2.7", "@edge-runtime/vm": "^3.2.0", "@lobehub/i18n-cli": "^1.18.1", "@lobehub/lint": "^1.23.4", @@ -185,24 +185,24 @@ "@next/eslint-plugin-next": "^14.2.3", "@peculiar/webcrypto": "^1.4.6", "@testing-library/jest-dom": "^6.4.5", - "@testing-library/react": "^15.0.6", + "@testing-library/react": "^15.0.7", "@types/chroma-js": "^2.4.4", "@types/debug": "^4.1.12", - "@types/diff": "^5.2.0", + "@types/diff": "^5.2.1", "@types/ip": "^1.1.3", "@types/json-schema": "^7.0.15", - "@types/lodash": "^4.17.0", + "@types/lodash": "^4.17.1", "@types/lodash-es": "^4.17.12", - "@types/node": "^20.12.7", + "@types/node": "^20.12.11", "@types/numeral": "^2.0.5", - "@types/react": "^18.3.1", + "@types/react": "^18.3.2", "@types/react-dom": "^18.3.0", "@types/rtl-detect": "^1.0.3", "@types/semver": "^7.5.8", "@types/systemjs": "^6.13.5", "@types/ua-parser-js": "^0.7.39", "@types/uuid": "^9.0.8", - "@umijs/lint": "^4.1.10", + "@umijs/lint": "^4.2.2", "@vitest/coverage-v8": "~1.2.2", "ajv-keywords": "^5.1.0", "commitlint": "^19.3.0", @@ -211,9 +211,9 @@ "eslint": "^8.57.0", "eslint-plugin-mdx": "^2.3.4", "fake-indexeddb": "^5.0.2", - "glob": "^10.3.12", + "glob": "^10.3.15", "gray-matter": "^4.0.3", - "happy-dom": "^14.7.1", + "happy-dom": "^14.10.1", "husky": "^9.0.11", "just-diff": "^6.0.2", "lint-staged": "^15.2.2", @@ -227,11 +227,11 @@ "remark-parse": "^10.0.2", "semantic-release": "^21.1.2", "stylelint": "^15.11.0", - "tsx": "^4.7.3", + "tsx": "^4.10.0", "typescript": "^5.4.5", "unified": "^11.0.4", "unist-util-visit": "^5.0.0", - "vite": "^5.2.10", + "vite": "^5.2.11", "vitest": "~1.2.2", "vitest-canvas-mock": "^0.3.3" }, diff --git a/src/app/api/chat/[provider]/route.test.ts b/src/app/api/chat/[provider]/route.test.ts index ebb3fe1fa597c..f433107c5bfd1 100644 --- a/src/app/api/chat/[provider]/route.test.ts +++ b/src/app/api/chat/[provider]/route.test.ts @@ -2,18 +2,18 @@ import { getAuth } from '@clerk/nextjs/server'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { checkAuthMethod, getJWTPayload } from '@/app/api/middleware/auth/utils'; import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED } from '@/const/auth'; import { AgentRuntime, LobeRuntimeAI } from '@/libs/agent-runtime'; import { ChatErrorType } from '@/types/fetch'; -import { checkAuthMethod, getJWTPayload } from '../auth/utils'; import { POST } from './route'; vi.mock('@clerk/nextjs/server', () => ({ getAuth: vi.fn(), })); -vi.mock('../auth/utils', () => ({ +vi.mock('../../middleware/auth/utils', () => ({ getJWTPayload: vi.fn(), checkAuthMethod: vi.fn(), })); diff --git a/src/app/api/chat/[provider]/route.ts b/src/app/api/chat/[provider]/route.ts index 974df7cb096e2..75b9374ebc084 100644 --- a/src/app/api/chat/[provider]/route.ts +++ b/src/app/api/chat/[provider]/route.ts @@ -5,8 +5,8 @@ import { ChatErrorType } from '@/types/fetch'; import { ChatStreamPayload } from '@/types/openai/chat'; import { getTracePayload } from '@/utils/trace'; +import { checkAuth } from '../../middleware/auth'; import { createTraceOptions, initAgentRuntimeWithUserPayload } from '../agentRuntime'; -import { checkAuth } from '../auth'; export const runtime = 'edge'; diff --git a/src/app/api/chat/models/[provider]/route.ts b/src/app/api/chat/models/[provider]/route.ts index ee3ef4f32e0f1..9b292238d4b1c 100644 --- a/src/app/api/chat/models/[provider]/route.ts +++ b/src/app/api/chat/models/[provider]/route.ts @@ -5,8 +5,8 @@ import { createErrorResponse } from '@/app/api/errorResponse'; import { ChatCompletionErrorPayload, ModelProvider } from '@/libs/agent-runtime'; import { ChatErrorType } from '@/types/fetch'; +import { checkAuth } from '../../../middleware/auth'; import { initAgentRuntimeWithUserPayload } from '../../agentRuntime'; -import { checkAuth } from '../../auth'; export const runtime = 'edge'; diff --git a/src/app/api/config.test.ts b/src/app/api/config.test.ts index 8a7fdfbb3b92d..9cbdda9802c01 100644 --- a/src/app/api/config.test.ts +++ b/src/app/api/config.test.ts @@ -1,7 +1,7 @@ import { describe, expect, it, vi } from 'vitest'; -import { checkAuth } from './auth'; import { getPreferredRegion } from './config'; +import { checkAuth } from './openai/createBizOpenAI/auth'; // Stub the global process object to safely mock environment variables vi.stubGlobal('process', { @@ -41,53 +41,3 @@ describe('getPreferredRegion', () => { expect(preferredRegion).toStrictEqual(['ida1', 'sfo1']); }); }); - -describe('ACCESS_CODE', () => { - let auth = false; - - beforeEach(() => { - auth = false; - process.env.ACCESS_CODE = undefined; - // Reset environment variables before each test case - vi.restoreAllMocks(); - }); - - it('set multiple access codes', () => { - process.env.ACCESS_CODE = ',code1,code2,code3'; - ({ auth } = checkAuth({ accessCode: 'code1' })); - expect(auth).toBe(true); - ({ auth } = checkAuth({ accessCode: 'code2' })); - expect(auth).toBe(true); - ({ auth } = checkAuth({ accessCode: 'code1,code2' })); - expect(auth).toBe(false); - }); - - it('set individual access code', () => { - process.env.ACCESS_CODE = 'code1'; - ({ auth } = checkAuth({ accessCode: 'code1' })); - expect(auth).toBe(true); - ({ auth } = checkAuth({ accessCode: 'code2' })); - expect(auth).toBe(false); - }); - - it('no access code', () => { - ({ auth } = checkAuth({ accessCode: 'code1' })); - expect(auth).toBe(true); - ({ auth } = checkAuth({})); - expect(auth).toBe(true); - }); - - it('empty access code', () => { - process.env.ACCESS_CODE = ''; - ({ auth } = checkAuth({ accessCode: 'code1' })); - expect(auth).toBe(true); - ({ auth } = checkAuth({})); - expect(auth).toBe(true); - - process.env.ACCESS_CODE = ',,'; - ({ auth } = checkAuth({ accessCode: 'code1' })); - expect(auth).toBe(true); - ({ auth } = checkAuth({})); - expect(auth).toBe(true); - }); -}); diff --git a/src/app/api/chat/auth/index.test.ts b/src/app/api/middleware/auth/index.test.ts similarity index 100% rename from src/app/api/chat/auth/index.test.ts rename to src/app/api/middleware/auth/index.test.ts diff --git a/src/app/api/chat/auth/index.ts b/src/app/api/middleware/auth/index.ts similarity index 100% rename from src/app/api/chat/auth/index.ts rename to src/app/api/middleware/auth/index.ts diff --git a/src/app/api/chat/auth/utils.ts b/src/app/api/middleware/auth/utils.ts similarity index 100% rename from src/app/api/chat/auth/utils.ts rename to src/app/api/middleware/auth/utils.ts diff --git a/src/app/api/openai/createBizOpenAI/auth.test.ts b/src/app/api/openai/createBizOpenAI/auth.test.ts new file mode 100644 index 0000000000000..6584dc0487a00 --- /dev/null +++ b/src/app/api/openai/createBizOpenAI/auth.test.ts @@ -0,0 +1,52 @@ +import { checkAuth } from './auth'; + +describe('ACCESS_CODE', () => { + let auth = false; + + beforeEach(() => { + auth = false; + process.env.ACCESS_CODE = undefined; + // Reset environment variables before each test case + vi.restoreAllMocks(); + }); + + it('set multiple access codes', () => { + process.env.ACCESS_CODE = ',code1,code2,code3'; + ({ auth } = checkAuth({ accessCode: 'code1' })); + expect(auth).toBe(true); + ({ auth } = checkAuth({ accessCode: 'code2' })); + expect(auth).toBe(true); + ({ auth } = checkAuth({ accessCode: 'code1,code2' })); + expect(auth).toBe(false); + }); + + it('set individual access code', () => { + process.env.ACCESS_CODE = 'code1'; + ({ auth } = checkAuth({ accessCode: 'code1' })); + expect(auth).toBe(true); + ({ auth } = checkAuth({ accessCode: 'code2' })); + expect(auth).toBe(false); + }); + + it('no access code', () => { + delete process.env.ACCESS_CODE; + ({ auth } = checkAuth({ accessCode: 'code1' })); + expect(auth).toBe(true); + ({ auth } = checkAuth({})); + expect(auth).toBe(true); + }); + + it('empty access code', () => { + process.env.ACCESS_CODE = ''; + ({ auth } = checkAuth({ accessCode: 'code1' })); + expect(auth).toBe(true); + ({ auth } = checkAuth({})); + expect(auth).toBe(true); + + process.env.ACCESS_CODE = ',,'; + ({ auth } = checkAuth({ accessCode: 'code1' })); + expect(auth).toBe(true); + ({ auth } = checkAuth({})); + expect(auth).toBe(true); + }); +}); diff --git a/src/app/api/auth.ts b/src/app/api/openai/createBizOpenAI/auth.ts similarity index 100% rename from src/app/api/auth.ts rename to src/app/api/openai/createBizOpenAI/auth.ts diff --git a/src/app/api/openai/createBizOpenAI/index.ts b/src/app/api/openai/createBizOpenAI/index.ts index 89d3d18e00169..0742ca512d146 100644 --- a/src/app/api/openai/createBizOpenAI/index.ts +++ b/src/app/api/openai/createBizOpenAI/index.ts @@ -1,10 +1,10 @@ import OpenAI from 'openai'; -import { checkAuth } from '@/app/api/auth'; import { getOpenAIAuthFromRequest } from '@/const/fetch'; import { ChatErrorType, ErrorType } from '@/types/fetch'; import { createErrorResponse } from '../../errorResponse'; +import { checkAuth } from './auth'; import { createOpenai } from './createOpenai'; /** diff --git a/src/app/api/openai/images/createImageGeneration.ts b/src/app/api/openai/images/createImageGeneration.ts deleted file mode 100644 index 77674cf12c828..0000000000000 --- a/src/app/api/openai/images/createImageGeneration.ts +++ /dev/null @@ -1,26 +0,0 @@ -import OpenAI from 'openai'; - -import { OpenAIImagePayload } from '@/types/openai/image'; - -export const createImageGeneration = async ({ - openai, - payload, -}: { - openai: OpenAI; - payload: OpenAIImagePayload; -}) => { - const res = await openai.images.generate({ ...payload, response_format: 'url' }); - - const urls = res.data.map((o) => o.url) as string[]; - - return new Response(JSON.stringify(urls)); -}; - -// const mockImages = [ -// 'https://github-production-user-asset-6210df.s3.amazonaws.com/28616219/292159272-032d5c8b-20be-48d9-8dbb-f2491f231bac.png', -// 'https://github-production-user-asset-6210df.s3.amazonaws.com/28616219/292159798-cad89421-20c5-44b0-a337-fcbb857f1f70.png', -// 'https://github-production-user-asset-6210df.s3.amazonaws.com/28616219/292160015-2263156f-d41f-48ae-9c2c-d96799b9d2b8.png', -// 'https://github-production-user-asset-6210df.s3.amazonaws.com/28616219/292160229-592d112f-5dfd-47d7-98d3-44bc09dc91f7.png', -// ]; -// export const createImageGeneration = async () => -// new Response(JSON.stringify([mockImages[Math.round(Math.random() * 3)]])); diff --git a/src/app/api/openai/images/route.ts b/src/app/api/openai/images/route.ts deleted file mode 100644 index 483387dccfd86..0000000000000 --- a/src/app/api/openai/images/route.ts +++ /dev/null @@ -1,16 +0,0 @@ -import { OpenAIImagePayload } from '@/types/openai/image'; - -import { createBizOpenAI } from '../createBizOpenAI'; -import { createImageGeneration } from './createImageGeneration'; - -export const runtime = 'edge'; - -export const POST = async (req: Request) => { - const payload = (await req.json()) as OpenAIImagePayload; - - const openaiOrErrResponse = createBizOpenAI(req); - // if resOrOpenAI is a Response, it means there is an error,just return it - if (openaiOrErrResponse instanceof Response) return openaiOrErrResponse; - - return createImageGeneration({ openai: openaiOrErrResponse, payload }); -}; diff --git a/src/app/api/plugin/gateway/route.ts b/src/app/api/plugin/gateway/route.ts index 832b712735c46..dcddc8d429975 100644 --- a/src/app/api/plugin/gateway/route.ts +++ b/src/app/api/plugin/gateway/route.ts @@ -1,8 +1,8 @@ import { PluginRequestPayload } from '@lobehub/chat-plugin-sdk'; import { createGatewayOnEdgeRuntime } from '@lobehub/chat-plugins-gateway'; -import { getJWTPayload } from '@/app/api/chat/auth/utils'; import { createErrorResponse } from '@/app/api/errorResponse'; +import { getJWTPayload } from '@/app/api/middleware/auth/utils'; import { getServerConfig } from '@/config/server'; import { LOBE_CHAT_AUTH_HEADER, OAUTH_AUTHORIZED, enableNextAuth } from '@/const/auth'; import { LOBE_CHAT_TRACE_ID, TraceNameMap } from '@/const/trace'; diff --git a/src/app/api/text-to-image/[provider]/route.ts b/src/app/api/text-to-image/[provider]/route.ts new file mode 100644 index 0000000000000..12749c9de182a --- /dev/null +++ b/src/app/api/text-to-image/[provider]/route.ts @@ -0,0 +1,61 @@ +import { NextResponse } from 'next/server'; + +import { getPreferredRegion } from '@/app/api/config'; +import { createErrorResponse } from '@/app/api/errorResponse'; +import { ChatCompletionErrorPayload } from '@/libs/agent-runtime'; +import { TextToImagePayload } from '@/libs/agent-runtime/types'; +import { ChatErrorType } from '@/types/fetch'; + +import { initAgentRuntimeWithUserPayload } from '../../chat/agentRuntime'; +import { checkAuth } from '../../middleware/auth'; + +export const runtime = 'edge'; + +export const preferredRegion = getPreferredRegion(); + +// return NextResponse.json( +// { +// body: { +// endpoint: 'https://ai****ix.com/v1', +// error: { +// code: 'content_policy_violation', +// message: +// 'Your request was rejected as a result of our safety system. Image descriptions generated from your prompt may contain text that is not allowed by our safety system. If you believe this was done in error, your request may succeed if retried, or by adjusting your prompt.', +// param: null, +// type: 'invalid_request_error', +// }, +// provider: 'openai', +// }, +// errorType: 'OpenAIBizError', +// }, +// { status: 400 }, +// ); + +export const POST = checkAuth(async (req: Request, { params, jwtPayload }) => { + const { provider } = params; + + try { + // ============ 1. init chat model ============ // + const agentRuntime = await initAgentRuntimeWithUserPayload(provider, jwtPayload); + + // ============ 2. create chat completion ============ // + + const data = (await req.json()) as TextToImagePayload; + + const images = await agentRuntime.textToImage(data); + + return NextResponse.json(images); + } catch (e) { + const { + errorType = ChatErrorType.InternalServerError, + error: errorContent, + ...res + } = e as ChatCompletionErrorPayload; + + const error = errorContent || e; + // track the error at server side + console.error(`Route: [${provider}] ${errorType}:`, error); + + return createErrorResponse(errorType, { error, ...res, provider }); + } +}); diff --git a/src/components/GalleyGrid/index.tsx b/src/components/GalleyGrid/index.tsx index f05997a200ac6..d01cae4e49e1b 100644 --- a/src/components/GalleyGrid/index.tsx +++ b/src/components/GalleyGrid/index.tsx @@ -41,13 +41,13 @@ const GalleyGrid = memo(({ items, renderItem: Render }) => { {firstRow.map((i, index) => ( - + ))} {lastRow.length > 0 && ( 2 ? 3 : lastRow.length} gap={gap} max={max}> {lastRow.map((i, index) => ( - + ))} )} diff --git a/src/database/client/schemas/message.ts b/src/database/client/schemas/message.ts index feddb3960eb57..5b98a09444cdb 100644 --- a/src/database/client/schemas/message.ts +++ b/src/database/client/schemas/message.ts @@ -30,6 +30,8 @@ export const DB_MessageSchema = z.object({ plugin: PluginSchema.optional(), pluginState: z.any().optional(), + pluginError: z.any().optional(), + fromModel: z.string().optional(), fromProvider: z.string().optional(), translate: TranslateSchema.optional().or(z.literal(false)), diff --git a/src/features/Conversation/Actions/Assistant.tsx b/src/features/Conversation/Actions/Assistant.tsx index d46737c4d1e3e..a197235095506 100644 --- a/src/features/Conversation/Actions/Assistant.tsx +++ b/src/features/Conversation/Actions/Assistant.tsx @@ -6,9 +6,10 @@ import { RenderAction } from '../types'; import { ErrorActionsBar } from './Error'; import { useCustomActions } from './customAction'; -export const AssistantActionsBar: RenderAction = memo(({ id, onActionClick, error }) => { +export const AssistantActionsBar: RenderAction = memo(({ id, onActionClick, error, tools }) => { const { regenerate, edit, delAndRegenerate, copy, divider, del } = useChatListActionsBar(); const { translate, tts } = useCustomActions(); + const hasTools = !!tools; if (id === 'default') return; @@ -27,7 +28,7 @@ export const AssistantActionsBar: RenderAction = memo(({ id, onActionClick, erro delAndRegenerate, del, ]} - items={[edit, copy]} + items={[hasTools ? delAndRegenerate : edit, copy]} onActionClick={onActionClick} type="ghost" /> diff --git a/src/features/Conversation/Actions/Function.tsx b/src/features/Conversation/Actions/Function.tsx deleted file mode 100644 index c33e5446f6e3c..0000000000000 --- a/src/features/Conversation/Actions/Function.tsx +++ /dev/null @@ -1,17 +0,0 @@ -import { ActionIconGroup } from '@lobehub/ui'; -import { memo } from 'react'; - -import { useChatListActionsBar } from '../hooks/useChatListActionsBar'; -import { RenderAction } from '../types'; - -export const FunctionActionsBar: RenderAction = memo(({ onActionClick }) => { - const { regenerate, delAndRegenerate, del } = useChatListActionsBar(); - return ( - - ); -}); diff --git a/src/features/Conversation/Actions/Tool.tsx b/src/features/Conversation/Actions/Tool.tsx index d1f022c76b1ab..ca7ad80a8585e 100644 --- a/src/features/Conversation/Actions/Tool.tsx +++ b/src/features/Conversation/Actions/Tool.tsx @@ -1,16 +1,28 @@ +import { ActionIconGroup } from '@lobehub/ui'; import { memo } from 'react'; +import { useChatStore } from '@/store/chat'; + +import { useChatListActionsBar } from '../hooks/useChatListActionsBar'; import { RenderAction } from '../types'; -export const ToolActionsBar: RenderAction = memo(() => { - return undefined; - // const { regenerate } = useChatListActionsBar(); - // return ( - // - // ); +export const ToolActionsBar: RenderAction = memo(({ id }) => { + const { regenerate } = useChatListActionsBar(); + const [reInvokeToolMessage] = useChatStore((s) => [s.reInvokeToolMessage]); + + return ( + { + switch (event.key) { + case 'regenerate': { + reInvokeToolMessage(id); + break; + } + } + }} + type="ghost" + /> + ); }); diff --git a/src/features/Conversation/Messages/Assistant/ToolCalls/index.tsx b/src/features/Conversation/Messages/Assistant/ToolCalls/index.tsx index 980982f2ac998..2403369f94b5f 100644 --- a/src/features/Conversation/Messages/Assistant/ToolCalls/index.tsx +++ b/src/features/Conversation/Messages/Assistant/ToolCalls/index.tsx @@ -1,12 +1,12 @@ import { Avatar, Highlighter, Icon } from '@lobehub/ui'; import isEqual from 'fast-deep-equal'; import { Loader2, LucideChevronDown, LucideChevronRight, LucideToyBrick } from 'lucide-react'; -import { memo, useState } from 'react'; +import { CSSProperties, memo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { Center, Flexbox } from 'react-layout-kit'; import { useChatStore } from '@/store/chat'; -import { chatSelectors } from '@/store/chat/slices/message/selectors'; +import { chatSelectors } from '@/store/chat/selectors'; import { pluginHelpers, useToolStore } from '@/store/tool'; import { toolSelectors } from '@/store/tool/selectors'; @@ -15,15 +15,17 @@ import { useStyles } from './style'; export interface InspectorProps { arguments?: string; identifier: string; + index: number; messageId: string; + style: CSSProperties; } const CallItem = memo( - ({ arguments: requestArgs = '{}', messageId, identifier }) => { + ({ arguments: requestArgs = '{}', messageId, index, identifier, style }) => { const { t } = useTranslation('plugin'); const { styles } = useStyles(); const [open, setOpen] = useState(false); - const loading = useChatStore(chatSelectors.isMessageGenerating(messageId)); + const loading = useChatStore(chatSelectors.isToolCallStreaming(messageId, index)); const pluginMeta = useToolStore(toolSelectors.getMetaById(identifier), isEqual); @@ -32,20 +34,13 @@ const CallItem = memo( const pluginTitle = pluginHelpers.getPluginTitle(pluginMeta) ?? t('unknownPlugin'); const avatar = pluginAvatar ? ( - + ) : ( ); - let params; - try { - params = JSON.stringify(JSON.parse(requestArgs), null, 2); - } catch { - params = requestArgs; - } - return ( - + ( - {(open || loading) && {params}} + {(open || loading) && {requestArgs}} ); }, diff --git a/src/features/Conversation/Messages/Assistant/index.tsx b/src/features/Conversation/Messages/Assistant/index.tsx index 8f6afa3f45599..1e4fdc2c07ecd 100644 --- a/src/features/Conversation/Messages/Assistant/index.tsx +++ b/src/features/Conversation/Messages/Assistant/index.tsx @@ -7,7 +7,7 @@ import { chatSelectors } from '@/store/chat/selectors'; import { ChatMessage } from '@/types/message'; import { DefaultMessage } from '../Default'; -import ToolCalls from './ToolCalls'; +import ToolCall from './ToolCalls'; export const AssistantMessage = memo< ChatMessage & { @@ -32,12 +32,16 @@ export const AssistantMessage = memo< )} {!editing && tools && ( - {tools.map((toolCall) => ( - ( + ))} diff --git a/src/features/Conversation/Messages/Tool/Inspector/index.tsx b/src/features/Conversation/Messages/Tool/Inspector/index.tsx index 9cf85dfcfbd5c..07a90a5721c35 100644 --- a/src/features/Conversation/Messages/Tool/Inspector/index.tsx +++ b/src/features/Conversation/Messages/Tool/Inspector/index.tsx @@ -53,7 +53,7 @@ const Inspector = memo( const pluginTitle = pluginHelpers.getPluginTitle(pluginMeta) ?? t('unknownPlugin'); const avatar = pluginAvatar ? ( - + ) : ( ); diff --git a/src/features/Conversation/Plugins/Render/index.tsx b/src/features/Conversation/Plugins/Render/index.tsx index 6bd8ebdade578..3b837525cf1aa 100644 --- a/src/features/Conversation/Plugins/Render/index.tsx +++ b/src/features/Conversation/Plugins/Render/index.tsx @@ -1,12 +1,21 @@ import { PluginRequestPayload } from '@lobehub/chat-plugin-sdk'; +import { Skeleton } from 'antd'; +import dynamic from 'next/dynamic'; import { memo } from 'react'; import { LobeToolRenderType } from '@/types/tool'; -import BuiltinType from '././BuiltinType'; import DefaultType from './DefaultType'; import Markdown from './MarkdownType'; -import Standalone from './StandaloneType'; + +const loading = () => ( + + {' '} + +); + +const Standalone = dynamic(() => import('./StandaloneType'), { loading }); +const BuiltinType = dynamic(() => import('./BuiltinType'), { loading }); export interface PluginRenderProps { content: string; diff --git a/src/hooks/useTokenCount.test.ts b/src/hooks/useTokenCount.test.ts new file mode 100644 index 0000000000000..08c06348c4eeb --- /dev/null +++ b/src/hooks/useTokenCount.test.ts @@ -0,0 +1,38 @@ +import { act, renderHook, waitFor } from '@testing-library/react'; +import { describe, expect, it, vi } from 'vitest'; + +import * as tokenizers from '@/utils/tokenizer'; + +import { useTokenCount } from './useTokenCount'; + +describe('useTokenCount', () => { + it('should return token count for given input', async () => { + const { result } = renderHook(() => useTokenCount('test input')); + + expect(result.current).toBe(0); + await waitFor(() => expect(result.current).toBe(2)); + }); + + it('should fall back to input length if encodeAsync throws', async () => { + const mockEncodeAsync = vi.spyOn(tokenizers, 'encodeAsync'); + mockEncodeAsync.mockRejectedValueOnce(new Error('encode error')); + + const { result } = renderHook(() => useTokenCount('test input')); + + expect(result.current).toBe(0); + await waitFor(() => expect(result.current).toBe(0)); + }); + + it('should handle empty input', async () => { + const { result } = renderHook(() => useTokenCount('')); + + expect(result.current).toBe(0); + await waitFor(() => expect(result.current).toBe(0)); + }); + it('should handle null input', async () => { + const { result } = renderHook(() => useTokenCount(null as any)); + + expect(result.current).toBe(0); + await waitFor(() => expect(result.current).toBe(0)); + }); +}); diff --git a/src/hooks/useTokenCount.ts b/src/hooks/useTokenCount.ts index eb2d114b61150..52df580a5d91d 100644 --- a/src/hooks/useTokenCount.ts +++ b/src/hooks/useTokenCount.ts @@ -6,9 +6,8 @@ export const useTokenCount = (input: string = '') => { const [value, setNum] = useState(0); useEffect(() => { - if (!input) return; startTransition(() => { - encodeAsync(input) + encodeAsync(input || '') .then(setNum) .catch(() => { // 兜底采用字符数 diff --git a/src/libs/agent-runtime/AgentRuntime.ts b/src/libs/agent-runtime/AgentRuntime.ts index bd457920211eb..599e188a28a66 100644 --- a/src/libs/agent-runtime/AgentRuntime.ts +++ b/src/libs/agent-runtime/AgentRuntime.ts @@ -16,7 +16,12 @@ import { LobeOpenAI } from './openai'; import { LobeOpenRouterAI } from './openrouter'; import { LobePerplexityAI } from './perplexity'; import { LobeTogetherAI } from './togetherai'; -import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from './types'; +import { + ChatCompetitionOptions, + ChatStreamPayload, + ModelProvider, + TextToImagePayload, +} from './types'; import { LobeZeroOneAI } from './zeroone'; import { LobeZhipuAI } from './zhipu'; @@ -65,6 +70,9 @@ class AgentRuntime { async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions) { return this._runtime.chat(payload, options); } + async textToImage(payload: TextToImagePayload) { + return this._runtime.textToImage?.(payload); + } async models() { return this._runtime.models?.(); diff --git a/src/libs/agent-runtime/BaseAI.ts b/src/libs/agent-runtime/BaseAI.ts index 301dc6523ba81..b36faabb86698 100644 --- a/src/libs/agent-runtime/BaseAI.ts +++ b/src/libs/agent-runtime/BaseAI.ts @@ -1,5 +1,6 @@ import OpenAI from 'openai'; +import { TextToImagePayload } from '@/libs/agent-runtime/types/textToImage'; import { ChatModelCard } from '@/types/llm'; import { ChatCompetitionOptions, ChatStreamPayload } from './types'; @@ -9,6 +10,8 @@ export interface LobeRuntimeAI { chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions): Promise; models?(): Promise; + + textToImage?: (payload: TextToImagePayload) => Promise; } export abstract class LobeOpenAICompatibleRuntime { diff --git a/src/libs/agent-runtime/types/index.ts b/src/libs/agent-runtime/types/index.ts index b720567fe1a85..c4eaaea80397a 100644 --- a/src/libs/agent-runtime/types/index.ts +++ b/src/libs/agent-runtime/types/index.ts @@ -1,2 +1,3 @@ export * from './chat'; +export * from './textToImage'; export * from './type'; diff --git a/src/libs/agent-runtime/types/textToImage.ts b/src/libs/agent-runtime/types/textToImage.ts new file mode 100644 index 0000000000000..170c8eff6c23b --- /dev/null +++ b/src/libs/agent-runtime/types/textToImage.ts @@ -0,0 +1,34 @@ +import { DallEImageQuality, DallEImageSize, DallEImageStyle } from '@/types/tool/dalle'; + +export interface TextToImagePayload { + model: string; + /** + * The number of images to generate. Must be between 1 and 10. + */ + n?: number; + /** + * A text description of the desired image(s). + * The maximum length is 1000 characters. + */ + prompt: string; + /** + * The quality of the image that will be generated. + * hd creates images with finer details and greater consistency across the image. + * This param is only supported for dall-e-3. + */ + quality?: DallEImageQuality; + /** + * The size of the generated images. + * Must be one of '1792x1024' , '1024x1024' , '1024x1792' + */ + size?: DallEImageSize; + + /** + * The style of the generated images. Must be one of vivid or natural. + * Vivid causes the model to lean towards generating hyper-real and dramatic images. + * Natural causes the model to produce more natural, less hyper-real looking images. + * This param is only supported for dall-e-3. + * @default vivid + */ + style?: DallEImageStyle; +} diff --git a/src/libs/agent-runtime/utils/createError.ts b/src/libs/agent-runtime/utils/createError.ts index 180f8744d2967..410ea6f1607d8 100644 --- a/src/libs/agent-runtime/utils/createError.ts +++ b/src/libs/agent-runtime/utils/createError.ts @@ -7,4 +7,5 @@ export const AgentRuntimeError = { errorType: ILobeAgentRuntimeErrorType | string | number, error?: any, ): AgentInitErrorPayload => ({ error, errorType }), + textToImage: (error: any): any => error, }; diff --git a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts index ece97dbf98340..97a193c9275ac 100644 --- a/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts +++ b/src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts @@ -1,6 +1,7 @@ import OpenAI, { ClientOptions } from 'openai'; import { LOBE_DEFAULT_MODEL_LIST } from '@/config/modelProviders'; +import { TextToImagePayload } from '@/libs/agent-runtime/types/textToImage'; import { ChatModelCard } from '@/types/llm'; import { LobeRuntimeAI } from '../../BaseAI'; @@ -178,6 +179,56 @@ export const LobeOpenAICompatibleFactory = ({ .filter(Boolean) as ChatModelCard[]; } + async textToImage(payload: TextToImagePayload) { + try { + const res = await this.client.images.generate(payload); + return res.data.map((o) => o.url) as string[]; + } catch (error) { + let desensitizedEndpoint = this.baseURL; + + // refs: https://github.com/lobehub/lobe-chat/issues/842 + if (this.baseURL !== DEFAULT_BASE_URL) { + desensitizedEndpoint = desensitizeUrl(this.baseURL); + } + + if ('status' in (error as any)) { + switch ((error as Response).status) { + case 401: { + throw AgentRuntimeError.chat({ + endpoint: desensitizedEndpoint, + error: error as any, + errorType: ErrorType.invalidAPIKey, + provider: provider as any, + }); + } + + default: { + break; + } + } + } + + if (chatCompletion?.handleError) { + const errorResult = chatCompletion.handleError(error); + + if (errorResult) + throw AgentRuntimeError.chat({ + ...errorResult, + provider, + } as ChatCompletionErrorPayload); + } + + const { errorResult, RuntimeError } = handleOpenAIError(error); + + throw AgentRuntimeError.chat({ + endpoint: desensitizedEndpoint, + error: errorResult, + errorType: RuntimeError || ErrorType.bizError, + provider: provider as any, + }); + } + } + /** * make the OpenAI response data as a stream * @private diff --git a/src/locales/default/tool.ts b/src/locales/default/tool.ts index f4708cd48ac9f..5553abf2764fa 100644 --- a/src/locales/default/tool.ts +++ b/src/locales/default/tool.ts @@ -5,5 +5,6 @@ export default { generate: '生成', generating: '生成中...', images: '图片:', + prompt: '提示词', }, }; diff --git a/src/services/_url.ts b/src/services/_url.ts index d76249f4adccf..1a2cbad130c0b 100644 --- a/src/services/_url.ts +++ b/src/services/_url.ts @@ -34,7 +34,7 @@ export const API_ENDPOINTS = mapWithBasePath({ trace: '/api/trace', // image - images: '/api/openai/images', + images: '/api/text-to-image/openai', // TTS & STT stt: '/api/openai/stt', diff --git a/src/services/imageGeneration.ts b/src/services/textToImage.ts similarity index 66% rename from src/services/imageGeneration.ts rename to src/services/textToImage.ts index c33f0fb0fd4dd..f60eb9afc4be0 100644 --- a/src/services/imageGeneration.ts +++ b/src/services/textToImage.ts @@ -1,4 +1,5 @@ -import { createHeaderWithOpenAI } from '@/services/_header'; +import { ModelProvider } from '@/libs/agent-runtime'; +import { createHeaderWithAuth } from '@/services/_auth'; import { OpenAIImagePayload } from '@/types/openai/image'; import { API_ENDPOINTS } from './_url'; @@ -11,12 +12,20 @@ class ImageGenerationService { async generateImage(params: Omit, options?: FetchOptions) { const payload: OpenAIImagePayload = { ...params, model: 'dall-e-3', n: 1 }; + const headers = await createHeaderWithAuth({ + headers: { 'Content-Type': 'application/json' }, + provider: ModelProvider.OpenAI, + }); + const res = await fetch(API_ENDPOINTS.images, { body: JSON.stringify(payload), - headers: createHeaderWithOpenAI({ 'Content-Type': 'application/json' }), + headers: headers, method: 'POST', signal: options?.signal, }); + if (!res.ok) { + throw await res.json(); + } const urls = await res.json(); diff --git a/src/store/chat/initialState.ts b/src/store/chat/initialState.ts index 8546b45af2480..a6ac4b1c5a005 100644 --- a/src/store/chat/initialState.ts +++ b/src/store/chat/initialState.ts @@ -1,6 +1,6 @@ +import { ChatToolState, initialToolState } from './slices/builtinTool/initialState'; import { ChatMessageState, initialMessageState } from './slices/message/initialState'; import { ChatShareState, initialShareState } from './slices/share/initialState'; -import { ChatToolState, initialToolState } from './slices/tool/initialState'; import { ChatTopicState, initialTopicState } from './slices/topic/initialState'; export type ChatStoreState = ChatTopicState & ChatMessageState & ChatToolState & ChatShareState; diff --git a/src/store/chat/selectors.ts b/src/store/chat/selectors.ts index 59344d71e9aa4..ce6864e7a349e 100644 --- a/src/store/chat/selectors.ts +++ b/src/store/chat/selectors.ts @@ -1,3 +1,3 @@ +export { chatToolSelectors } from './slices/builtinTool/selectors'; export { chatSelectors } from './slices/message/selectors'; -export { chatToolSelectors } from './slices/tool/selectors'; export { topicSelectors } from './slices/topic/selectors'; diff --git a/src/store/chat/slices/tool/action.test.ts b/src/store/chat/slices/builtinTool/action.test.ts similarity index 98% rename from src/store/chat/slices/tool/action.test.ts rename to src/store/chat/slices/builtinTool/action.test.ts index a145e76a83ebd..530662d75e37f 100644 --- a/src/store/chat/slices/tool/action.test.ts +++ b/src/store/chat/slices/builtinTool/action.test.ts @@ -2,7 +2,7 @@ import { act, renderHook } from '@testing-library/react'; import { describe, expect, it, vi } from 'vitest'; import { fileService } from '@/services/file'; -import { imageGenerationService } from '@/services/imageGeneration'; +import { imageGenerationService } from '@/services/textToImage'; import { chatSelectors } from '@/store/chat/selectors'; import { ChatMessage } from '@/types/message'; import { DallEImageItem } from '@/types/tool/dalle'; diff --git a/src/store/chat/slices/tool/action.ts b/src/store/chat/slices/builtinTool/action.ts similarity index 84% rename from src/store/chat/slices/tool/action.ts rename to src/store/chat/slices/builtinTool/action.ts index 5ea2547a225bf..9db4f6d1c37ef 100644 --- a/src/store/chat/slices/tool/action.ts +++ b/src/store/chat/slices/builtinTool/action.ts @@ -3,7 +3,7 @@ import pMap from 'p-map'; import { StateCreator } from 'zustand/vanilla'; import { fileService } from '@/services/file'; -import { imageGenerationService } from '@/services/imageGeneration'; +import { imageGenerationService } from '@/services/textToImage'; import { chatSelectors } from '@/store/chat/selectors'; import { ChatStore } from '@/store/chat/store'; import { DallEImageItem } from '@/types/tool/dalle'; @@ -14,7 +14,7 @@ const n = setNamespace('tool'); /** * builtin tool action */ -export interface ChatToolAction { +export interface ChatBuiltinToolAction { generateImageFromPrompts: (items: DallEImageItem[], id: string) => Promise; text2image: (id: string, data: DallEImageItem[]) => Promise; toggleDallEImageLoading: (key: string, value: boolean) => void; @@ -25,7 +25,7 @@ export const chatToolSlice: StateCreator< ChatStore, [['zustand/devtools', never]], [], - ChatToolAction + ChatBuiltinToolAction > = (set, get) => ({ generateImageFromPrompts: async (items, messageId) => { const { toggleDallEImageLoading, updateImageItem } = get(); @@ -37,10 +37,22 @@ export const chatToolSlice: StateCreator< const parent = getMessageById(message!.parentId!); const originPrompt = parent?.content; + let errorArray: any[] = []; await pMap(items, async (params, index) => { toggleDallEImageLoading(messageId + params.prompt, true); - const url = await imageGenerationService.generateImage(params); + + let url = ''; + try { + url = await imageGenerationService.generateImage(params); + } catch (e) { + toggleDallEImageLoading(messageId + params.prompt, false); + errorArray[index] = e; + + await get().updatePluginState(messageId, `error`, errorArray); + } + + if (!url) return; await updateImageItem(messageId, (draft) => { draft[index].previewUrl = url; diff --git a/src/store/chat/slices/tool/initialState.ts b/src/store/chat/slices/builtinTool/initialState.ts similarity index 100% rename from src/store/chat/slices/tool/initialState.ts rename to src/store/chat/slices/builtinTool/initialState.ts diff --git a/src/store/chat/slices/tool/selectors.ts b/src/store/chat/slices/builtinTool/selectors.ts similarity index 100% rename from src/store/chat/slices/tool/selectors.ts rename to src/store/chat/slices/builtinTool/selectors.ts diff --git a/src/store/chat/slices/enchance/action.ts b/src/store/chat/slices/enchance/action.ts index b28f0a69654e5..54942a0822f5d 100644 --- a/src/store/chat/slices/enchance/action.ts +++ b/src/store/chat/slices/enchance/action.ts @@ -63,19 +63,22 @@ export const chatEnhance: StateCreator< let from = ''; // detect from language - chatService - .fetchPresetTaskResult({ - params: chainLangDetect(message.content), - trace: get().getCurrentTracePayload({ traceName: TraceNameMap.LanguageDetect }), - }) - .then(async (data) => { + chatService.fetchPresetTaskResult({ + onFinish: async (data) => { if (data && supportLocales.includes(data)) from = data; await updateMessageTranslate(id, { content, from, to: targetLang }); - }); + }, + params: chainLangDetect(message.content), + trace: get().getCurrentTracePayload({ traceName: TraceNameMap.LanguageDetect }), + }); // translate to target language await chatService.fetchPresetTaskResult({ + onFinish: async (content) => { + await updateMessageTranslate(id, { content, from, to: targetLang }); + internal_toggleChatLoading(false, id); + }, onMessageHandle: (chunk) => { switch (chunk.type) { case 'text': { @@ -95,10 +98,6 @@ export const chatEnhance: StateCreator< params: chainTranslate(message.content, targetLang), trace: get().getCurrentTracePayload({ traceName: TraceNameMap.Translator }), }); - - await updateMessageTranslate(id, { content, from, to: targetLang }); - - internal_toggleChatLoading(false); }, ttsMessage: async (id, state = {}) => { diff --git a/src/store/chat/slices/message/action.ts b/src/store/chat/slices/message/action.ts index 54c5719557f5e..8f499d72819af 100644 --- a/src/store/chat/slices/message/action.ts +++ b/src/store/chat/slices/message/action.ts @@ -73,6 +73,7 @@ export interface ChatMessageAction { copyMessage: (id: string, content: string) => Promise; refreshMessages: () => Promise; toggleMessageEditing: (id: string, editing: boolean) => void; + // ========= ↓ Internal Method ↓ ========== // // ========================================== // // ========================================== // @@ -81,6 +82,7 @@ export interface ChatMessageAction { id?: string, action?: string, ) => AbortController | undefined; + internal_toggleToolCallingStreaming: (id: string, streaming: boolean[] | undefined) => void; internal_toggleMessageLoading: (loading: boolean, id: string) => void; /** * update message at the frontend point @@ -108,13 +110,7 @@ export interface ChatMessageAction { isFunctionCall: boolean; traceId?: string; }>; - // TODO: 后续 smoothMessage 实现考虑落到 sse 这一层 - internal_createSmoothMessage: (id: string) => { - startAnimation: (speed?: number) => Promise; - stopAnimation: () => void; - outputQueue: string[]; - isAnimationActive: boolean; - }; + /** * a method used by other action * @param id @@ -374,7 +370,7 @@ export const chatMessage: StateCreator< refreshMessages, internal_updateMessageContent, internal_dispatchMessage, - internal_createSmoothMessage, + internal_toggleToolCallingStreaming, } = get(); const abortController = internal_toggleChatLoading( @@ -431,9 +427,7 @@ export const chatMessage: StateCreator< let isFunctionCall = false; let msgTraceId: string | undefined; - - const { startAnimation, stopAnimation, outputQueue, isAnimationActive } = - internal_createSmoothMessage(assistantId); + let output = ''; await chatService.createAssistantMessageStream({ abortController, @@ -455,11 +449,7 @@ export const chatMessage: StateCreator< await messageService.updateMessageError(assistantId, error); await refreshMessages(); }, - onAbort: async () => { - stopAnimation(); - }, onFinish: async (content, { traceId, observationId, toolCalls }) => { - stopAnimation(); // if there is traceId, update it if (traceId) { msgTraceId = traceId; @@ -469,11 +459,8 @@ export const chatMessage: StateCreator< }); } - // if there is still content not displayed, - // and the message is not a function call - // then continue the animation - if (outputQueue.length > 0 && !isFunctionCall) { - await startAnimation(15); + if (toolCalls && toolCalls.length > 0) { + internal_toggleToolCallingStreaming(assistantId, undefined); } // update the content after fetch result @@ -482,12 +469,18 @@ export const chatMessage: StateCreator< onMessageHandle: async (chunk) => { switch (chunk.type) { case 'text': { - outputQueue.push(...chunk.text.split('')); + output += chunk.text; + internal_dispatchMessage({ + id: assistantId, + type: 'updateMessages', + value: { content: output }, + }); break; } // is this message is just a tool call case 'tool_calls': { + internal_toggleToolCallingStreaming(assistantId, chunk.isAnimationActives); internal_dispatchMessage({ id: assistantId, type: 'updateMessages', @@ -496,15 +489,10 @@ export const chatMessage: StateCreator< isFunctionCall = true; } } - - // if it's the first time to receive the message, - // and the message is not a function call - // then start the animation - if (!isAnimationActive && !isFunctionCall) startAnimation(); }, }); - internal_toggleChatLoading(false, undefined, n('generateMessage(end)') as string); + internal_toggleChatLoading(false, assistantId, n('generateMessage(end)') as string); return { isFunctionCall, @@ -551,7 +539,22 @@ export const chatMessage: StateCreator< 'internal_toggleMessageLoading', ); }, + internal_toggleToolCallingStreaming: (id, streaming) => { + set( + { + toolCallingStreamIds: produce(get().toolCallingStreamIds, (draft) => { + if (!!streaming) { + draft[id] = streaming; + } else { + delete draft[id]; + } + }), + }, + false, + 'toggleToolCallingStreaming', + ); + }, internal_resendMessage: async (messageId, traceId) => { // 1. 构造所有相关的历史记录 const chats = chatSelectors.currentChats(get()); @@ -629,71 +632,6 @@ export const chatMessage: StateCreator< return id; }, - internal_createSmoothMessage: (id) => { - const { internal_dispatchMessage } = get(); - - let buffer = ''; - // why use queue: https://shareg.pt/GLBrjpK - let outputQueue: string[] = []; - - // eslint-disable-next-line no-undef - let animationTimeoutId: NodeJS.Timeout | null = null; - let isAnimationActive = false; - - // when you need to stop the animation, call this function - const stopAnimation = () => { - isAnimationActive = false; - if (animationTimeoutId !== null) { - clearTimeout(animationTimeoutId); - animationTimeoutId = null; - } - }; - - // define startAnimation function to display the text in buffer smooth - // when you need to start the animation, call this function - const startAnimation = (speed = 2) => - new Promise((resolve) => { - if (isAnimationActive) { - resolve(); - return; - } - - isAnimationActive = true; - - const updateText = () => { - // 如果动画已经不再激活,则停止更新文本 - if (!isAnimationActive) { - clearTimeout(animationTimeoutId!); - animationTimeoutId = null; - resolve(); - } - - // 如果还有文本没有显示 - // 检查队列中是否有字符待显示 - if (outputQueue.length > 0) { - // 从队列中获取前两个字符(如果存在) - const charsToAdd = outputQueue.splice(0, speed).join(''); - buffer += charsToAdd; - - // 更新消息内容,这里可能需要结合实际情况调整 - internal_dispatchMessage({ id, type: 'updateMessages', value: { content: buffer } }); - - // 设置下一个字符的延迟 - animationTimeoutId = setTimeout(updateText, 16); // 16 毫秒的延迟模拟打字机效果 - } else { - // 当所有字符都显示完毕时,清除动画状态 - isAnimationActive = false; - animationTimeoutId = null; - resolve(); - } - }; - - updateText(); - }); - - return { startAnimation, stopAnimation, outputQueue, isAnimationActive }; - }, - internal_traceMessage: async (id, payload) => { // tracing the diff of update const message = chatSelectors.getMessageById(id)(get()); diff --git a/src/store/chat/slices/message/initialState.ts b/src/store/chat/slices/message/initialState.ts index bc28874b08c01..3a4688ceeac64 100644 --- a/src/store/chat/slices/message/initialState.ts +++ b/src/store/chat/slices/message/initialState.ts @@ -25,6 +25,10 @@ export interface ChatMessageState { * whether messages have fetched */ messagesInit: boolean; + /** + * the tool calling stream ids + */ + toolCallingStreamIds: Record; } export const initialMessageState: ChatMessageState = { @@ -35,4 +39,5 @@ export const initialMessageState: ChatMessageState = { messageLoadingIds: [], messages: [], messagesInit: false, + toolCallingStreamIds: {}, }; diff --git a/src/store/chat/slices/message/selectors.ts b/src/store/chat/slices/message/selectors.ts index 44b829efedddb..ccb0efff47d23 100644 --- a/src/store/chat/slices/message/selectors.ts +++ b/src/store/chat/slices/message/selectors.ts @@ -117,6 +117,13 @@ const currentChatLoadingState = (s: ChatStore) => !s.messagesInit; const isMessageEditing = (id: string) => (s: ChatStore) => s.messageEditingIds.includes(id); const isMessageLoading = (id: string) => (s: ChatStore) => s.messageLoadingIds.includes(id); const isMessageGenerating = (id: string) => (s: ChatStore) => s.chatLoadingIds.includes(id); +const isToolCallStreaming = (id: string, index: number) => (s: ChatStore) => { + const isLoading = s.toolCallingStreamIds[id]; + + if (!isLoading) return false; + + return isLoading[index]; +}; const isAIGenerating = (s: ChatStore) => s.chatLoadingIds.length > 0; export const chatSelectors = { @@ -133,6 +140,7 @@ export const chatSelectors = { isMessageEditing, isMessageGenerating, isMessageLoading, + isToolCallStreaming, latestMessage, showInboxWelcome, }; diff --git a/src/store/chat/slices/plugin/action.test.ts b/src/store/chat/slices/plugin/action.test.ts index 199c809091275..768185a7a79e0 100644 --- a/src/store/chat/slices/plugin/action.test.ts +++ b/src/store/chat/slices/plugin/action.test.ts @@ -602,7 +602,7 @@ describe('ChatPluginAction', () => { const runPluginApiMock = vi.fn(); act(() => { - useChatStore.setState({ runPluginApi: runPluginApiMock }); + useChatStore.setState({ internal_callPluginApi: runPluginApiMock }); }); const { result } = renderHook(() => useChatStore()); diff --git a/src/store/chat/slices/plugin/action.ts b/src/store/chat/slices/plugin/action.ts index 6b5a57a879f2c..1c8608ad1a454 100644 --- a/src/store/chat/slices/plugin/action.ts +++ b/src/store/chat/slices/plugin/action.ts @@ -24,15 +24,22 @@ export interface ChatPluginAction { content: string, triggerAiMessage?: boolean, ) => Promise; + + internal_callPluginApi: (id: string, payload: ChatToolPayload) => Promise; + internal_invokeDifferentTypePlugin: (id: string, payload: ChatToolPayload) => Promise; internal_transformToolCalls: (toolCalls: MessageToolCall[]) => ChatToolPayload[]; + internal_updatePluginError: (id: string, error: any) => Promise; + invokeBuiltinTool: (id: string, payload: ChatToolPayload) => Promise; invokeDefaultTypePlugin: (id: string, payload: any) => Promise; invokeMarkdownTypePlugin: (id: string, payload: ChatToolPayload) => Promise; + invokeStandaloneTypePlugin: (id: string, payload: ChatToolPayload) => Promise; - runPluginApi: (id: string, payload: ChatToolPayload) => Promise; + + reInvokeToolMessage: (id: string) => Promise; triggerAIMessage: (params: { parentId?: string; traceId?: string }) => Promise; - triggerToolCalls: (id: string) => Promise; + triggerToolCalls: (id: string) => Promise; updatePluginState: (id: string, key: string, value: any) => Promise; } @@ -62,6 +69,70 @@ export const chatPlugin: StateCreator< if (triggerAiMessage) await triggerAIMessage({ parentId: id }); }, + internal_callPluginApi: async (id, payload) => { + const { internal_updateMessageContent, refreshMessages, internal_toggleChatLoading } = get(); + let data: string; + + try { + const abortController = internal_toggleChatLoading( + true, + id, + n('fetchPlugin/start') as string, + ); + + const message = chatSelectors.getMessageById(id)(get()); + + const res = await chatService.runPluginApi(payload, { + signal: abortController?.signal, + trace: { observationId: message?.observationId, traceId: message?.traceId }, + }); + data = res.text; + + // save traceId + if (res.traceId) { + await messageService.updateMessage(id, { traceId: res.traceId }); + } + } catch (error) { + console.log(error); + const err = error as Error; + + // ignore the aborted request error + if (!err.message.includes('The user aborted a request.')) { + await messageService.updateMessageError(id, error as any); + await refreshMessages(); + } + + data = ''; + } + + internal_toggleChatLoading(false, id, n('fetchPlugin/end') as string); + // 如果报错则结束了 + if (!data) return; + + await internal_updateMessageContent(id, data); + + return data; + }, + + internal_invokeDifferentTypePlugin: async (id, payload) => { + switch (payload.type) { + case 'standalone': { + return await get().invokeStandaloneTypePlugin(id, payload); + } + + case 'markdown': { + return await get().invokeMarkdownTypePlugin(id, payload); + } + + case 'builtin': { + return await get().invokeBuiltinTool(id, payload); + } + + default: { + return await get().invokeDefaultTypePlugin(id, payload); + } + } + }, internal_transformToolCalls: (toolCalls) => { return toolCalls @@ -98,6 +169,13 @@ export const chatPlugin: StateCreator< .filter(Boolean) as ChatToolPayload[]; }, + internal_updatePluginError: async (id, error) => { + const { refreshMessages } = get(); + + await messageService.updateMessage(id, { pluginError: error }); + await refreshMessages(); + }, + invokeBuiltinTool: async (id, payload) => { const { internal_toggleChatLoading, internal_updateMessageContent } = get(); const params = JSON.parse(payload.arguments); @@ -131,9 +209,9 @@ export const chatPlugin: StateCreator< }, invokeDefaultTypePlugin: async (id, payload) => { - const { runPluginApi } = get(); + const { internal_callPluginApi } = get(); - const data = await runPluginApi(id, payload); + const data = await internal_callPluginApi(id, payload); if (!data) return; @@ -141,9 +219,9 @@ export const chatPlugin: StateCreator< }, invokeMarkdownTypePlugin: async (id, payload) => { - const { runPluginApi } = get(); + const { internal_callPluginApi } = get(); - await runPluginApi(id, payload); + await internal_callPluginApi(id, payload); }, invokeStandaloneTypePlugin: async (id, payload) => { @@ -166,49 +244,13 @@ export const chatPlugin: StateCreator< } }, - runPluginApi: async (id, payload) => { - const { internal_updateMessageContent, refreshMessages, internal_toggleChatLoading } = get(); - let data: string; + reInvokeToolMessage: async (id) => { + const message = chatSelectors.getMessageById(id)(get()); + if (!message || message.role !== 'tool' || !message.plugin) return; - try { - const abortController = internal_toggleChatLoading( - true, - id, - n('fetchPlugin/start') as string, - ); + const payload: ChatToolPayload = { ...message.plugin, id: message.tool_call_id! }; - const message = chatSelectors.getMessageById(id)(get()); - - const res = await chatService.runPluginApi(payload, { - signal: abortController?.signal, - trace: { observationId: message?.observationId, traceId: message?.traceId }, - }); - data = res.text; - - // save traceId - if (res.traceId) { - await messageService.updateMessage(id, { traceId: res.traceId }); - } - } catch (error) { - console.log(error); - const err = error as Error; - - // ignore the aborted request error - if (!err.message.includes('The user aborted a request.')) { - await messageService.updateMessageError(id, error as any); - await refreshMessages(); - } - - data = ''; - } - - internal_toggleChatLoading(false, id, n('fetchPlugin/end') as string); - // 如果报错则结束了 - if (!data) return; - - await internal_updateMessageContent(id, data); - - return data; + await get().internal_invokeDifferentTypePlugin(id, payload); }, triggerAIMessage: async ({ parentId, traceId }) => { @@ -216,19 +258,10 @@ export const chatPlugin: StateCreator< const chats = chatSelectors.currentChats(get()); await internal_coreProcessMessage(chats, parentId ?? chats.at(-1)!.id, { traceId }); }, - triggerToolCalls: async (assistantId) => { const message = chatSelectors.getMessageById(assistantId)(get()); if (!message || !message.tools) return; - const { - invokeDefaultTypePlugin, - invokeMarkdownTypePlugin, - invokeStandaloneTypePlugin, - invokeBuiltinTool, - triggerAIMessage, - } = get(); - let shouldCreateMessage = false; let latestToolId = ''; const messagePools = message.tools.map(async (payload) => { @@ -244,29 +277,12 @@ export const chatPlugin: StateCreator< const id = await get().internal_createMessage(toolMessage); - switch (payload.type) { - case 'standalone': { - await invokeStandaloneTypePlugin(id, payload); - break; - } - - case 'markdown': { - await invokeMarkdownTypePlugin(id, payload); - break; - } - - case 'builtin': { - await invokeBuiltinTool(id, payload); - break; - } + // trigger the plugin call + const data = await get().internal_invokeDifferentTypePlugin(id, payload); - default: { - const data = await invokeDefaultTypePlugin(id, payload); - if (data) { - shouldCreateMessage = true; - latestToolId = id; - } - } + if (payload.type === 'default' && data) { + shouldCreateMessage = true; + latestToolId = id; } }); @@ -277,9 +293,8 @@ export const chatPlugin: StateCreator< const traceId = chatSelectors.getTraceIdByMessageId(latestToolId)(get()); - await triggerAIMessage({ traceId }); + await get().triggerAIMessage({ traceId }); }, - updatePluginState: async (id, key, value) => { const { refreshMessages } = get(); diff --git a/src/store/chat/store.ts b/src/store/chat/store.ts index 7f3fb1a42c346..552be7546b2b0 100644 --- a/src/store/chat/store.ts +++ b/src/store/chat/store.ts @@ -6,11 +6,11 @@ import { StateCreator } from 'zustand/vanilla'; import { isDev } from '@/utils/env'; import { ChatStoreState, initialState } from './initialState'; +import { ChatBuiltinToolAction, chatToolSlice } from './slices/builtinTool/action'; import { ChatEnhanceAction, chatEnhance } from './slices/enchance/action'; import { ChatMessageAction, chatMessage } from './slices/message/action'; import { ChatPluginAction, chatPlugin } from './slices/plugin/action'; import { ShareAction, chatShare } from './slices/share/action'; -import { ChatToolAction, chatToolSlice } from './slices/tool/action'; import { ChatTopicAction, chatTopic } from './slices/topic/action'; export interface ChatStoreAction @@ -19,7 +19,7 @@ export interface ChatStoreAction ShareAction, ChatEnhanceAction, ChatPluginAction, - ChatToolAction {} + ChatBuiltinToolAction {} export type ChatStore = ChatStoreAction & ChatStoreState; diff --git a/src/store/tool/slices/store/action.test.ts b/src/store/tool/slices/store/action.test.ts index c2f0d5fe3c4c0..421133992ca83 100644 --- a/src/store/tool/slices/store/action.test.ts +++ b/src/store/tool/slices/store/action.test.ts @@ -156,7 +156,9 @@ describe('useToolStore:pluginStore', () => { const { result } = renderHook(() => useToolStore.getState().useFetchPluginStore()); // Then - expect(useSWR).toHaveBeenCalledWith('loadPluginStore', expect.any(Function)); + expect(useSWR).toHaveBeenCalledWith('loadPluginStore', expect.any(Function), { + revalidateOnFocus: false, + }); expect(result.current.data).toEqual(pluginListMock); expect(result.current.error).toBeNull(); expect(result.current.isValidating).toBe(false); @@ -175,7 +177,9 @@ describe('useToolStore:pluginStore', () => { const { result } = renderHook(() => useToolStore.getState().useFetchPluginStore()); // Then - expect(useSWR).toHaveBeenCalledWith('loadPluginStore', expect.any(Function)); + expect(useSWR).toHaveBeenCalledWith('loadPluginStore', expect.any(Function), { + revalidateOnFocus: false, + }); expect(result.current.data).toBeNull(); expect(result.current.error).toEqual(error); expect(result.current.isValidating).toBe(false); diff --git a/src/store/tool/slices/store/action.ts b/src/store/tool/slices/store/action.ts index fe771664deb4b..da2df5bf40e59 100644 --- a/src/store/tool/slices/store/action.ts +++ b/src/store/tool/slices/store/action.ts @@ -101,5 +101,7 @@ export const createPluginStoreSlice: StateCreator< revalidateOnFocus: false, }), useFetchPluginStore: () => - useSWR('loadPluginStore', get().loadPluginStore), + useSWR('loadPluginStore', get().loadPluginStore, { + revalidateOnFocus: false, + }), }); diff --git a/src/tools/dalle/Render/EditMode.tsx b/src/tools/dalle/Render/Item/EditMode.tsx similarity index 100% rename from src/tools/dalle/Render/EditMode.tsx rename to src/tools/dalle/Render/Item/EditMode.tsx diff --git a/src/tools/dalle/Render/Item/Error.tsx b/src/tools/dalle/Render/Item/Error.tsx new file mode 100644 index 0000000000000..6186062dcefcb --- /dev/null +++ b/src/tools/dalle/Render/Item/Error.tsx @@ -0,0 +1,50 @@ +import { Alert, Highlighter, Icon } from '@lobehub/ui'; +import { Button } from 'antd'; +import { LucideRefreshCw } from 'lucide-react'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { Flexbox } from 'react-layout-kit'; + +import { useChatStore } from '@/store/chat'; +import { chatSelectors } from '@/store/chat/selectors'; + +interface ErrorProps { + index: number; + messageId: string; +} + +const Error = memo(({ messageId, index }) => { + const { t } = useTranslation('error'); + const { t: ct } = useTranslation('common'); + + const error = useChatStore( + (s) => chatSelectors.getMessageById(messageId)(s)?.pluginState['error']?.[index], + ); + const [reInvokeToolMessage] = useChatStore((s) => [s.reInvokeToolMessage]); + + return ( + error && ( + + + {JSON.stringify(error.body, null, 2)} + + } + extraDefaultExpand + message={t(`response.${error.errorType}` as any)} + type={'error'} + /> + + + ) + ); +}); + +export default Error; diff --git a/src/tools/dalle/Render/Item/Image.tsx b/src/tools/dalle/Render/Item/Image.tsx new file mode 100644 index 0000000000000..6eda89fb40730 --- /dev/null +++ b/src/tools/dalle/Render/Item/Image.tsx @@ -0,0 +1,44 @@ +import { Icon, Image, Tooltip } from '@lobehub/ui'; +import { Loader2 } from 'lucide-react'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { Flexbox } from 'react-layout-kit'; + +import ImageFileItem from '@/components/FileList/ImageFileItem'; + +interface ImagePreviewProps { + imageId?: string; + previewUrl?: string; + prompt: string; +} + +const ImagePreview = memo(({ imageId, previewUrl, prompt }) => { + const { t } = useTranslation('tool'); + + return imageId ? ( + // + // { + // if (e.key === 'edit') { + // setEdit(true); + // } + // }} + // /> + // + + ) : ( + previewUrl && ( + +
+ + + +
+ {prompt} +
+ ) + ); +}); + +export default ImagePreview; diff --git a/src/tools/dalle/Render/Item.tsx b/src/tools/dalle/Render/Item/index.tsx similarity index 60% rename from src/tools/dalle/Render/Item.tsx rename to src/tools/dalle/Render/Item/index.tsx index 5bae1f88b6478..0e472e651329b 100644 --- a/src/tools/dalle/Render/Item.tsx +++ b/src/tools/dalle/Render/Item/index.tsx @@ -1,4 +1,4 @@ -import { Icon, Image, Tooltip } from '@lobehub/ui'; +import { Highlighter, Icon } from '@lobehub/ui'; import { Spin } from 'antd'; import { createStyles } from 'antd-style'; import { Loader2 } from 'lucide-react'; @@ -6,12 +6,13 @@ import { memo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { Flexbox } from 'react-layout-kit'; -import ImageFileItem from '@/components/FileList/ImageFileItem'; import { useChatStore } from '@/store/chat'; import { chatToolSelectors } from '@/store/chat/selectors'; import { DallEImageItem } from '@/types/tool/dalle'; import EditMode from './EditMode'; +import Error from './Error'; +import ImagePreview from './Image'; const useStyles = createStyles(({ css, token, prefixCls }) => ({ action: css` @@ -32,8 +33,8 @@ const useStyles = createStyles(({ css, token, prefixCls }) => ({ `, })); -const ImageItem = memo( - ({ prompt, messageId, imageId, previewUrl, style, size, quality }) => { +const ImageItem = memo( + ({ prompt, messageId, imageId, previewUrl, index, style, size, quality }) => { const { t } = useTranslation('tool'); const { styles } = useStyles(); @@ -55,30 +56,7 @@ const ImageItem = memo( ); if (imageId || previewUrl) - return imageId ? ( - // - // { - // if (e.key === 'edit') { - // setEdit(true); - // } - // }} - // /> - // - - ) : ( - previewUrl && ( - -
- - - -
- {prompt} -
- ) - ); + return ; return ( @@ -87,7 +65,20 @@ const ImageItem = memo( {prompt} ) : ( - prompt + + + + {prompt} + + + + )} ); diff --git a/src/utils/fetch.test.ts b/src/utils/fetch.test.ts index e70151340d324..e65059e68b1db 100644 --- a/src/utils/fetch.test.ts +++ b/src/utils/fetch.test.ts @@ -1,6 +1,7 @@ import { fetchEventSource } from '@microsoft/fetch-event-source'; import { FetchEventSourceInit } from '@microsoft/fetch-event-source'; import { afterEach, describe, expect, it, vi } from 'vitest'; +import { ZodError } from 'zod'; import { ErrorResponse } from '@/types/fetch'; @@ -82,6 +83,16 @@ describe('getMessageError', () => { }); expect(mockResponse.json).toHaveBeenCalled(); }); + + it('should handle timeout error correctly', async () => { + const mockResponse = createMockResponse(undefined, false, 504); + const error = await getMessageError(mockResponse as any); + + expect(error).toEqual({ + message: 'translated_response.504', + type: 504, + }); + }); }); describe('parseToolCalls', () => { @@ -173,6 +184,34 @@ describe('parseToolCalls', () => { }, ]); }); + + it('should throw error if incomplete tool calls data', () => { + const origin = [ + { + id: '1', + type: 'function', + function: { name: 'func', arguments: '{"location\\": \\"Hangzhou\\"}' }, + }, + ]; + + const chunk = [{ index: 1, id: '2', type: 'function' }]; + + try { + parseToolCalls(origin, chunk as any); + } catch (e) { + expect(e).toEqual( + new ZodError([ + { + code: 'invalid_type', + expected: 'object', + received: 'undefined', + path: ['function'], + message: 'Required', + }, + ]), + ); + } + }); }); describe('fetchSSE', () => { @@ -188,7 +227,11 @@ describe('fetchSSE', () => { }, ); - await fetchSSE('/', { onMessageHandle: mockOnMessageHandle, onFinish: mockOnFinish }); + await fetchSSE('/', { + onMessageHandle: mockOnMessageHandle, + onFinish: mockOnFinish, + smoothing: false, + }); expect(mockOnMessageHandle).toHaveBeenNthCalledWith(1, { text: 'Hello', type: 'text' }); expect(mockOnMessageHandle).toHaveBeenNthCalledWith(2, { text: ' World', type: 'text' }); @@ -222,7 +265,11 @@ describe('fetchSSE', () => { }, ); - await fetchSSE('/', { onMessageHandle: mockOnMessageHandle, onFinish: mockOnFinish }); + await fetchSSE('/', { + onMessageHandle: mockOnMessageHandle, + onFinish: mockOnFinish, + smoothing: false, + }); expect(mockOnMessageHandle).toHaveBeenNthCalledWith(1, { tool_calls: [{ id: '1', type: 'function', function: { name: 'func1', arguments: 'arg1' } }], @@ -256,7 +303,7 @@ describe('fetchSSE', () => { }, ); - await fetchSSE('/', { onAbort: mockOnAbort }); + await fetchSSE('/', { onAbort: mockOnAbort, smoothing: false }); expect(mockOnAbort).toHaveBeenCalledWith('Hello'); }); @@ -320,4 +367,162 @@ describe('fetchSSE', () => { type: 'done', }); }); + + it('should handle text event with smoothing correctly', async () => { + const mockOnMessageHandle = vi.fn(); + const mockOnFinish = vi.fn(); + + (fetchEventSource as any).mockImplementationOnce( + (url: string, options: FetchEventSourceInit) => { + options.onopen!({ clone: () => ({ ok: true, headers: new Headers() }) } as any); + options.onmessage!({ event: 'text', data: JSON.stringify('Hello') } as any); + options.onmessage!({ event: 'text', data: JSON.stringify(' World') } as any); + }, + ); + + await fetchSSE('/', { + onMessageHandle: mockOnMessageHandle, + onFinish: mockOnFinish, + }); + + expect(mockOnMessageHandle).toHaveBeenNthCalledWith(1, { text: 'He', type: 'text' }); + expect(mockOnMessageHandle).toHaveBeenNthCalledWith(2, { text: 'llo World', type: 'text' }); + // more assertions for each character... + expect(mockOnFinish).toHaveBeenCalledWith('Hello World', { + observationId: null, + toolCalls: undefined, + traceId: null, + type: 'done', + }); + }); + + it('should handle tool_calls event with smoothing correctly', async () => { + const mockOnMessageHandle = vi.fn(); + const mockOnFinish = vi.fn(); + + (fetchEventSource as any).mockImplementationOnce( + (url: string, options: FetchEventSourceInit) => { + options.onopen!({ clone: () => ({ ok: true, headers: new Headers() }) } as any); + options.onmessage!({ + event: 'tool_calls', + data: JSON.stringify([ + { index: 0, id: '1', type: 'function', function: { name: 'func1', arguments: 'a' } }, + ]), + } as any); + options.onmessage!({ + event: 'tool_calls', + data: JSON.stringify([ + { index: 0, function: { arguments: 'rg1' } }, + { index: 1, id: '2', type: 'function', function: { name: 'func2', arguments: 'a' } }, + ]), + } as any); + options.onmessage!({ + event: 'tool_calls', + data: JSON.stringify([{ index: 1, function: { arguments: 'rg2' } }]), + } as any); + }, + ); + + await fetchSSE('/', { + onMessageHandle: mockOnMessageHandle, + onFinish: mockOnFinish, + }); + + // TODO: need to check whether the `aarg1` is correct + expect(mockOnMessageHandle).toHaveBeenNthCalledWith(1, { + isAnimationActives: [true], + tool_calls: [ + { id: '1', type: 'function', function: { name: 'func1', arguments: 'aarg1' } }, + { function: { arguments: 'aarg2', name: 'func2' }, id: '2', type: 'function' }, + ], + type: 'tool_calls', + }); + expect(mockOnMessageHandle).toHaveBeenNthCalledWith(2, { + isAnimationActives: [true, true], + tool_calls: [ + { id: '1', type: 'function', function: { name: 'func1', arguments: 'aarg1' } }, + { id: '2', type: 'function', function: { name: 'func2', arguments: 'aarg2' } }, + ], + type: 'tool_calls', + }); + + // more assertions for each character... + expect(mockOnFinish).toHaveBeenCalledWith('', { + observationId: null, + toolCalls: [ + { id: '1', type: 'function', function: { name: 'func1', arguments: 'arg1' } }, + { id: '2', type: 'function', function: { name: 'func2', arguments: 'arg2' } }, + ], + traceId: null, + type: 'done', + }); + }); + + it('should handle request interruption and resumption correctly', async () => { + const mockOnMessageHandle = vi.fn(); + const mockOnFinish = vi.fn(); + const abortController = new AbortController(); + + (fetchEventSource as any).mockImplementationOnce( + (url: string, options: FetchEventSourceInit) => { + options.onopen!({ clone: () => ({ ok: true, headers: new Headers() }) } as any); + options.onmessage!({ event: 'text', data: JSON.stringify('Hello') } as any); + abortController.abort(); + options.onmessage!({ event: 'text', data: JSON.stringify(' World') } as any); + }, + ); + + await fetchSSE('/', { + onMessageHandle: mockOnMessageHandle, + onFinish: mockOnFinish, + signal: abortController.signal, + }); + + expect(mockOnMessageHandle).toHaveBeenNthCalledWith(1, { text: 'He', type: 'text' }); + expect(mockOnMessageHandle).toHaveBeenNthCalledWith(2, { text: 'llo World', type: 'text' }); + + expect(mockOnFinish).toHaveBeenCalledWith('Hello World', { + type: 'done', + observationId: null, + traceId: null, + }); + }); + + it('should call onFinish with correct parameters for different finish types', async () => { + const mockOnFinish = vi.fn(); + + (fetchEventSource as any).mockImplementationOnce( + (url: string, options: FetchEventSourceInit) => { + options.onopen!({ clone: () => ({ ok: true, headers: new Headers() }) } as any); + options.onmessage!({ event: 'text', data: JSON.stringify('Hello') } as any); + options.onerror!({ name: 'AbortError' }); + }, + ); + + await fetchSSE('/', { onFinish: mockOnFinish, smoothing: false }); + + expect(mockOnFinish).toHaveBeenCalledWith('Hello', { + observationId: null, + toolCalls: undefined, + traceId: null, + type: 'abort', + }); + + (fetchEventSource as any).mockImplementationOnce( + (url: string, options: FetchEventSourceInit) => { + options.onopen!({ clone: () => ({ ok: true, headers: new Headers() }) } as any); + options.onmessage!({ event: 'text', data: JSON.stringify('Hello') } as any); + options.onerror!(new Error('Unknown error')); + }, + ); + + await fetchSSE('/', { onFinish: mockOnFinish, smoothing: false }); + + expect(mockOnFinish).toHaveBeenCalledWith('Hello', { + observationId: null, + toolCalls: undefined, + traceId: null, + type: 'error', + }); + }); }); diff --git a/src/utils/fetch.ts b/src/utils/fetch.ts index 748cf044fe864..92cc9f15d2afc 100644 --- a/src/utils/fetch.ts +++ b/src/utils/fetch.ts @@ -51,6 +51,7 @@ export interface MessageTextChunk { } interface MessageToolCallsChunk { + isAnimationActives?: boolean[]; tool_calls: MessageToolCall[]; type: 'tool_calls'; } @@ -61,24 +62,201 @@ export interface FetchSSEOptions { onErrorHandle?: (error: ChatMessageError) => void; onFinish?: OnFinishHandler; onMessageHandle?: (chunk: MessageTextChunk | MessageToolCallsChunk) => void; + smoothing?: boolean; } export const parseToolCalls = (origin: MessageToolCall[], value: MessageToolCallChunk[]) => produce(origin, (draft) => { + // if there is no origin, we should parse all the value and set it to draft if (draft.length === 0) { draft.push(...value.map((item) => MessageToolCallSchema.parse(item))); - } else { - value.forEach(({ index, ...item }) => { - if (!draft?.[index]) { - draft?.splice(index, 0, MessageToolCallSchema.parse(item)); + return; + } + + // if there is origin, we should merge the value to the origin + value.forEach(({ index, ...item }) => { + if (!draft?.[index]) { + // if not, we should insert it to the draft + draft?.splice(index, 0, MessageToolCallSchema.parse(item)); + } else { + // if it is already in the draft, we should merge the arguments to the draft + if (item.function?.arguments) { + draft[index].function.arguments += item.function.arguments; + } + } + }); + }); + +const createSmoothMessage = (params: { onTextUpdate: (delta: string, text: string) => void }) => { + let buffer = ''; + // why use queue: https://shareg.pt/GLBrjpK + let outputQueue: string[] = []; + + // eslint-disable-next-line no-undef + let animationTimeoutId: NodeJS.Timeout | null = null; + let isAnimationActive = false; + + // when you need to stop the animation, call this function + const stopAnimation = () => { + isAnimationActive = false; + if (animationTimeoutId !== null) { + clearTimeout(animationTimeoutId); + animationTimeoutId = null; + } + }; + + // define startAnimation function to display the text in buffer smooth + // when you need to start the animation, call this function + const startAnimation = (speed = 2) => + new Promise((resolve) => { + if (isAnimationActive) { + resolve(); + return; + } + + isAnimationActive = true; + + const updateText = () => { + // 如果动画已经不再激活,则停止更新文本 + if (!isAnimationActive) { + clearTimeout(animationTimeoutId!); + animationTimeoutId = null; + resolve(); + } + + // 如果还有文本没有显示 + // 检查队列中是否有字符待显示 + if (outputQueue.length > 0) { + // 从队列中获取前两个字符(如果存在) + const charsToAdd = outputQueue.splice(0, speed).join(''); + buffer += charsToAdd; + + // 更新消息内容,这里可能需要结合实际情况调整 + params.onTextUpdate(charsToAdd, buffer); + + // 设置下一个字符的延迟 + animationTimeoutId = setTimeout(updateText, 16); // 16 毫秒的延迟模拟打字机效果 } else { - if (item.function?.arguments) { - draft[index].function.arguments += item.function.arguments; - } + // 当所有字符都显示完毕时,清除动画状态 + isAnimationActive = false; + animationTimeoutId = null; + resolve(); } - }); + }; + + updateText(); + }); + + const pushToQueue = (text: string) => { + outputQueue.push(...text.split('')); + }; + + return { + isAnimationActive, + isTokenRemain: () => outputQueue.length > 0, + pushToQueue, + startAnimation, + stopAnimation, + }; +}; + +const createSmoothToolCalls = (params: { + onToolCallsUpdate: (toolCalls: MessageToolCall[], isAnimationActives: boolean[]) => void; +}) => { + let toolCallsBuffer: MessageToolCall[] = []; + + // 为每个 tool_call 维护一个输出队列和动画控制器 + + // eslint-disable-next-line no-undef + const animationTimeoutIds: (NodeJS.Timeout | null)[] = []; + const outputQueues: string[][] = []; + const isAnimationActives: boolean[] = []; + + const stopAnimation = (index: number) => { + isAnimationActives[index] = false; + if (animationTimeoutIds[index] !== null) { + clearTimeout(animationTimeoutIds[index]!); + animationTimeoutIds[index] = null; } - }); + }; + + const startAnimation = (index: number, speed = 2) => + new Promise((resolve) => { + if (isAnimationActives[index]) { + resolve(); + return; + } + + isAnimationActives[index] = true; + + const updateToolCall = () => { + if (!isAnimationActives[index]) { + resolve(); + } + + if (outputQueues[index].length > 0) { + const charsToAdd = outputQueues[index].splice(0, speed).join(''); + + const toolCallToUpdate = toolCallsBuffer[index]; + + if (toolCallToUpdate) { + toolCallToUpdate.function.arguments += charsToAdd; + + // 触发 ui 更新 + params.onToolCallsUpdate(toolCallsBuffer, [...isAnimationActives]); + } + + animationTimeoutIds[index] = setTimeout(updateToolCall, 16); + } else { + isAnimationActives[index] = false; + animationTimeoutIds[index] = null; + resolve(); + } + }; + + updateToolCall(); + }); + + const pushToQueue = (toolCallChunks: MessageToolCallChunk[]) => { + toolCallChunks.forEach((chunk) => { + // init the tool call buffer and output queue + if (!toolCallsBuffer[chunk.index]) { + toolCallsBuffer[chunk.index] = MessageToolCallSchema.parse(chunk); + } + + if (!outputQueues[chunk.index]) { + outputQueues[chunk.index] = []; + isAnimationActives[chunk.index] = false; + animationTimeoutIds[chunk.index] = null; + } + + outputQueues[chunk.index].push(...(chunk.function?.arguments || '').split('')); + }); + }; + + const startAnimations = async (speed = 2) => { + const pools = toolCallsBuffer.map(async (_, index) => { + if (outputQueues[index].length > 0 && !isAnimationActives[index]) { + await startAnimation(index, speed); + } + }); + + await Promise.all(pools); + }; + const stopAnimations = () => { + toolCallsBuffer.forEach((_, index) => { + stopAnimation(index); + }); + }; + + return { + isAnimationActives, + isTokenRemain: () => outputQueues.some((token) => token.length > 0), + pushToQueue, + startAnimations, + stopAnimations, + }; +}; /** * Fetch data using stream method @@ -92,6 +270,21 @@ export const fetchSSE = async (url: string, options: RequestInit & FetchSSEOptio let finishedType: SSEFinishType = 'done'; let response!: Response; + const { smoothing = true } = options; + + const textController = createSmoothMessage({ + onTextUpdate: (delta, text) => { + output = text; + options.onMessageHandle?.({ text: delta, type: 'text' }); + }, + }); + + const toolCallsController = createSmoothToolCalls({ + onToolCallsUpdate: (toolCalls, isAnimationActives) => { + options.onMessageHandle?.({ isAnimationActives, tool_calls: toolCalls, type: 'tool_calls' }); + }, + }); + try { await fetchEventSource(url, { body: options.body, @@ -102,6 +295,7 @@ export const fetchSSE = async (url: string, options: RequestInit & FetchSSEOptio if ((error as TypeError).name === 'AbortError') { finishedType = 'abort'; options?.onAbort?.(output); + textController.stopAnimation(); } else { finishedType = 'error'; console.error(error); @@ -122,22 +316,39 @@ export const fetchSSE = async (url: string, options: RequestInit & FetchSSEOptio switch (ev.event) { case 'text': { - output += data; - options.onMessageHandle?.({ text: data, type: 'text' }); + if (smoothing) { + textController.pushToQueue(data); + + if (!textController.isAnimationActive) textController.startAnimation(); + } else { + output += data; + options.onMessageHandle?.({ text: data, type: 'text' }); + } + break; } case 'tool_calls': { - if (!toolCalls) { - toolCalls = []; - } - + // get finial + // if there is no tool calls, we should initialize the tool calls + if (!toolCalls) toolCalls = []; toolCalls = parseToolCalls(toolCalls, data); - options.onMessageHandle?.({ - tool_calls: toolCalls, - type: 'tool_calls', - }); + if (smoothing) { + // make the tool calls smooth + + // push the tool calls to the smooth queue + toolCallsController.pushToQueue(data); + // if there is no animation active, we should start the animation + if (toolCallsController.isAnimationActives.some((value) => !value)) { + toolCallsController.startAnimations(); + } + } else { + options.onMessageHandle?.({ + tool_calls: toolCalls, + type: 'tool_calls', + }); + } } } }, @@ -160,6 +371,9 @@ export const fetchSSE = async (url: string, options: RequestInit & FetchSSEOptio // only call onFinish when response is available // so like abort, we don't need to call onFinish if (response) { + textController.stopAnimation(); + toolCallsController.stopAnimations(); + // if there is no onMessageHandler, we should call onHandleMessage first if (!triggerOnMessageHandler) { output = await response.clone().text(); @@ -168,6 +382,15 @@ export const fetchSSE = async (url: string, options: RequestInit & FetchSSEOptio const traceId = response.headers.get(LOBE_CHAT_TRACE_ID); const observationId = response.headers.get(LOBE_CHAT_OBSERVATION_ID); + + if (textController.isTokenRemain()) { + await textController.startAnimation(15); + } + + if (toolCallsController.isTokenRemain()) { + await toolCallsController.startAnimations(15); + } + await options?.onFinish?.(output, { observationId, toolCalls, traceId, type: finishedType }); }