@genkit-ai/ai
Version:
Genkit AI framework generative AI APIs.
1 lines • 8.52 kB
Source Map (JSON)
{"version":3,"sources":["../src/reranker.ts"],"sourcesContent":["/**\n * Copyright 2024 Google LLC\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n */\n\nimport { Action, defineAction, z } from '@genkit-ai/core';\nimport { Registry } from '@genkit-ai/core/registry';\nimport { Part, PartSchema } from './document.js';\nimport { Document, DocumentData, DocumentDataSchema } from './retriever.js';\n\nexport type RerankerFn<RerankerOptions extends z.ZodTypeAny> = (\n query: Document,\n documents: Document[],\n queryOpts: z.infer<RerankerOptions>\n) => Promise<RerankerResponse>;\n\n/**\n * Zod schema for a reranked document metadata.\n */\nexport const RankedDocumentMetadataSchema = z\n .object({\n score: z.number(), // Enforces that 'score' must be a number\n })\n .passthrough(); // Allows other properties in 'metadata' with any type\n\nexport const RankedDocumentDataSchema = z.object({\n content: z.array(PartSchema),\n metadata: RankedDocumentMetadataSchema,\n});\n\nexport type RankedDocumentData = z.infer<typeof RankedDocumentDataSchema>;\n\nexport class RankedDocument extends Document implements RankedDocumentData {\n content: Part[];\n metadata: { score: number } & Record<string, any>;\n\n constructor(data: RankedDocumentData) {\n super(data);\n this.content = data.content;\n this.metadata = data.metadata;\n }\n /**\n * Returns the score of the document.\n * @returns The score of the document.\n */\n score(): number {\n return this.metadata.score;\n }\n}\n\nconst RerankerRequestSchema = z.object({\n query: DocumentDataSchema,\n documents: z.array(DocumentDataSchema),\n options: z.any().optional(),\n});\n\nconst RerankerResponseSchema = z.object({\n documents: z.array(RankedDocumentDataSchema),\n});\ntype RerankerResponse = z.infer<typeof RerankerResponseSchema>;\n\nexport const RerankerInfoSchema = z.object({\n label: z.string().optional(),\n /** Supported model capabilities. */\n supports: z\n .object({\n /** Model can process media as part of the prompt (multimodal input). */\n media: z.boolean().optional(),\n })\n .optional(),\n});\nexport type RerankerInfo = z.infer<typeof RerankerInfoSchema>;\n\nexport type RerankerAction<CustomOptions extends z.ZodTypeAny = z.ZodTypeAny> =\n Action<typeof RerankerRequestSchema, typeof RerankerResponseSchema> & {\n __configSchema?: CustomOptions;\n };\n\nfunction rerankerWithMetadata<\n RerankerOptions extends z.ZodTypeAny = z.ZodTypeAny,\n>(\n reranker: Action<typeof RerankerRequestSchema, typeof RerankerResponseSchema>,\n configSchema?: RerankerOptions\n): RerankerAction<RerankerOptions> {\n const withMeta = reranker as RerankerAction<RerankerOptions>;\n withMeta.__configSchema = configSchema;\n return withMeta;\n}\n\n/**\n * Creates a reranker action for the provided {@link RerankerFn} implementation.\n */\nexport function defineReranker<OptionsType extends z.ZodTypeAny = z.ZodTypeAny>(\n registry: Registry,\n options: {\n name: string;\n configSchema?: OptionsType;\n info?: RerankerInfo;\n },\n runner: RerankerFn<OptionsType>\n) {\n const reranker = defineAction(\n registry,\n {\n actionType: 'reranker',\n name: options.name,\n inputSchema: options.configSchema\n ? RerankerRequestSchema.extend({\n options: options.configSchema.optional(),\n })\n : RerankerRequestSchema,\n outputSchema: RerankerResponseSchema,\n metadata: {\n type: 'reranker',\n info: options.info,\n },\n },\n (i) =>\n runner(\n new Document(i.query),\n i.documents.map((d) => new Document(d)),\n i.options\n )\n );\n const rwm = rerankerWithMetadata(\n reranker as Action<\n typeof RerankerRequestSchema,\n typeof RerankerResponseSchema\n >,\n options.configSchema\n );\n return rwm;\n}\n\nexport interface RerankerParams<\n CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,\n> {\n reranker: RerankerArgument<CustomOptions>;\n query: string | DocumentData;\n documents: DocumentData[];\n options?: z.infer<CustomOptions>;\n}\n\nexport type RerankerArgument<\n CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,\n> = RerankerAction<CustomOptions> | RerankerReference<CustomOptions> | string;\n\n/**\n * Reranks documents from a {@link RerankerArgument} based on the provided query.\n */\nexport async function rerank<CustomOptions extends z.ZodTypeAny>(\n registry: Registry,\n params: RerankerParams<CustomOptions>\n): Promise<Array<RankedDocument>> {\n let reranker: RerankerAction<CustomOptions>;\n if (typeof params.reranker === 'string') {\n reranker = await registry.lookupAction(`/reranker/${params.reranker}`);\n } else if (Object.hasOwnProperty.call(params.reranker, 'info')) {\n reranker = await registry.lookupAction(`/reranker/${params.reranker.name}`);\n } else {\n reranker = params.reranker as RerankerAction<CustomOptions>;\n }\n if (!reranker) {\n throw new Error('Unable to resolve the reranker');\n }\n const response = await reranker({\n query:\n typeof params.query === 'string'\n ? Document.fromText(params.query)\n : params.query,\n documents: params.documents,\n options: params.options,\n });\n\n return response.documents.map((d) => new RankedDocument(d));\n}\n\nexport const CommonRerankerOptionsSchema = z.object({\n k: z.number().describe('Number of documents to rerank').optional(),\n});\n\nexport interface RerankerReference<CustomOptions extends z.ZodTypeAny> {\n name: string;\n configSchema?: CustomOptions;\n info?: RerankerInfo;\n}\n\n/**\n * Helper method to configure a {@link RerankerReference} to a plugin.\n */\nexport function rerankerRef<\n CustomOptionsSchema extends z.ZodTypeAny = z.ZodTypeAny,\n>(\n options: RerankerReference<CustomOptionsSchema>\n): RerankerReference<CustomOptionsSchema> {\n return { ...options };\n}\n"],"mappings":"AAgBA,SAAiB,cAAc,SAAS;AAExC,SAAe,kBAAkB;AACjC,SAAS,UAAwB,0BAA0B;AAWpD,MAAM,+BAA+B,EACzC,OAAO;AAAA,EACN,OAAO,EAAE,OAAO;AAAA;AAClB,CAAC,EACA,YAAY;AAER,MAAM,2BAA2B,EAAE,OAAO;AAAA,EAC/C,SAAS,EAAE,MAAM,UAAU;AAAA,EAC3B,UAAU;AACZ,CAAC;AAIM,MAAM,uBAAuB,SAAuC;AAAA,EACzE;AAAA,EACA;AAAA,EAEA,YAAY,MAA0B;AACpC,UAAM,IAAI;AACV,SAAK,UAAU,KAAK;AACpB,SAAK,WAAW,KAAK;AAAA,EACvB;AAAA;AAAA;AAAA;AAAA;AAAA,EAKA,QAAgB;AACd,WAAO,KAAK,SAAS;AAAA,EACvB;AACF;AAEA,MAAM,wBAAwB,EAAE,OAAO;AAAA,EACrC,OAAO;AAAA,EACP,WAAW,EAAE,MAAM,kBAAkB;AAAA,EACrC,SAAS,EAAE,IAAI,EAAE,SAAS;AAC5B,CAAC;AAED,MAAM,yBAAyB,EAAE,OAAO;AAAA,EACtC,WAAW,EAAE,MAAM,wBAAwB;AAC7C,CAAC;AAGM,MAAM,qBAAqB,EAAE,OAAO;AAAA,EACzC,OAAO,EAAE,OAAO,EAAE,SAAS;AAAA;AAAA,EAE3B,UAAU,EACP,OAAO;AAAA;AAAA,IAEN,OAAO,EAAE,QAAQ,EAAE,SAAS;AAAA,EAC9B,CAAC,EACA,SAAS;AACd,CAAC;AAQD,SAAS,qBAGP,UACA,cACiC;AACjC,QAAM,WAAW;AACjB,WAAS,iBAAiB;AAC1B,SAAO;AACT;AAKO,SAAS,eACd,UACA,SAKA,QACA;AACA,QAAM,WAAW;AAAA,IACf;AAAA,IACA;AAAA,MACE,YAAY;AAAA,MACZ,MAAM,QAAQ;AAAA,MACd,aAAa,QAAQ,eACjB,sBAAsB,OAAO;AAAA,QAC3B,SAAS,QAAQ,aAAa,SAAS;AAAA,MACzC,CAAC,IACD;AAAA,MACJ,cAAc;AAAA,MACd,UAAU;AAAA,QACR,MAAM;AAAA,QACN,MAAM,QAAQ;AAAA,MAChB;AAAA,IACF;AAAA,IACA,CAAC,MACC;AAAA,MACE,IAAI,SAAS,EAAE,KAAK;AAAA,MACpB,EAAE,UAAU,IAAI,CAAC,MAAM,IAAI,SAAS,CAAC,CAAC;AAAA,MACtC,EAAE;AAAA,IACJ;AAAA,EACJ;AACA,QAAM,MAAM;AAAA,IACV;AAAA,IAIA,QAAQ;AAAA,EACV;AACA,SAAO;AACT;AAkBA,eAAsB,OACpB,UACA,QACgC;AAChC,MAAI;AACJ,MAAI,OAAO,OAAO,aAAa,UAAU;AACvC,eAAW,MAAM,SAAS,aAAa,aAAa,OAAO,QAAQ,EAAE;AAAA,EACvE,WAAW,OAAO,eAAe,KAAK,OAAO,UAAU,MAAM,GAAG;AAC9D,eAAW,MAAM,SAAS,aAAa,aAAa,OAAO,SAAS,IAAI,EAAE;AAAA,EAC5E,OAAO;AACL,eAAW,OAAO;AAAA,EACpB;AACA,MAAI,CAAC,UAAU;AACb,UAAM,IAAI,MAAM,gCAAgC;AAAA,EAClD;AACA,QAAM,WAAW,MAAM,SAAS;AAAA,IAC9B,OACE,OAAO,OAAO,UAAU,WACpB,SAAS,SAAS,OAAO,KAAK,IAC9B,OAAO;AAAA,IACb,WAAW,OAAO;AAAA,IAClB,SAAS,OAAO;AAAA,EAClB,CAAC;AAED,SAAO,SAAS,UAAU,IAAI,CAAC,MAAM,IAAI,eAAe,CAAC,CAAC;AAC5D;AAEO,MAAM,8BAA8B,EAAE,OAAO;AAAA,EAClD,GAAG,EAAE,OAAO,EAAE,SAAS,+BAA+B,EAAE,SAAS;AACnE,CAAC;AAWM,SAAS,YAGd,SACwC;AACxC,SAAO,EAAE,GAAG,QAAQ;AACtB;","names":[]}