Enforce restrictions based on organization.status (#4652)
* feat: does not allow change of organziation.status unless from stripe * feat: restrict apikey when organization.status is not active
This commit is contained in:
parent
407c8bb1a8
commit
4a2ea0a425
|
|
@ -1,12 +1,13 @@
|
|||
import { Request, Response, NextFunction } from 'express'
|
||||
import { NextFunction, Request, Response } from 'express'
|
||||
import { StatusCodes } from 'http-status-codes'
|
||||
import { OrganizationErrorMessage, OrganizationService } from '../services/organization.service'
|
||||
import { getRunningExpressApp } from '../../utils/getRunningExpressApp'
|
||||
import { QueryRunner } from 'typeorm'
|
||||
import { InternalFlowiseError } from '../../errors/internalFlowiseError'
|
||||
import { Organization } from '../database/entities/organization.entity'
|
||||
import { GeneralErrorMessage } from '../../utils/constants'
|
||||
import { OrganizationUserService } from '../services/organization-user.service'
|
||||
import { getRunningExpressApp } from '../../utils/getRunningExpressApp'
|
||||
import { getCurrentUsage } from '../../utils/quotaUsage'
|
||||
import { Organization } from '../database/entities/organization.entity'
|
||||
import { OrganizationUserService } from '../services/organization-user.service'
|
||||
import { OrganizationErrorMessage, OrganizationService } from '../services/organization.service'
|
||||
|
||||
export class OrganizationController {
|
||||
public async create(req: Request, res: Response, next: NextFunction) {
|
||||
|
|
@ -47,12 +48,18 @@ export class OrganizationController {
|
|||
}
|
||||
|
||||
public async update(req: Request, res: Response, next: NextFunction) {
|
||||
let queryRunner: QueryRunner | undefined
|
||||
try {
|
||||
queryRunner = getRunningExpressApp().AppDataSource.createQueryRunner()
|
||||
await queryRunner.connect()
|
||||
const organizationService = new OrganizationService()
|
||||
const organization = await organizationService.updateOrganization(req.body)
|
||||
const organization = await organizationService.updateOrganization(req.body, queryRunner)
|
||||
return res.status(StatusCodes.OK).json(organization)
|
||||
} catch (error) {
|
||||
if (queryRunner && queryRunner.isTransactionActive) await queryRunner.rollbackTransaction()
|
||||
next(error)
|
||||
} finally {
|
||||
if (queryRunner && !queryRunner.isReleased) await queryRunner.release()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,12 +4,13 @@ import { InternalFlowiseError } from '../../errors/internalFlowiseError'
|
|||
import { generateId } from '../../utils'
|
||||
import { getRunningExpressApp } from '../../utils/getRunningExpressApp'
|
||||
import { Telemetry } from '../../utils/telemetry'
|
||||
import { Organization, OrganizationName } from '../database/entities/organization.entity'
|
||||
import { Organization, OrganizationName, OrganizationStatus } from '../database/entities/organization.entity'
|
||||
import { isInvalidName, isInvalidUUID } from '../utils/validation.util'
|
||||
import { UserErrorMessage, UserService } from './user.service'
|
||||
|
||||
export const enum OrganizationErrorMessage {
|
||||
INVALID_ORGANIZATION_ID = 'Invalid Organization Id',
|
||||
INVALID_ORGANIZATION_STATUS = 'Invalid Organization Status',
|
||||
INVALID_ORGANIZATION_NAME = 'Invalid Organization Name',
|
||||
ORGANIZATION_NOT_FOUND = 'Organization Not Found',
|
||||
ORGANIZATION_FOUND_MULTIPLE = 'Organization Found Multiple',
|
||||
|
|
@ -32,6 +33,12 @@ export class OrganizationService {
|
|||
if (isInvalidUUID(id)) throw new InternalFlowiseError(StatusCodes.BAD_REQUEST, OrganizationErrorMessage.INVALID_ORGANIZATION_ID)
|
||||
}
|
||||
|
||||
public validateOrganizationStatus(status: string | undefined) {
|
||||
if (status && !Object.values(OrganizationStatus).includes(status as OrganizationStatus)) {
|
||||
throw new InternalFlowiseError(StatusCodes.BAD_REQUEST, OrganizationErrorMessage.INVALID_ORGANIZATION_STATUS)
|
||||
}
|
||||
}
|
||||
|
||||
public async readOrganizationById(id: string | undefined, queryRunner: QueryRunner) {
|
||||
this.validateOrganizationId(id)
|
||||
return await queryRunner.manager.findOneBy(Organization, { id })
|
||||
|
|
@ -59,6 +66,8 @@ export class OrganizationService {
|
|||
|
||||
public createNewOrganization(data: Partial<Organization>, queryRunner: QueryRunner, isRegister: boolean = false) {
|
||||
this.validateOrganizationName(data.name, isRegister)
|
||||
// REMARK: status is not allowed to be set when creating a new organization
|
||||
if (data.status) delete data.status
|
||||
data.updatedBy = data.createdBy
|
||||
data.id = generateId()
|
||||
|
||||
|
|
@ -91,30 +100,20 @@ export class OrganizationService {
|
|||
return newOrganization
|
||||
}
|
||||
|
||||
public async updateOrganization(newOrganizationData: Partial<Organization>) {
|
||||
const queryRunner = this.dataSource.createQueryRunner()
|
||||
await queryRunner.connect()
|
||||
|
||||
public async updateOrganization(newOrganizationData: Partial<Organization>, queryRunner: QueryRunner, fromStripe: boolean = false) {
|
||||
const oldOrganizationData = await this.readOrganizationById(newOrganizationData.id, queryRunner)
|
||||
if (!oldOrganizationData) throw new InternalFlowiseError(StatusCodes.NOT_FOUND, OrganizationErrorMessage.ORGANIZATION_NOT_FOUND)
|
||||
const user = await this.userService.readUserById(newOrganizationData.updatedBy, queryRunner)
|
||||
if (!user) throw new InternalFlowiseError(StatusCodes.NOT_FOUND, UserErrorMessage.USER_NOT_FOUND)
|
||||
if (newOrganizationData.name) {
|
||||
this.validateOrganizationName(newOrganizationData.name)
|
||||
}
|
||||
if (newOrganizationData.name) this.validateOrganizationName(newOrganizationData.name)
|
||||
// TODO: allow flowise's employees to modify organization status
|
||||
// REMARK: status is only allowed to be set when updating an organization from stripe
|
||||
if (fromStripe === true && newOrganizationData.status) this.validateOrganizationStatus(newOrganizationData.status)
|
||||
else if (newOrganizationData.status) delete newOrganizationData.status
|
||||
newOrganizationData.createdBy = oldOrganizationData.createdBy
|
||||
|
||||
let updateOrganization = queryRunner.manager.merge(Organization, oldOrganizationData, newOrganizationData)
|
||||
try {
|
||||
await queryRunner.startTransaction()
|
||||
await this.saveOrganization(updateOrganization, queryRunner)
|
||||
await queryRunner.commitTransaction()
|
||||
} catch (error) {
|
||||
await queryRunner.rollbackTransaction()
|
||||
throw error
|
||||
} finally {
|
||||
await queryRunner.release()
|
||||
}
|
||||
await this.saveOrganization(updateOrganization, queryRunner)
|
||||
|
||||
return updateOrganization
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,39 +1,40 @@
|
|||
import express, { Request, Response } from 'express'
|
||||
import path from 'path'
|
||||
import cors from 'cors'
|
||||
import http from 'http'
|
||||
import cookieParser from 'cookie-parser'
|
||||
import cors from 'cors'
|
||||
import express, { Request, Response } from 'express'
|
||||
import 'global-agent/bootstrap'
|
||||
import http from 'http'
|
||||
import path from 'path'
|
||||
import { DataSource, IsNull } from 'typeorm'
|
||||
import { MODE, Platform } from './Interface'
|
||||
import { getNodeModulesPackagePath, getEncryptionKey } from './utils'
|
||||
import logger, { expressRequestLogger } from './utils/logger'
|
||||
import { getDataSource } from './DataSource'
|
||||
import { NodesPool } from './NodesPool'
|
||||
import { ChatFlow } from './database/entities/ChatFlow'
|
||||
import { CachePool } from './CachePool'
|
||||
import { AbortControllerPool } from './AbortControllerPool'
|
||||
import { RateLimiterManager } from './utils/rateLimit'
|
||||
import { getAllowedIframeOrigins, getCorsOptions, sanitizeMiddleware } from './utils/XSS'
|
||||
import { Telemetry } from './utils/telemetry'
|
||||
import flowiseApiV1Router from './routes'
|
||||
import errorHandlerMiddleware from './middlewares/errors'
|
||||
import { WHITELIST_URLS } from './utils/constants'
|
||||
import { initializeJwtCookieMiddleware, verifyToken } from './enterprise/middleware/passport'
|
||||
import { IdentityManager } from './IdentityManager'
|
||||
import { SSEStreamer } from './utils/SSEStreamer'
|
||||
import { getAPIKeyWorkspaceID, validateAPIKey } from './utils/validateKey'
|
||||
import { CachePool } from './CachePool'
|
||||
import { ChatFlow } from './database/entities/ChatFlow'
|
||||
import { getDataSource } from './DataSource'
|
||||
import { Organization, OrganizationStatus } from './enterprise/database/entities/organization.entity'
|
||||
import { GeneralRole, Role } from './enterprise/database/entities/role.entity'
|
||||
import { Workspace } from './enterprise/database/entities/workspace.entity'
|
||||
import { LoggedInUser } from './enterprise/Interface.Enterprise'
|
||||
import { initializeJwtCookieMiddleware, verifyToken } from './enterprise/middleware/passport'
|
||||
import { handleStripeWebhook } from './enterprise/webhooks/stripe'
|
||||
import { IdentityManager } from './IdentityManager'
|
||||
import { MODE, Platform } from './Interface'
|
||||
import { IMetricsProvider } from './Interface.Metrics'
|
||||
import { Prometheus } from './metrics/Prometheus'
|
||||
import { OpenTelemetry } from './metrics/OpenTelemetry'
|
||||
import { Prometheus } from './metrics/Prometheus'
|
||||
import errorHandlerMiddleware from './middlewares/errors'
|
||||
import { NodesPool } from './NodesPool'
|
||||
import { QueueManager } from './queue/QueueManager'
|
||||
import { RedisEventSubscriber } from './queue/RedisEventSubscriber'
|
||||
import 'global-agent/bootstrap'
|
||||
import flowiseApiV1Router from './routes'
|
||||
import { UsageCacheManager } from './UsageCacheManager'
|
||||
import { Workspace } from './enterprise/database/entities/workspace.entity'
|
||||
import { Organization } from './enterprise/database/entities/organization.entity'
|
||||
import { GeneralRole, Role } from './enterprise/database/entities/role.entity'
|
||||
import { getEncryptionKey, getNodeModulesPackagePath } from './utils'
|
||||
import { migrateApiKeysFromJsonToDb } from './utils/apiKey'
|
||||
import { WHITELIST_URLS } from './utils/constants'
|
||||
import logger, { expressRequestLogger } from './utils/logger'
|
||||
import { RateLimiterManager } from './utils/rateLimit'
|
||||
import { SSEStreamer } from './utils/SSEStreamer'
|
||||
import { Telemetry } from './utils/telemetry'
|
||||
import { getAPIKeyWorkspaceID, validateAPIKey } from './utils/validateKey'
|
||||
import { getAllowedIframeOrigins, getCorsOptions, sanitizeMiddleware } from './utils/XSS'
|
||||
import { StripeWebhooks } from './enterprise/webhooks/stripe'
|
||||
|
||||
declare global {
|
||||
|
|
@ -252,6 +253,10 @@ export class App {
|
|||
if (!org) {
|
||||
return res.status(401).json({ error: 'Unauthorized Access' })
|
||||
}
|
||||
if (org.status == OrganizationStatus.PAST_DUE)
|
||||
return res.status(402).json({ error: 'Access denied. Your organization has past due payments.' })
|
||||
if (org.status == OrganizationStatus.UNDER_REVIEW)
|
||||
return res.status(403).json({ error: 'Access denied. Your organization is under review.' })
|
||||
const subscriptionId = org.subscriptionId as string
|
||||
const customerId = org.customerId as string
|
||||
const features = await this.identityManager.getFeaturesByPlan(subscriptionId)
|
||||
|
|
|
|||
Loading…
Reference in New Issue