UNPKG

@aj-archipelago/cortex

Version:

Cortex is a GraphQL API for AI. It provides a simple, extensible interface for using AI services from OpenAI, Azure and others.

281 lines (235 loc) 9.82 kB
// graphql.js // Setup the Apollo server and Express middleware import { ApolloServer } from '@apollo/server'; import { expressMiddleware } from '@apollo/server/express4'; import { ApolloServerPluginDrainHttpServer } from '@apollo/server/plugin/drainHttpServer'; import { ApolloServerPluginLandingPageLocalDefault } from '@apollo/server/plugin/landingPage/default'; import { makeExecutableSchema } from '@graphql-tools/schema'; import express from 'express'; import { useServer } from 'graphql-ws/lib/use/ws'; import http from 'http'; import Keyv from 'keyv'; import { WebSocketServer } from 'ws'; // eslint-disable-next-line import/no-extraneous-dependencies import responseCachePlugin from '@apollo/server-plugin-response-cache'; import { KeyvAdapter } from '@apollo/utils.keyvadapter'; import cors from 'cors'; import { buildModels, buildPathways } from '../config.js'; import logger from '../lib/logger.js'; import { buildModelEndpoints } from '../lib/requestExecutor.js'; import { startTestServer } from '../tests/server.js'; import { requestState } from './requestState.js'; import { cancelRequestResolver } from './resolver.js'; import subscriptions from './subscriptions.js'; import { getMessageTypeDefs, getPathwayTypeDef, userPathwayInputParameters } from './typeDef.js'; import { buildRestEndpoints } from './rest.js'; // Utility functions // Server plugins const getPlugins = (config) => { const plugins = [ ApolloServerPluginLandingPageLocalDefault({ embed: true }), // For local development. ]; //if cache is enabled and Redis is available, use it let cache; if (config.get('enableGraphqlCache') && config.get('storageConnectionString')) { cache = new KeyvAdapter(new Keyv(config.get('storageConnectionString'), { ssl: true, abortConnect: false, })); //caching similar strings, embedding hashing, ... #delta similarity // TODO: custom cache key: // https://www.apollographql.com/docs/apollo-server/performance/cache-backends#implementing-your-own-cache-backend plugins.push(responseCachePlugin({ cache })); logger.info('Using Redis for GraphQL cache'); } return { plugins, cache }; } // Type Definitions for GraphQL const getTypedefs = (pathways, pathwayManager) => { const defaultTypeDefs = `#graphql ${getMessageTypeDefs()} enum CacheControlScope { PUBLIC PRIVATE } directive @cacheControl( maxAge: Int scope: CacheControlScope inheritMaxAge: Boolean ) on FIELD_DEFINITION | OBJECT | INTERFACE | UNION type Query { _ : Boolean } type Mutation { cancelRequest(requestId: String!): Boolean } ${getPathwayTypeDef('ExecuteWorkspace', 'String')} extend type Query { executeWorkspace(userId: String!, pathwayName: String!, ${userPathwayInputParameters}): ExecuteWorkspace } type RequestSubscription { requestId: String progress: Float status: String data: String info: String error: String } type Subscription { requestProgress(requestIds: [String!]): RequestSubscription } `; const pathwayManagerTypeDefs = pathwayManager?.getTypeDefs() || ''; const pathwayTypeDefs = Object.values(pathways) .filter(p => !p.disabled) .map(p => p.typeDef(p).gqlDefinition); const typeDefs = [defaultTypeDefs, pathwayManagerTypeDefs, ...pathwayTypeDefs]; return typeDefs.join('\n'); } // Resolvers for GraphQL const getResolvers = (config, pathways, pathwayManager) => { const resolverFunctions = {}; for (const [name, pathway] of Object.entries(pathways)) { if (pathway.disabled) continue; resolverFunctions[name] = (parent, args, contextValue, info) => { // add shared state to contextValue contextValue.pathway = pathway; contextValue.config = config; return pathway.rootResolver(parent, args, contextValue, info); } } const pathwayManagerResolvers = pathwayManager?.getResolvers() || {}; const executeWorkspaceResolver = async (_, args, contextValue, info) => { const { userId, pathwayName, ...pathwayArgs } = args; const userPathway = await pathwayManager.getPathway(userId, pathwayName); contextValue.pathway = userPathway; contextValue.config = config; const result = await userPathway.rootResolver(null, pathwayArgs, contextValue, info); return result; }; const resolvers = { Query: { ...resolverFunctions, executeWorkspace: executeWorkspaceResolver }, Mutation: { 'cancelRequest': cancelRequestResolver, ...pathwayManagerResolvers.Mutation }, Subscription: subscriptions, } return resolvers; } // Build the server including the GraphQL schema and REST endpoints const build = async (config) => { // First perform config build const { pathwayManager } = await buildPathways(config); buildModels(config); // build model API endpoints and limiters buildModelEndpoints(config); //build api const pathways = config.get('pathways'); const typeDefs = getTypedefs(pathways, pathwayManager); const resolvers = getResolvers(config, pathways, pathwayManager); const schema = makeExecutableSchema({ typeDefs, resolvers }); const { plugins, cache } = getPlugins(config); const app = express(); app.use(express.json({ limit: '200mb' })); const httpServer = http.createServer(app); // Creating the WebSocket server const wsServer = new WebSocketServer({ // This is the `httpServer` we created in a previous step. server: httpServer, // Pass a different path here if your ApolloServer serves at // a different path. path: '/graphql', }); // Hand in the schema we just created and have the // WebSocketServer start listening. // Respects the keep alive setting in config in case you want to // turn it off for deployments that don't route the ping/pong frames const keepAlive = config.get('subscriptionKeepAlive'); logger.info(`Starting web socket server with subscription keep alive: ${keepAlive}`); const serverCleanup = useServer({ schema }, wsServer, keepAlive); const server = new ApolloServer({ schema: schema, introspection: config.get('env') === 'development', csrfPrevention: true, plugins: plugins.concat([// Proper shutdown for the HTTP server. ApolloServerPluginDrainHttpServer({ httpServer }), // Proper shutdown for the WebSocket server. { async serverWillStart() { return { async drainServer() { await serverCleanup.dispose(); }, }; }, } ]), }); // Healthcheck endpoint is valid regardless of auth app.get('/healthcheck', (req, res) => { res.status(200).send('OK'); }); // If CORTEX_API_KEY is set, we roll our own auth middleware - usually not used if you're being fronted by a proxy const cortexApiKeys = config.get('cortexApiKeys'); if (cortexApiKeys && Array.isArray(cortexApiKeys)) { app.use((req, res, next) => { let providedApiKey = req.headers['cortex-api-key'] || req.query['cortex-api-key']; if (!providedApiKey) { providedApiKey = req.headers['authorization']; providedApiKey = providedApiKey?.startsWith('Bearer ') ? providedApiKey.slice(7) : providedApiKey; } if (!cortexApiKeys.includes(providedApiKey)) { if (req.baseUrl === '/graphql' || req.headers['content-type'] === 'application/graphql') { res.status(401) .set('WWW-Authenticate', 'Cortex-Api-Key') .set('X-Cortex-Api-Key-Info', 'Server requires Cortex API Key') .json({ errors: [ { message: 'Unauthorized', extensions: { code: 'UNAUTHENTICATED', }, }, ], }); } else { res.status(401) .set('WWW-Authenticate', 'Cortex-Api-Key') .set('X-Cortex-Api-Key-Info', 'Server requires Cortex API Key') .send('Unauthorized'); } } else { next(); } }); } // Parse the body for REST endpoints app.use(express.json()); // Server Startup Function const startServer = async () => { // Start only the main server await server.start(); app.use( '/graphql', cors(), expressMiddleware(server, { context: async ({ req, res }) => ({ req, res, config, requestState }), }), ); // add the REST endpoints buildRestEndpoints(pathways, app, server, config); // Now that our HTTP server is fully set up, we can listen to it. httpServer.listen(config.get('PORT'), () => { logger.info(`🚀 Server is now running at http://localhost:${config.get('PORT')}/graphql`); }); }; return { server, startServer, startTestServer, cache, plugins, typeDefs, resolvers } } export { build };