UNPKG

ai

Version:

AI SDK by Vercel - The AI Toolkit for TypeScript and JavaScript

1,029 lines (993 loc) • 30.4 kB
import { LanguageModelV3Usage } from '@ai-sdk/provider'; import { convertArrayToReadableStream, convertAsyncIterableToArray, } from '@ai-sdk/provider-utils/test'; import { describe, expect, it } from 'vitest'; import { generateText, streamText } from '../generate-text'; import { wrapLanguageModel } from '../middleware/wrap-language-model'; import { MockLanguageModelV3 } from '../test/mock-language-model-v3'; import { extractReasoningMiddleware } from './extract-reasoning-middleware'; const testUsage: LanguageModelV3Usage = { inputTokens: { total: 5, noCache: 5, cacheRead: 0, cacheWrite: 0, }, outputTokens: { total: 10, text: 10, reasoning: 3, }, }; describe('extractReasoningMiddleware', () => { describe('wrapGenerate', () => { it('should extract reasoning from <think> tags', async () => { const mockModel = new MockLanguageModelV3({ async doGenerate() { return { content: [ { type: 'text', text: '<think>analyzing the request</think>Here is the response', }, ], finishReason: { unified: 'stop', raw: 'stop' }, usage: testUsage, warnings: [], }; }, }); const result = await generateText({ model: wrapLanguageModel({ model: mockModel, middleware: extractReasoningMiddleware({ tagName: 'think' }), }), prompt: 'Hello, how can I help?', }); expect(result.content).toMatchInlineSnapshot(` [ { "text": "analyzing the request", "type": "reasoning", }, { "text": "Here is the response", "type": "text", }, ] `); }); it('should extract reasoning from <think> tags when there is no text', async () => { const mockModel = new MockLanguageModelV3({ async doGenerate() { return { content: [ { type: 'text', text: '<think>analyzing the request\n</think>', }, ], finishReason: { unified: 'stop', raw: 'stop' }, usage: testUsage, warnings: [], }; }, }); const result = await generateText({ model: wrapLanguageModel({ model: mockModel, middleware: extractReasoningMiddleware({ tagName: 'think' }), }), prompt: 'Hello, how can I help?', }); expect(result.content).toMatchInlineSnapshot(` [ { "text": "analyzing the request ", "type": "reasoning", }, { "text": "", "type": "text", }, ] `); }); it('should extract reasoning from multiple <think> tags', async () => { const mockModel = new MockLanguageModelV3({ async doGenerate() { return { content: [ { type: 'text', text: '<think>analyzing the request</think>Here is the response<think>thinking about the response</think>more', }, ], finishReason: { unified: 'stop', raw: 'stop' }, usage: testUsage, warnings: [], }; }, }); const result = await generateText({ model: wrapLanguageModel({ model: mockModel, middleware: extractReasoningMiddleware({ tagName: 'think' }), }), prompt: 'Hello, how can I help?', }); expect(result.content).toMatchInlineSnapshot(` [ { "text": "analyzing the request thinking about the response", "type": "reasoning", }, { "text": "Here is the response more", "type": "text", }, ] `); }); it('should prepend <think> tag IFF startWithReasoning is true', async () => { const mockModel = new MockLanguageModelV3({ async doGenerate() { return { content: [ { type: 'text', text: 'analyzing the request</think>Here is the response', }, ], finishReason: { unified: 'stop', raw: 'stop' }, usage: testUsage, warnings: [], }; }, }); const resultTrue = await generateText({ model: wrapLanguageModel({ model: mockModel, middleware: extractReasoningMiddleware({ tagName: 'think', startWithReasoning: true, }), }), prompt: 'Hello, how can I help?', }); const resultFalse = await generateText({ model: wrapLanguageModel({ model: mockModel, middleware: extractReasoningMiddleware({ tagName: 'think', }), }), prompt: 'Hello, how can I help?', }); expect(resultTrue.content).toMatchInlineSnapshot(` [ { "text": "analyzing the request", "type": "reasoning", }, { "text": "Here is the response", "type": "text", }, ] `); expect(resultFalse.content).toMatchInlineSnapshot(` [ { "text": "analyzing the request</think>Here is the response", "type": "text", }, ] `); }); it('should preserve reasoning property even when rest contains other properties', async () => { const mockModel = new MockLanguageModelV3({ async doGenerate() { return { content: [ { type: 'text', text: '<think>analyzing the request</think>Here is the response', }, ], finishReason: { unified: 'stop', raw: 'stop' }, usage: testUsage, reasoning: undefined, warnings: [], }; }, }); const result = await generateText({ model: wrapLanguageModel({ model: mockModel, middleware: extractReasoningMiddleware({ tagName: 'think' }), }), prompt: 'Hello, how can I help?', }); expect(result.content).toMatchInlineSnapshot(` [ { "text": "analyzing the request", "type": "reasoning", }, { "text": "Here is the response", "type": "text", }, ] `); }); }); describe('wrapStream', () => { it('should extract reasoning from split <think> tags', async () => { const mockModel = new MockLanguageModelV3({ async doStream() { return { stream: convertArrayToReadableStream([ { type: 'response-metadata', id: 'id-0', modelId: 'mock-model-id', timestamp: new Date(0), }, { type: 'text-start', id: '1' }, { type: 'text-delta', id: '1', delta: '<think>' }, { type: 'text-delta', id: '1', delta: 'ana' }, { type: 'text-delta', id: '1', delta: 'lyzing the request' }, { type: 'text-delta', id: '1', delta: '</think>' }, { type: 'text-delta', id: '1', delta: 'Here' }, { type: 'text-delta', id: '1', delta: ' is the response' }, { type: 'text-end', id: '1' }, { type: 'finish', finishReason: { unified: 'stop', raw: 'stop' }, usage: testUsage, }, ]), }; }, }); const result = streamText({ model: wrapLanguageModel({ model: mockModel, middleware: extractReasoningMiddleware({ tagName: 'think' }), }), prompt: 'Hello, how can I help?', }); expect(await convertAsyncIterableToArray(result.fullStream)) .toMatchInlineSnapshot(` [ { "type": "start", }, { "request": {}, "type": "start-step", "warnings": [], }, { "id": "reasoning-0", "type": "reasoning-start", }, { "id": "reasoning-0", "providerMetadata": undefined, "text": "ana", "type": "reasoning-delta", }, { "id": "reasoning-0", "providerMetadata": undefined, "text": "lyzing the request", "type": "reasoning-delta", }, { "id": "reasoning-0", "type": "reasoning-end", }, { "id": "1", "type": "text-start", }, { "id": "1", "providerMetadata": undefined, "text": "Here", "type": "text-delta", }, { "id": "1", "providerMetadata": undefined, "text": " is the response", "type": "text-delta", }, { "id": "1", "type": "text-end", }, { "finishReason": "stop", "providerMetadata": undefined, "rawFinishReason": "stop", "response": { "headers": undefined, "id": "id-0", "modelId": "mock-model-id", "timestamp": 1970-01-01T00:00:00.000Z, }, "type": "finish-step", "usage": { "cachedInputTokens": 0, "inputTokenDetails": { "cacheReadTokens": 0, "cacheWriteTokens": 0, "noCacheTokens": 5, }, "inputTokens": 5, "outputTokenDetails": { "reasoningTokens": 3, "textTokens": 10, }, "outputTokens": 10, "raw": undefined, "reasoningTokens": 3, "totalTokens": 15, }, }, { "finishReason": "stop", "rawFinishReason": "stop", "totalUsage": { "cachedInputTokens": 0, "inputTokenDetails": { "cacheReadTokens": 0, "cacheWriteTokens": 0, "noCacheTokens": 5, }, "inputTokens": 5, "outputTokenDetails": { "reasoningTokens": 3, "textTokens": 10, }, "outputTokens": 10, "reasoningTokens": 3, "totalTokens": 15, }, "type": "finish", }, ] `); }); it('should extract reasoning from single chunk with multiple <think> tags', async () => { const mockModel = new MockLanguageModelV3({ async doStream() { return { stream: convertArrayToReadableStream([ { type: 'response-metadata', id: 'id-0', modelId: 'mock-model-id', timestamp: new Date(0), }, { type: 'text-start', id: '1' }, { type: 'text-delta', id: '1', delta: '<think>analyzing the request</think>Here is the response<think>thinking about the response</think>more', }, { type: 'text-end', id: '1' }, { type: 'finish', finishReason: { unified: 'stop', raw: 'stop' }, usage: testUsage, }, ]), }; }, }); const result = streamText({ model: wrapLanguageModel({ model: mockModel, middleware: extractReasoningMiddleware({ tagName: 'think' }), }), prompt: 'Hello, how can I help?', }); expect(await convertAsyncIterableToArray(result.fullStream)) .toMatchInlineSnapshot(` [ { "type": "start", }, { "request": {}, "type": "start-step", "warnings": [], }, { "id": "reasoning-0", "type": "reasoning-start", }, { "id": "reasoning-0", "providerMetadata": undefined, "text": "analyzing the request", "type": "reasoning-delta", }, { "id": "reasoning-0", "type": "reasoning-end", }, { "id": "1", "type": "text-start", }, { "id": "1", "providerMetadata": undefined, "text": "Here is the response", "type": "text-delta", }, { "id": "reasoning-1", "type": "reasoning-start", }, { "id": "reasoning-1", "providerMetadata": undefined, "text": " thinking about the response", "type": "reasoning-delta", }, { "id": "reasoning-1", "type": "reasoning-end", }, { "id": "1", "providerMetadata": undefined, "text": " more", "type": "text-delta", }, { "id": "1", "type": "text-end", }, { "finishReason": "stop", "providerMetadata": undefined, "rawFinishReason": "stop", "response": { "headers": undefined, "id": "id-0", "modelId": "mock-model-id", "timestamp": 1970-01-01T00:00:00.000Z, }, "type": "finish-step", "usage": { "cachedInputTokens": 0, "inputTokenDetails": { "cacheReadTokens": 0, "cacheWriteTokens": 0, "noCacheTokens": 5, }, "inputTokens": 5, "outputTokenDetails": { "reasoningTokens": 3, "textTokens": 10, }, "outputTokens": 10, "raw": undefined, "reasoningTokens": 3, "totalTokens": 15, }, }, { "finishReason": "stop", "rawFinishReason": "stop", "totalUsage": { "cachedInputTokens": 0, "inputTokenDetails": { "cacheReadTokens": 0, "cacheWriteTokens": 0, "noCacheTokens": 5, }, "inputTokens": 5, "outputTokenDetails": { "reasoningTokens": 3, "textTokens": 10, }, "outputTokens": 10, "reasoningTokens": 3, "totalTokens": 15, }, "type": "finish", }, ] `); }); it('should extract reasoning from <think> when there is no text', async () => { const mockModel = new MockLanguageModelV3({ async doStream() { return { stream: convertArrayToReadableStream([ { type: 'response-metadata', id: 'id-0', modelId: 'mock-model-id', timestamp: new Date(0), }, { type: 'text-start', id: '1' }, { type: 'text-delta', id: '1', delta: '<think>' }, { type: 'text-delta', id: '1', delta: 'ana' }, { type: 'text-delta', id: '1', delta: 'lyzing the request\n' }, { type: 'text-delta', id: '1', delta: '</think>' }, { type: 'text-end', id: '1' }, { type: 'finish', finishReason: { unified: 'stop', raw: 'stop' }, usage: testUsage, }, ]), }; }, }); const result = streamText({ model: wrapLanguageModel({ model: mockModel, middleware: extractReasoningMiddleware({ tagName: 'think' }), }), prompt: 'Hello, how can I help?', }); expect(await convertAsyncIterableToArray(result.fullStream)) .toMatchInlineSnapshot(` [ { "type": "start", }, { "request": {}, "type": "start-step", "warnings": [], }, { "id": "reasoning-0", "type": "reasoning-start", }, { "id": "reasoning-0", "providerMetadata": undefined, "text": "ana", "type": "reasoning-delta", }, { "id": "reasoning-0", "providerMetadata": undefined, "text": "lyzing the request ", "type": "reasoning-delta", }, { "id": "reasoning-0", "type": "reasoning-end", }, { "id": "1", "type": "text-start", }, { "id": "1", "type": "text-end", }, { "finishReason": "stop", "providerMetadata": undefined, "rawFinishReason": "stop", "response": { "headers": undefined, "id": "id-0", "modelId": "mock-model-id", "timestamp": 1970-01-01T00:00:00.000Z, }, "type": "finish-step", "usage": { "cachedInputTokens": 0, "inputTokenDetails": { "cacheReadTokens": 0, "cacheWriteTokens": 0, "noCacheTokens": 5, }, "inputTokens": 5, "outputTokenDetails": { "reasoningTokens": 3, "textTokens": 10, }, "outputTokens": 10, "raw": undefined, "reasoningTokens": 3, "totalTokens": 15, }, }, { "finishReason": "stop", "rawFinishReason": "stop", "totalUsage": { "cachedInputTokens": 0, "inputTokenDetails": { "cacheReadTokens": 0, "cacheWriteTokens": 0, "noCacheTokens": 5, }, "inputTokens": 5, "outputTokenDetails": { "reasoningTokens": 3, "textTokens": 10, }, "outputTokens": 10, "reasoningTokens": 3, "totalTokens": 15, }, "type": "finish", }, ] `); }); it('should prepend <think> tag if startWithReasoning is true', async () => { const mockModel = new MockLanguageModelV3({ async doStream() { return { stream: convertArrayToReadableStream([ { type: 'response-metadata', id: 'id-0', modelId: 'mock-model-id', timestamp: new Date(0), }, { type: 'text-start', id: '1' }, { type: 'text-delta', id: '1', delta: 'ana' }, { type: 'text-delta', id: '1', delta: 'lyzing the request\n' }, { type: 'text-delta', id: '1', delta: '</think>' }, { type: 'text-delta', id: '1', delta: 'this is the response' }, { type: 'text-end', id: '1' }, { type: 'finish', finishReason: { unified: 'stop', raw: 'stop' }, usage: testUsage, }, ]), }; }, }); const resultTrue = streamText({ model: wrapLanguageModel({ model: mockModel, middleware: extractReasoningMiddleware({ tagName: 'think', startWithReasoning: true, }), }), prompt: 'Hello, how can I help?', }); const resultFalse = streamText({ model: wrapLanguageModel({ model: mockModel, middleware: extractReasoningMiddleware({ tagName: 'think' }), }), prompt: 'Hello, how can I help?', }); expect(await convertAsyncIterableToArray(resultTrue.fullStream)) .toMatchInlineSnapshot(` [ { "type": "start", }, { "request": {}, "type": "start-step", "warnings": [], }, { "id": "reasoning-0", "type": "reasoning-start", }, { "id": "reasoning-0", "providerMetadata": undefined, "text": "ana", "type": "reasoning-delta", }, { "id": "reasoning-0", "providerMetadata": undefined, "text": "lyzing the request ", "type": "reasoning-delta", }, { "id": "reasoning-0", "type": "reasoning-end", }, { "id": "1", "type": "text-start", }, { "id": "1", "providerMetadata": undefined, "text": "this is the response", "type": "text-delta", }, { "id": "1", "type": "text-end", }, { "finishReason": "stop", "providerMetadata": undefined, "rawFinishReason": "stop", "response": { "headers": undefined, "id": "id-0", "modelId": "mock-model-id", "timestamp": 1970-01-01T00:00:00.000Z, }, "type": "finish-step", "usage": { "cachedInputTokens": 0, "inputTokenDetails": { "cacheReadTokens": 0, "cacheWriteTokens": 0, "noCacheTokens": 5, }, "inputTokens": 5, "outputTokenDetails": { "reasoningTokens": 3, "textTokens": 10, }, "outputTokens": 10, "raw": undefined, "reasoningTokens": 3, "totalTokens": 15, }, }, { "finishReason": "stop", "rawFinishReason": "stop", "totalUsage": { "cachedInputTokens": 0, "inputTokenDetails": { "cacheReadTokens": 0, "cacheWriteTokens": 0, "noCacheTokens": 5, }, "inputTokens": 5, "outputTokenDetails": { "reasoningTokens": 3, "textTokens": 10, }, "outputTokens": 10, "reasoningTokens": 3, "totalTokens": 15, }, "type": "finish", }, ] `); expect(await convertAsyncIterableToArray(resultFalse.fullStream)) .toMatchInlineSnapshot(` [ { "type": "start", }, { "request": {}, "type": "start-step", "warnings": [], }, { "id": "1", "type": "text-start", }, { "id": "1", "providerMetadata": undefined, "text": "ana", "type": "text-delta", }, { "id": "1", "providerMetadata": undefined, "text": "lyzing the request ", "type": "text-delta", }, { "id": "1", "providerMetadata": undefined, "text": "</think>", "type": "text-delta", }, { "id": "1", "providerMetadata": undefined, "text": "this is the response", "type": "text-delta", }, { "id": "1", "type": "text-end", }, { "finishReason": "stop", "providerMetadata": undefined, "rawFinishReason": "stop", "response": { "headers": undefined, "id": "id-0", "modelId": "mock-model-id", "timestamp": 1970-01-01T00:00:00.000Z, }, "type": "finish-step", "usage": { "cachedInputTokens": 0, "inputTokenDetails": { "cacheReadTokens": 0, "cacheWriteTokens": 0, "noCacheTokens": 5, }, "inputTokens": 5, "outputTokenDetails": { "reasoningTokens": 3, "textTokens": 10, }, "outputTokens": 10, "raw": undefined, "reasoningTokens": 3, "totalTokens": 15, }, }, { "finishReason": "stop", "rawFinishReason": "stop", "totalUsage": { "cachedInputTokens": 0, "inputTokenDetails": { "cacheReadTokens": 0, "cacheWriteTokens": 0, "noCacheTokens": 5, }, "inputTokens": 5, "outputTokenDetails": { "reasoningTokens": 3, "textTokens": 10, }, "outputTokens": 10, "reasoningTokens": 3, "totalTokens": 15, }, "type": "finish", }, ] `); }); it('should keep original text when <think> tag is not present', async () => { const mockModel = new MockLanguageModelV3({ async doStream() { return { stream: convertArrayToReadableStream([ { type: 'response-metadata', id: 'id-0', modelId: 'mock-model-id', timestamp: new Date(0), }, { type: 'text-start', id: '1' }, { type: 'text-delta', id: '1', delta: 'this is the response' }, { type: 'text-end', id: '1' }, { type: 'finish', finishReason: { unified: 'stop', raw: 'stop' }, usage: testUsage, }, ]), }; }, }); const result = streamText({ model: wrapLanguageModel({ model: mockModel, middleware: extractReasoningMiddleware({ tagName: 'think' }), }), prompt: 'Hello, how can I help?', }); expect(await convertAsyncIterableToArray(result.fullStream)) .toMatchInlineSnapshot(` [ { "type": "start", }, { "request": {}, "type": "start-step", "warnings": [], }, { "id": "1", "type": "text-start", }, { "id": "1", "providerMetadata": undefined, "text": "this is the response", "type": "text-delta", }, { "id": "1", "type": "text-end", }, { "finishReason": "stop", "providerMetadata": undefined, "rawFinishReason": "stop", "response": { "headers": undefined, "id": "id-0", "modelId": "mock-model-id", "timestamp": 1970-01-01T00:00:00.000Z, }, "type": "finish-step", "usage": { "cachedInputTokens": 0, "inputTokenDetails": { "cacheReadTokens": 0, "cacheWriteTokens": 0, "noCacheTokens": 5, }, "inputTokens": 5, "outputTokenDetails": { "reasoningTokens": 3, "textTokens": 10, }, "outputTokens": 10, "raw": undefined, "reasoningTokens": 3, "totalTokens": 15, }, }, { "finishReason": "stop", "rawFinishReason": "stop", "totalUsage": { "cachedInputTokens": 0, "inputTokenDetails": { "cacheReadTokens": 0, "cacheWriteTokens": 0, "noCacheTokens": 5, }, "inputTokens": 5, "outputTokenDetails": { "reasoningTokens": 3, "textTokens": 10, }, "outputTokens": 10, "reasoningTokens": 3, "totalTokens": 15, }, "type": "finish", }, ] `); }); }); });