diff --git a/packages/server/package.json b/packages/server/package.json index 4d50ddc4e..9b9b7dbd0 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -52,6 +52,7 @@ "dotenv": "^16.0.0", "express": "^4.17.3", "express-basic-auth": "^1.2.1", + "express-rate-limit": "^6.9.0", "flowise-components": "*", "flowise-ui": "*", "moment-timezone": "^0.5.34", diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 66c7b0004..15709ad94 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -1,4 +1,4 @@ -import express, { Request, Response } from 'express' +import express, { NextFunction, Request, Response } from 'express' import multer from 'multer' import path from 'path' import cors from 'cors' @@ -54,6 +54,7 @@ import { Credential } from './entity/Credential' import { Tool } from './entity/Tool' import { ChatflowPool } from './ChatflowPool' import { ICommonObject, INodeOptionsValue } from 'flowise-components' +import { createRateLimiter, getRateLimiter } from './utils/rateLimit' export class App { app: express.Application @@ -654,6 +655,21 @@ export class App { // Prediction // ---------------------------------------- + this.app.get( + '/api/v1/rate-limit/:id', + upload.array('files'), + (req: Request, res: Response, next: NextFunction) => getRateLimiter(req, res, next), + // specificRouteLimiter, + async (req: Request, res: Response) => { + res.send("you're fine") + } + ) + + this.app.post('/api/v1/rate-limit/', async (req: Request, res: Response) => { + createRateLimiter(req) + res.send('Created/Updated rate limit') + }) + // Send input message and get prediction result (External) this.app.post('/api/v1/prediction/:id', upload.array('files'), async (req: Request, res: Response) => { await this.processPrediction(req, res, socketIO) diff --git a/packages/server/src/utils/rateLimit.ts b/packages/server/src/utils/rateLimit.ts new file mode 100644 index 000000000..0bd5be98c --- /dev/null +++ b/packages/server/src/utils/rateLimit.ts @@ -0,0 +1,56 @@ +import { NextFunction, Request, Response } from 'express' +import { rateLimit, RateLimitRequestHandler } from 'express-rate-limit' + +interface RateLimit { + id: string + rateLimitObj: RateLimitRequestHandler +} + +export const specificRouteLimiter: RateLimitRequestHandler = rateLimit({ + windowMs: 1 * 60 * 1000, // 15 minutes + max: 1, // Limit each IP to 100 requests per windowMs + message: 'Too many requests, please try again later.' +}) + +let rateLimiters: RateLimit[] = [] + +export function createRateLimiter(req: Request) { + const id = req.body.id + const duration = req.body.duration + const limit = req.body.limit + const message = req.body.message + + const rateLimitObj: RateLimitRequestHandler = rateLimit({ + windowMs: Number(duration), + max: limit, + handler: (req, res) => { + res.status(429).json({ error: message }) + } + }) + + const existingIndex: number = rateLimiters.findIndex((rateLimit) => rateLimit.id === id) + + if (existingIndex === -1) { + rateLimiters.push({ + id, + rateLimitObj + }) + } else { + rateLimiters[existingIndex] = { + id, + rateLimitObj + } + } +} + +export function getRateLimiter(req: Request, res: Response, next: NextFunction) { + const id = req.params.id + + const ratelimiter = rateLimiters.find((rateLimit) => rateLimit.id === id) + + if (!ratelimiter) return next() + + const idRateLimiter = ratelimiter.rateLimitObj + + return idRateLimiter(req, res, next) +}