@mondaydotcomorg/atp-client
Version:
Client SDK for Agent Tool Protocol
195 lines (168 loc) • 5.41 kB
text/typescript
import { CallbackType, ToolOperation } from '@mondaydotcomorg/atp-protocol';
import type {
ClientLLMHandler,
ClientApprovalHandler,
ClientEmbeddingHandler,
ClientServiceProviders,
ClientTool,
ClientToolDefinition,
ClientToolHandler,
} from '@mondaydotcomorg/atp-protocol';
const LLMOperation = {
CALL: 'call',
EXTRACT: 'extract',
CLASSIFY: 'classify',
} as const;
const EmbeddingOperation = {
EMBED: 'embed',
SEARCH: 'search',
} as const;
export class ServiceProviders {
private providers: ClientServiceProviders = {};
private toolHandlers: Map<string, ClientToolHandler> = new Map();
constructor(providers?: ClientServiceProviders) {
this.providers = providers || {};
if (providers?.tools) {
for (const tool of providers.tools) {
this.toolHandlers.set(tool.name, tool.handler);
}
}
}
provideLLM(handler: ClientLLMHandler): void {
this.providers.llm = handler;
}
provideApproval(handler: ClientApprovalHandler): void {
this.providers.approval = handler;
}
provideEmbedding(handler: ClientEmbeddingHandler): void {
this.providers.embedding = handler;
}
provideTools(tools: ClientTool[]): void {
this.providers.tools = tools;
for (const tool of tools) {
this.toolHandlers.set(tool.name, tool.handler);
}
}
getLLM(): ClientLLMHandler | undefined {
return this.providers.llm;
}
getApproval(): ClientApprovalHandler | undefined {
return this.providers.approval;
}
getEmbedding(): ClientEmbeddingHandler | undefined {
return this.providers.embedding;
}
getTools(): ClientTool[] | undefined {
return this.providers.tools;
}
/**
* Get tool definitions (without handlers) for sending to server
*/
getToolDefinitions(): ClientToolDefinition[] {
if (!this.providers.tools) {
return [];
}
return this.providers.tools.map((tool) => {
const { handler, ...definition } = tool;
return definition;
});
}
/**
* Check if client has tools
*/
hasTools(): boolean {
return !!(this.providers.tools && this.providers.tools.length > 0);
}
/**
* Check if client has any services or tools
*/
hasAnyServices(): boolean {
return !!(
this.providers.llm ||
this.providers.approval ||
this.providers.embedding ||
this.hasTools()
);
}
/**
* Check if client has a service for a specific callback type
*/
hasServiceForCallback(callbackType: CallbackType): boolean {
switch (callbackType) {
case CallbackType.LLM:
return !!this.providers.llm;
case CallbackType.APPROVAL:
return !!this.providers.approval;
case CallbackType.EMBEDDING:
return !!this.providers.embedding;
case CallbackType.TOOL:
return this.hasTools();
default:
return false;
}
}
async handleCallback(callbackType: CallbackType, payload: any): Promise<any> {
if (payload.operation === 'batch_parallel' && payload.calls) {
return await Promise.all(
payload.calls.map(async (call: any) => {
return await this.handleCallback(call.type, {
...call.payload,
operation: call.operation,
});
})
);
}
switch (callbackType) {
case CallbackType.LLM:
if (!this.providers.llm) {
throw new Error('LLM service not provided by client');
}
if (payload.operation === LLMOperation.CALL) {
return await this.providers.llm.call(payload.prompt, payload.options);
} else if (payload.operation === LLMOperation.EXTRACT && this.providers.llm.extract) {
return await this.providers.llm.extract(payload.prompt, payload.schema, payload.options);
} else if (payload.operation === LLMOperation.CLASSIFY && this.providers.llm.classify) {
return await this.providers.llm.classify(
payload.text,
payload.categories,
payload.options
);
}
throw new Error(`Unsupported LLM operation: ${payload.operation}`);
case CallbackType.APPROVAL:
if (!this.providers.approval) {
throw new Error('Approval service not provided by client');
}
const contextWithExecutionId = payload.context
? { ...payload.context, executionId: payload.executionId }
: { executionId: payload.executionId };
return await this.providers.approval.request(payload.message, contextWithExecutionId);
case CallbackType.EMBEDDING:
if (!this.providers.embedding) {
throw new Error('Embedding service not provided by client');
}
if (payload.operation === EmbeddingOperation.EMBED) {
return await this.providers.embedding.embed(payload.text);
} else if (payload.operation === EmbeddingOperation.SEARCH) {
const queryEmbedding = await this.providers.embedding.embed(payload.query);
return queryEmbedding;
} else if (payload.operation === 'similarity' && this.providers.embedding.similarity) {
return await this.providers.embedding.similarity(payload.text1, payload.text2);
}
throw new Error(`Unsupported embedding operation: ${payload.operation}`);
case CallbackType.TOOL:
if (payload.operation === ToolOperation.CALL) {
const toolName = payload.toolName;
const handler = this.toolHandlers.get(toolName);
if (!handler) {
throw new Error(`Tool '${toolName}' not found in client tools`);
}
const result = await handler(payload.input);
return result;
}
throw new Error(`Unsupported tool operation: ${payload.operation}`);
default:
throw new Error(`Unknown callback type: ${callbackType}`);
}
}
}