UNPKG

mcpresso

Version:

TypeScript package for Model Context Protocol (MCP) utilities and tools

505 lines (439 loc) 15.1 kB
import { Pool, PoolClient } from 'pg'; import type { MCPOAuthStorage, OAuthClient, OAuthUser, AuthorizationCode, AccessToken, RefreshToken } from 'mcpresso-oauth-server'; export class PostgresStorage implements MCPOAuthStorage { private pool: Pool; constructor(databaseUrl: string) { this.pool = new Pool({ connectionString: databaseUrl, ssl: process.env.NODE_ENV === 'production' ? { rejectUnauthorized: false } : false, }); } async initialize(): Promise<void> { const client = await this.pool.connect(); try { await client.query(` CREATE TABLE IF NOT EXISTS oauth_clients ( id VARCHAR(255) PRIMARY KEY, secret VARCHAR(255) NOT NULL, name VARCHAR(255) NOT NULL, type VARCHAR(50) NOT NULL, redirect_uris TEXT[] NOT NULL, scopes TEXT[] NOT NULL, grant_types TEXT[] NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) `); await client.query(` CREATE TABLE IF NOT EXISTS oauth_users ( id VARCHAR(255) PRIMARY KEY, username VARCHAR(255) UNIQUE NOT NULL, email VARCHAR(255) UNIQUE, hashed_password VARCHAR(255), scopes TEXT[] NOT NULL, profile JSONB, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) `); await client.query(` CREATE TABLE IF NOT EXISTS oauth_authorization_codes ( code VARCHAR(255) PRIMARY KEY, client_id VARCHAR(255) NOT NULL, user_id VARCHAR(255) NOT NULL, redirect_uri VARCHAR(500) NOT NULL, scope VARCHAR(500) NOT NULL, resource VARCHAR(500), code_challenge VARCHAR(255), code_challenge_method VARCHAR(10), expires_at TIMESTAMP NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) `); await client.query(` CREATE TABLE IF NOT EXISTS oauth_access_tokens ( access_token VARCHAR(255) PRIMARY KEY, client_id VARCHAR(255) NOT NULL, user_id VARCHAR(255) NOT NULL, scope VARCHAR(500) NOT NULL, expires_at TIMESTAMP NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) `); await client.query(` CREATE TABLE IF NOT EXISTS oauth_refresh_tokens ( refresh_token VARCHAR(255) PRIMARY KEY, access_token_id VARCHAR(255) NOT NULL, client_id VARCHAR(255) NOT NULL, user_id VARCHAR(255) NOT NULL, scope VARCHAR(500) NOT NULL, expires_at TIMESTAMP NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) `); // Create indexes for better performance await client.query(` CREATE INDEX IF NOT EXISTS idx_oauth_clients_id ON oauth_clients(id) `); await client.query(` CREATE INDEX IF NOT EXISTS idx_oauth_users_id ON oauth_users(id) `); await client.query(` CREATE INDEX IF NOT EXISTS idx_oauth_users_username ON oauth_users(username) `); await client.query(` CREATE INDEX IF NOT EXISTS idx_oauth_authorization_codes_code ON oauth_authorization_codes(code) `); await client.query(` CREATE INDEX IF NOT EXISTS idx_oauth_access_tokens_token ON oauth_access_tokens(access_token) `); await client.query(` CREATE INDEX IF NOT EXISTS idx_oauth_refresh_tokens_token ON oauth_refresh_tokens(refresh_token) `); } finally { client.release(); } } // ===== CLIENT MANAGEMENT ===== async createClient(client: OAuthClient): Promise<void> { const query = ` INSERT INTO oauth_clients (id, secret, name, type, redirect_uris, scopes, grant_types) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (id) DO UPDATE SET secret = EXCLUDED.secret, name = EXCLUDED.name, type = EXCLUDED.type, redirect_uris = EXCLUDED.redirect_uris, scopes = EXCLUDED.scopes, grant_types = EXCLUDED.grant_types, updated_at = CURRENT_TIMESTAMP `; await this.pool.query(query, [ client.id, client.secret, client.name, client.type, client.redirectUris, client.scopes, client.grantTypes ]); } async getClient(clientId: string): Promise<OAuthClient | null> { const result = await this.pool.query( 'SELECT * FROM oauth_clients WHERE id = $1', [clientId] ); if (result.rows.length === 0) return null; const row = result.rows[0]; return { id: row.id, secret: row.secret, name: row.name, type: row.type, redirectUris: row.redirect_uris, scopes: row.scopes, grantTypes: row.grant_types, createdAt: row.created_at, updatedAt: row.updated_at }; } async listClients(): Promise<OAuthClient[]> { const result = await this.pool.query('SELECT * FROM oauth_clients ORDER BY created_at DESC'); return result.rows.map(row => ({ id: row.id, secret: row.secret, name: row.name, type: row.type, redirectUris: row.redirect_uris, scopes: row.scopes, grantTypes: row.grant_types, createdAt: row.created_at, updatedAt: row.updated_at })); } async updateClient(clientId: string, updates: Partial<OAuthClient>): Promise<void> { const fields = []; const values = []; let paramIndex = 1; if (updates.secret !== undefined) { fields.push(`secret = $${paramIndex++}`); values.push(updates.secret); } if (updates.name !== undefined) { fields.push(`name = $${paramIndex++}`); values.push(updates.name); } if (updates.type !== undefined) { fields.push(`type = $${paramIndex++}`); values.push(updates.type); } if (updates.redirectUris !== undefined) { fields.push(`redirect_uris = $${paramIndex++}`); values.push(updates.redirectUris); } if (updates.scopes !== undefined) { fields.push(`scopes = $${paramIndex++}`); values.push(updates.scopes); } if (updates.grantTypes !== undefined) { fields.push(`grant_types = $${paramIndex++}`); values.push(updates.grantTypes); } fields.push(`updated_at = CURRENT_TIMESTAMP`); values.push(clientId); const query = `UPDATE oauth_clients SET ${fields.join(', ')} WHERE id = $${paramIndex}`; await this.pool.query(query, values); } async deleteClient(clientId: string): Promise<void> { await this.pool.query('DELETE FROM oauth_clients WHERE id = $1', [clientId]); } // ===== USER MANAGEMENT ===== async createUser(user: OAuthUser): Promise<void> { const query = ` INSERT INTO oauth_users (id, username, email, hashed_password, scopes, profile) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (id) DO UPDATE SET username = EXCLUDED.username, email = EXCLUDED.email, hashed_password = EXCLUDED.hashed_password, scopes = EXCLUDED.scopes, profile = EXCLUDED.profile, updated_at = CURRENT_TIMESTAMP `; await this.pool.query(query, [ user.id, user.username, user.email, user.hashedPassword, user.scopes, user.profile ? JSON.stringify(user.profile) : null ]); } async getUser(userId: string): Promise<OAuthUser | null> { const result = await this.pool.query( 'SELECT * FROM oauth_users WHERE id = $1', [userId] ); if (result.rows.length === 0) return null; const row = result.rows[0]; return { id: row.id, username: row.username, email: row.email, hashedPassword: row.hashed_password, scopes: row.scopes, profile: row.profile, createdAt: row.created_at, updatedAt: row.updated_at }; } async getUserByUsername(username: string): Promise<OAuthUser | null> { const result = await this.pool.query( 'SELECT * FROM oauth_users WHERE username = $1', [username] ); if (result.rows.length === 0) return null; const row = result.rows[0]; return { id: row.id, username: row.username, email: row.email, hashedPassword: row.hashed_password, scopes: row.scopes, profile: row.profile, createdAt: row.created_at, updatedAt: row.updated_at }; } async listUsers(): Promise<OAuthUser[]> { const result = await this.pool.query('SELECT * FROM oauth_users ORDER BY created_at DESC'); return result.rows.map(row => ({ id: row.id, username: row.username, email: row.email, hashedPassword: row.hashed_password, scopes: row.scopes, profile: row.profile, createdAt: row.created_at, updatedAt: row.updated_at })); } async updateUser(userId: string, updates: Partial<OAuthUser>): Promise<void> { const fields = []; const values = []; let paramIndex = 1; if (updates.username !== undefined) { fields.push(`username = $${paramIndex++}`); values.push(updates.username); } if (updates.email !== undefined) { fields.push(`email = $${paramIndex++}`); values.push(updates.email); } if (updates.hashedPassword !== undefined) { fields.push(`hashed_password = $${paramIndex++}`); values.push(updates.hashedPassword); } if (updates.scopes !== undefined) { fields.push(`scopes = $${paramIndex++}`); values.push(updates.scopes); } if (updates.profile !== undefined) { fields.push(`profile = $${paramIndex++}`); values.push(updates.profile ? JSON.stringify(updates.profile) : null); } fields.push(`updated_at = CURRENT_TIMESTAMP`); values.push(userId); const query = `UPDATE oauth_users SET ${fields.join(', ')} WHERE id = $${paramIndex}`; await this.pool.query(query, values); } async deleteUser(userId: string): Promise<void> { await this.pool.query('DELETE FROM oauth_users WHERE id = $1', [userId]); } // ===== AUTHORIZATION CODES ===== async createAuthorizationCode(code: AuthorizationCode): Promise<void> { const query = ` INSERT INTO oauth_authorization_codes (code, client_id, user_id, redirect_uri, scope, resource, code_challenge, code_challenge_method, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) `; await this.pool.query(query, [ code.code, code.clientId, code.userId, code.redirectUri, code.scope, code.resource, code.codeChallenge, code.codeChallengeMethod, code.expiresAt ]); } async getAuthorizationCode(code: string): Promise<AuthorizationCode | null> { const result = await this.pool.query( 'SELECT * FROM oauth_authorization_codes WHERE code = $1 AND expires_at > CURRENT_TIMESTAMP', [code] ); if (result.rows.length === 0) return null; const row = result.rows[0]; return { code: row.code, clientId: row.client_id, userId: row.user_id, redirectUri: row.redirect_uri, scope: row.scope, resource: row.resource, codeChallenge: row.code_challenge, codeChallengeMethod: row.code_challenge_method, expiresAt: row.expires_at, createdAt: row.created_at }; } async deleteAuthorizationCode(code: string): Promise<void> { await this.pool.query('DELETE FROM oauth_authorization_codes WHERE code = $1', [code]); } async cleanupExpiredCodes(): Promise<void> { await this.pool.query('DELETE FROM oauth_authorization_codes WHERE expires_at <= CURRENT_TIMESTAMP'); } // ===== ACCESS TOKENS ===== async createAccessToken(token: AccessToken): Promise<void> { const query = ` INSERT INTO oauth_access_tokens (access_token, client_id, user_id, scope, expires_at) VALUES ($1, $2, $3, $4, $5) `; await this.pool.query(query, [ token.token, token.clientId, token.userId, token.scope, token.expiresAt ]); } async getAccessToken(token: string): Promise<AccessToken | null> { const result = await this.pool.query( 'SELECT * FROM oauth_access_tokens WHERE access_token = $1 AND expires_at > CURRENT_TIMESTAMP', [token] ); if (result.rows.length === 0) return null; const row = result.rows[0]; return { token: row.access_token, clientId: row.client_id, userId: row.user_id, scope: row.scope, expiresAt: row.expires_at, createdAt: row.created_at }; } async deleteAccessToken(token: string): Promise<void> { await this.pool.query('DELETE FROM oauth_access_tokens WHERE access_token = $1', [token]); } async cleanupExpiredTokens(): Promise<void> { await this.pool.query('DELETE FROM oauth_access_tokens WHERE expires_at <= CURRENT_TIMESTAMP'); } // ===== REFRESH TOKENS ===== async createRefreshToken(token: RefreshToken): Promise<void> { const query = ` INSERT INTO oauth_refresh_tokens (refresh_token, access_token_id, client_id, user_id, scope, expires_at) VALUES ($1, $2, $3, $4, $5, $6) `; await this.pool.query(query, [ token.token, token.accessTokenId, token.clientId, token.userId, token.scope, token.expiresAt ]); } async getRefreshToken(token: string): Promise<RefreshToken | null> { const result = await this.pool.query( 'SELECT * FROM oauth_refresh_tokens WHERE refresh_token = $1 AND expires_at > CURRENT_TIMESTAMP', [token] ); if (result.rows.length === 0) return null; const row = result.rows[0]; return { token: row.refresh_token, accessTokenId: row.access_token_id, clientId: row.client_id, userId: row.user_id, scope: row.scope, expiresAt: row.expires_at, createdAt: row.created_at }; } async deleteRefreshToken(token: string): Promise<void> { await this.pool.query('DELETE FROM oauth_refresh_tokens WHERE refresh_token = $1', [token]); } async deleteRefreshTokensByAccessToken(accessTokenId: string): Promise<void> { await this.pool.query('DELETE FROM oauth_refresh_tokens WHERE access_token_id = $1', [accessTokenId]); } async cleanupExpiredRefreshTokens(): Promise<void> { await this.pool.query('DELETE FROM oauth_refresh_tokens WHERE expires_at <= CURRENT_TIMESTAMP'); } // ===== UTILITY METHODS ===== getStats(): { clients: number; users: number; authorizationCodes: number; accessTokens: number; refreshTokens: number } { // Note: This is a synchronous method that returns cached stats // For real-time stats, you would need to implement a separate async method return { clients: 0, // Would need to be implemented with caching users: 0, authorizationCodes: 0, accessTokens: 0, refreshTokens: 0 }; } async close(): Promise<void> { await this.pool.end(); } }