diff --git a/packages/jsonrpc/src/ensure.ts b/packages/jsonrpc/src/ensure.ts new file mode 100644 index 0000000..4487637 --- /dev/null +++ b/packages/jsonrpc/src/ensure.ts @@ -0,0 +1,31 @@ +import 'reflect-metadata' + +export const ensureKey = Symbol('ensure') + +export type Validate = (context: Context) => boolean + +export function ensure( + validate: Validate, + message: string = 'Validation failed', +) { + return ( + target: any, + propertyKey: string, + // tslint:disable-next-line + descriptor: PropertyDescriptor, + ) => { + const validators: Array> = + getValidatorsForMethod(target, propertyKey) + + validators.push(validate) + + Reflect.defineMetadata(ensureKey, validators, target, propertyKey) + } +} + +export function getValidatorsForMethod( + target: any, + method: string, +): Array> { + return Reflect.getOwnMetadata(ensureKey, target, method) || [] +} diff --git a/packages/jsonrpc/src/express.test.ts b/packages/jsonrpc/src/express.test.ts index 2e09a38..951a647 100644 --- a/packages/jsonrpc/src/express.test.ts +++ b/packages/jsonrpc/src/express.test.ts @@ -4,6 +4,7 @@ import request from 'supertest' import {createClient} from './supertest' import {jsonrpc} from './express' import {noopLogger} from './test-utils' +import {ensure} from './ensure' describe('jsonrpc', () => { @@ -22,6 +23,8 @@ describe('jsonrpc', () => { addWithContext2(a: number, b: number): Promise } + const ensureLoggedIn = ensure(c => !!c.userId) + class Service implements IService { constructor(readonly time: number) {} add(a: number, b: number) { @@ -52,16 +55,20 @@ describe('jsonrpc', () => { addWithContext = (a: number, b: number) => (ctx: IContext): number => { return a + b + ctx.userId } + + @ensureLoggedIn addWithContext2(a: number, b: number, ctx?: IContext) { return Promise.resolve(a + b + ctx!.userId) } } + let userId: number | undefined = 1000 function createApp() { + userId = 1000 const app = express() app.use(bodyParser.json()) app.use('/', - jsonrpc(req => ({userId: 1000}), noopLogger) + jsonrpc(req => ({userId}), noopLogger) .addService('/myService', new Service(5), [ 'add', 'delay', @@ -78,7 +85,7 @@ describe('jsonrpc', () => { const client = createClient(createApp(), '/myService') - async function getError(promise: Promise) { + async function getError(promise: Promise) { let error try { await promise @@ -153,6 +160,11 @@ describe('jsonrpc', () => { const response = await client.addWithContext2(5, 7) expect(response).toEqual(1000 + 5 + 7) }) + it('can validate context using @ensure decorator', async () => { + userId = undefined + const err = await getError(client.addWithContext2(5, 7)) + expect(err.message).toMatch(/Invalid request/) + }) it('handles synchronous notifications', async () => { await request(createApp()) .post('/myService') diff --git a/packages/jsonrpc/src/express.ts b/packages/jsonrpc/src/express.ts index 87c3e52..0204d89 100644 --- a/packages/jsonrpc/src/express.ts +++ b/packages/jsonrpc/src/express.ts @@ -2,7 +2,7 @@ import express, {ErrorRequestHandler} from 'express' import {FunctionPropertyNames} from './types' import {IDEMPOTENT_METHOD_REGEX} from './idempotent' import {IErrorResponse} from './error' -import {ILogger} from '@rondo.dev/common' +import {ILogger} from '@rondo.dev/logger' import {ISuccessResponse} from './jsonrpc' import {NextFunction, Request, Response, Router} from 'express' import {createError, isJSONRPCError, IJSONRPCError, IError} from './error' diff --git a/packages/jsonrpc/src/jsonrpc.ts b/packages/jsonrpc/src/jsonrpc.ts index 7f08527..1b99a52 100644 --- a/packages/jsonrpc/src/jsonrpc.ts +++ b/packages/jsonrpc/src/jsonrpc.ts @@ -2,6 +2,7 @@ export type TId = number | string import {ArgumentTypes, FunctionPropertyNames, RetType} from './types' import {isPromise} from './isPromise' import {createError, IErrorResponse, IErrorWithData} from './error' +import {getValidatorsForMethod} from './ensure' export const ERROR_PARSE = { code: -32700, @@ -73,6 +74,7 @@ export const createRpcService = >( service: T, methods: M[], ) => { + const rpcService = pick(service, methods) return { async invoke( req: IRequest>, @@ -94,8 +96,6 @@ export const createRpcService = >( const isNotification = req.id === null || req.id === undefined - const rpcService = pick(service, methods) - if ( !rpcService.hasOwnProperty(method) || typeof rpcService[method] !== 'function' @@ -107,6 +107,20 @@ export const createRpcService = >( }) } + const validators = getValidatorsForMethod( + (service as any).__proto__, method) + + validators.forEach(v => { + const success = v(context) + if (!success) { + throw createError(ERROR_INVALID_REQUEST, { + id, + data: null, + statusCode: 400, + }) + } + }) + let retValue = (rpcService[method] as any)(...params, context) if (typeof retValue === 'function') { diff --git a/packages/jsonrpc/tsconfig.json b/packages/jsonrpc/tsconfig.json index 94e864b..301eb59 100644 --- a/packages/jsonrpc/tsconfig.json +++ b/packages/jsonrpc/tsconfig.json @@ -1,9 +1,11 @@ { "extends": "../tsconfig.common.json", "compilerOptions": { + "target": "es5", "outDir": "lib", "rootDir": "src" }, "references": [ + {"path": "../logger"} ] }