diff --git a/packages/server/src/enterprise/controllers/organization.controller.ts b/packages/server/src/enterprise/controllers/organization.controller.ts index b7ca0a6d7..fc57184dd 100644 --- a/packages/server/src/enterprise/controllers/organization.controller.ts +++ b/packages/server/src/enterprise/controllers/organization.controller.ts @@ -1,12 +1,13 @@ -import { Request, Response, NextFunction } from 'express' +import { NextFunction, Request, Response } from 'express' import { StatusCodes } from 'http-status-codes' -import { OrganizationErrorMessage, OrganizationService } from '../services/organization.service' -import { getRunningExpressApp } from '../../utils/getRunningExpressApp' +import { QueryRunner } from 'typeorm' import { InternalFlowiseError } from '../../errors/internalFlowiseError' -import { Organization } from '../database/entities/organization.entity' import { GeneralErrorMessage } from '../../utils/constants' -import { OrganizationUserService } from '../services/organization-user.service' +import { getRunningExpressApp } from '../../utils/getRunningExpressApp' import { getCurrentUsage } from '../../utils/quotaUsage' +import { Organization } from '../database/entities/organization.entity' +import { OrganizationUserService } from '../services/organization-user.service' +import { OrganizationErrorMessage, OrganizationService } from '../services/organization.service' export class OrganizationController { public async create(req: Request, res: Response, next: NextFunction) { @@ -47,12 +48,18 @@ export class OrganizationController { } public async update(req: Request, res: Response, next: NextFunction) { + let queryRunner: QueryRunner | undefined try { + queryRunner = getRunningExpressApp().AppDataSource.createQueryRunner() + await queryRunner.connect() const organizationService = new OrganizationService() - const organization = await organizationService.updateOrganization(req.body) + const organization = await organizationService.updateOrganization(req.body, queryRunner) return res.status(StatusCodes.OK).json(organization) } catch (error) { + if (queryRunner && queryRunner.isTransactionActive) await queryRunner.rollbackTransaction() next(error) + } finally { + if (queryRunner && !queryRunner.isReleased) await queryRunner.release() } } diff --git a/packages/server/src/enterprise/services/organization.service.ts b/packages/server/src/enterprise/services/organization.service.ts index 9ee115467..85380d6d1 100644 --- a/packages/server/src/enterprise/services/organization.service.ts +++ b/packages/server/src/enterprise/services/organization.service.ts @@ -4,12 +4,13 @@ import { InternalFlowiseError } from '../../errors/internalFlowiseError' import { generateId } from '../../utils' import { getRunningExpressApp } from '../../utils/getRunningExpressApp' import { Telemetry } from '../../utils/telemetry' -import { Organization, OrganizationName } from '../database/entities/organization.entity' +import { Organization, OrganizationName, OrganizationStatus } from '../database/entities/organization.entity' import { isInvalidName, isInvalidUUID } from '../utils/validation.util' import { UserErrorMessage, UserService } from './user.service' export const enum OrganizationErrorMessage { INVALID_ORGANIZATION_ID = 'Invalid Organization Id', + INVALID_ORGANIZATION_STATUS = 'Invalid Organization Status', INVALID_ORGANIZATION_NAME = 'Invalid Organization Name', ORGANIZATION_NOT_FOUND = 'Organization Not Found', ORGANIZATION_FOUND_MULTIPLE = 'Organization Found Multiple', @@ -32,6 +33,12 @@ export class OrganizationService { if (isInvalidUUID(id)) throw new InternalFlowiseError(StatusCodes.BAD_REQUEST, OrganizationErrorMessage.INVALID_ORGANIZATION_ID) } + public validateOrganizationStatus(status: string | undefined) { + if (status && !Object.values(OrganizationStatus).includes(status as OrganizationStatus)) { + throw new InternalFlowiseError(StatusCodes.BAD_REQUEST, OrganizationErrorMessage.INVALID_ORGANIZATION_STATUS) + } + } + public async readOrganizationById(id: string | undefined, queryRunner: QueryRunner) { this.validateOrganizationId(id) return await queryRunner.manager.findOneBy(Organization, { id }) @@ -59,6 +66,8 @@ export class OrganizationService { public createNewOrganization(data: Partial, queryRunner: QueryRunner, isRegister: boolean = false) { this.validateOrganizationName(data.name, isRegister) + // REMARK: status is not allowed to be set when creating a new organization + if (data.status) delete data.status data.updatedBy = data.createdBy data.id = generateId() @@ -91,30 +100,20 @@ export class OrganizationService { return newOrganization } - public async updateOrganization(newOrganizationData: Partial) { - const queryRunner = this.dataSource.createQueryRunner() - await queryRunner.connect() - + public async updateOrganization(newOrganizationData: Partial, queryRunner: QueryRunner, fromStripe: boolean = false) { const oldOrganizationData = await this.readOrganizationById(newOrganizationData.id, queryRunner) if (!oldOrganizationData) throw new InternalFlowiseError(StatusCodes.NOT_FOUND, OrganizationErrorMessage.ORGANIZATION_NOT_FOUND) const user = await this.userService.readUserById(newOrganizationData.updatedBy, queryRunner) if (!user) throw new InternalFlowiseError(StatusCodes.NOT_FOUND, UserErrorMessage.USER_NOT_FOUND) - if (newOrganizationData.name) { - this.validateOrganizationName(newOrganizationData.name) - } + if (newOrganizationData.name) this.validateOrganizationName(newOrganizationData.name) + // TODO: allow flowise's employees to modify organization status + // REMARK: status is only allowed to be set when updating an organization from stripe + if (fromStripe === true && newOrganizationData.status) this.validateOrganizationStatus(newOrganizationData.status) + else if (newOrganizationData.status) delete newOrganizationData.status newOrganizationData.createdBy = oldOrganizationData.createdBy let updateOrganization = queryRunner.manager.merge(Organization, oldOrganizationData, newOrganizationData) - try { - await queryRunner.startTransaction() - await this.saveOrganization(updateOrganization, queryRunner) - await queryRunner.commitTransaction() - } catch (error) { - await queryRunner.rollbackTransaction() - throw error - } finally { - await queryRunner.release() - } + await this.saveOrganization(updateOrganization, queryRunner) return updateOrganization } diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index b37a2ca0f..f656e884b 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -1,39 +1,40 @@ -import express, { Request, Response } from 'express' -import path from 'path' -import cors from 'cors' -import http from 'http' import cookieParser from 'cookie-parser' +import cors from 'cors' +import express, { Request, Response } from 'express' +import 'global-agent/bootstrap' +import http from 'http' +import path from 'path' import { DataSource, IsNull } from 'typeorm' -import { MODE, Platform } from './Interface' -import { getNodeModulesPackagePath, getEncryptionKey } from './utils' -import logger, { expressRequestLogger } from './utils/logger' -import { getDataSource } from './DataSource' -import { NodesPool } from './NodesPool' -import { ChatFlow } from './database/entities/ChatFlow' -import { CachePool } from './CachePool' import { AbortControllerPool } from './AbortControllerPool' -import { RateLimiterManager } from './utils/rateLimit' -import { getAllowedIframeOrigins, getCorsOptions, sanitizeMiddleware } from './utils/XSS' -import { Telemetry } from './utils/telemetry' -import flowiseApiV1Router from './routes' -import errorHandlerMiddleware from './middlewares/errors' -import { WHITELIST_URLS } from './utils/constants' -import { initializeJwtCookieMiddleware, verifyToken } from './enterprise/middleware/passport' -import { IdentityManager } from './IdentityManager' -import { SSEStreamer } from './utils/SSEStreamer' -import { getAPIKeyWorkspaceID, validateAPIKey } from './utils/validateKey' +import { CachePool } from './CachePool' +import { ChatFlow } from './database/entities/ChatFlow' +import { getDataSource } from './DataSource' +import { Organization, OrganizationStatus } from './enterprise/database/entities/organization.entity' +import { GeneralRole, Role } from './enterprise/database/entities/role.entity' +import { Workspace } from './enterprise/database/entities/workspace.entity' import { LoggedInUser } from './enterprise/Interface.Enterprise' +import { initializeJwtCookieMiddleware, verifyToken } from './enterprise/middleware/passport' +import { handleStripeWebhook } from './enterprise/webhooks/stripe' +import { IdentityManager } from './IdentityManager' +import { MODE, Platform } from './Interface' import { IMetricsProvider } from './Interface.Metrics' -import { Prometheus } from './metrics/Prometheus' import { OpenTelemetry } from './metrics/OpenTelemetry' +import { Prometheus } from './metrics/Prometheus' +import errorHandlerMiddleware from './middlewares/errors' +import { NodesPool } from './NodesPool' import { QueueManager } from './queue/QueueManager' import { RedisEventSubscriber } from './queue/RedisEventSubscriber' -import 'global-agent/bootstrap' +import flowiseApiV1Router from './routes' import { UsageCacheManager } from './UsageCacheManager' -import { Workspace } from './enterprise/database/entities/workspace.entity' -import { Organization } from './enterprise/database/entities/organization.entity' -import { GeneralRole, Role } from './enterprise/database/entities/role.entity' +import { getEncryptionKey, getNodeModulesPackagePath } from './utils' import { migrateApiKeysFromJsonToDb } from './utils/apiKey' +import { WHITELIST_URLS } from './utils/constants' +import logger, { expressRequestLogger } from './utils/logger' +import { RateLimiterManager } from './utils/rateLimit' +import { SSEStreamer } from './utils/SSEStreamer' +import { Telemetry } from './utils/telemetry' +import { getAPIKeyWorkspaceID, validateAPIKey } from './utils/validateKey' +import { getAllowedIframeOrigins, getCorsOptions, sanitizeMiddleware } from './utils/XSS' import { StripeWebhooks } from './enterprise/webhooks/stripe' declare global { @@ -252,6 +253,10 @@ export class App { if (!org) { return res.status(401).json({ error: 'Unauthorized Access' }) } + if (org.status == OrganizationStatus.PAST_DUE) + return res.status(402).json({ error: 'Access denied. Your organization has past due payments.' }) + if (org.status == OrganizationStatus.UNDER_REVIEW) + return res.status(403).json({ error: 'Access denied. Your organization is under review.' }) const subscriptionId = org.subscriptionId as string const customerId = org.customerId as string const features = await this.identityManager.getFeaturesByPlan(subscriptionId)