590 lines
24 KiB
TypeScript
590 lines
24 KiB
TypeScript
/**
|
|
* Copyright (c) 2023-present FlowiseAI, Inc.
|
|
*
|
|
* The Enterprise and Cloud versions of Flowise are licensed under the [Commercial License](https://github.com/FlowiseAI/Flowise/tree/main/packages/server/src/enterprise/LICENSE.md).
|
|
* Unauthorized copying, modification, distribution, or use of the Enterprise and Cloud versions is strictly prohibited without a valid license agreement from FlowiseAI, Inc.
|
|
*
|
|
* The Open Source version is licensed under the Apache License, Version 2.0 (the "License")
|
|
*
|
|
* For information about licensing of the Enterprise and Cloud versions, please contact:
|
|
* security@flowiseai.com
|
|
*/
|
|
|
|
import axios from 'axios'
|
|
import express, { Application, NextFunction, Request, Response } from 'express'
|
|
import * as fs from 'fs'
|
|
import { StatusCodes } from 'http-status-codes'
|
|
import jwt from 'jsonwebtoken'
|
|
import path from 'path'
|
|
import { LoginMethodStatus } from './enterprise/database/entities/login-method.entity'
|
|
import { ErrorMessage, LoggedInUser } from './enterprise/Interface.Enterprise'
|
|
import { Permissions } from './enterprise/rbac/Permissions'
|
|
import { LoginMethodService } from './enterprise/services/login-method.service'
|
|
import { Organization, OrganizationStatus } from './enterprise/database/entities/organization.entity'
|
|
import { OrganizationService } from './enterprise/services/organization.service'
|
|
import Auth0SSO from './enterprise/sso/Auth0SSO'
|
|
import AzureSSO from './enterprise/sso/AzureSSO'
|
|
import GithubSSO from './enterprise/sso/GithubSSO'
|
|
import GoogleSSO from './enterprise/sso/GoogleSSO'
|
|
import SSOBase from './enterprise/sso/SSOBase'
|
|
import { InternalFlowiseError } from './errors/internalFlowiseError'
|
|
import { Platform, UserPlan } from './Interface'
|
|
import { StripeManager } from './StripeManager'
|
|
import { UsageCacheManager } from './UsageCacheManager'
|
|
import { GeneralErrorMessage, LICENSE_QUOTAS } from './utils/constants'
|
|
import { getRunningExpressApp } from './utils/getRunningExpressApp'
|
|
import { ENTERPRISE_FEATURE_FLAGS } from './utils/quotaUsage'
|
|
import Stripe from 'stripe'
|
|
|
|
const allSSOProviders = ['azure', 'google', 'auth0', 'github']
|
|
export class IdentityManager {
|
|
private static instance: IdentityManager
|
|
private stripeManager?: StripeManager
|
|
licenseValid: boolean = false
|
|
permissions: Permissions
|
|
ssoProviderName: string = ''
|
|
currentInstancePlatform: Platform = Platform.OPEN_SOURCE
|
|
// create a map to store the sso provider name and the sso provider instance
|
|
ssoProviders: Map<string, SSOBase> = new Map()
|
|
|
|
public static async getInstance(): Promise<IdentityManager> {
|
|
if (!IdentityManager.instance) {
|
|
IdentityManager.instance = new IdentityManager()
|
|
await IdentityManager.instance.initialize()
|
|
}
|
|
return IdentityManager.instance
|
|
}
|
|
|
|
public async initialize() {
|
|
await this._validateLicenseKey()
|
|
this.permissions = new Permissions()
|
|
if (process.env.STRIPE_SECRET_KEY) {
|
|
this.stripeManager = await StripeManager.getInstance()
|
|
}
|
|
}
|
|
|
|
public getPlatformType = () => {
|
|
return this.currentInstancePlatform
|
|
}
|
|
|
|
public getPermissions = () => {
|
|
return this.permissions
|
|
}
|
|
|
|
public isEnterprise = () => {
|
|
return this.currentInstancePlatform === Platform.ENTERPRISE
|
|
}
|
|
|
|
public isCloud = () => {
|
|
return this.currentInstancePlatform === Platform.CLOUD
|
|
}
|
|
|
|
public isOpenSource = () => {
|
|
return this.currentInstancePlatform === Platform.OPEN_SOURCE
|
|
}
|
|
|
|
public isLicenseValid = () => {
|
|
return this.licenseValid
|
|
}
|
|
|
|
private _offlineVerifyLicense(licenseKey: string): any {
|
|
try {
|
|
const publicKey = fs.readFileSync(path.join(__dirname, '../', 'src/enterprise/license/public.pem'), 'utf8')
|
|
const decoded = jwt.verify(licenseKey, publicKey, {
|
|
algorithms: ['RS256']
|
|
})
|
|
return decoded
|
|
} catch (error) {
|
|
console.error('Error verifying license key:', error)
|
|
return null
|
|
}
|
|
}
|
|
|
|
private _validateLicenseKey = async () => {
|
|
const LICENSE_URL = process.env.LICENSE_URL
|
|
const FLOWISE_EE_LICENSE_KEY = process.env.FLOWISE_EE_LICENSE_KEY
|
|
|
|
// First check if license key is missing
|
|
if (!FLOWISE_EE_LICENSE_KEY) {
|
|
this.licenseValid = false
|
|
this.currentInstancePlatform = Platform.OPEN_SOURCE
|
|
return
|
|
}
|
|
|
|
try {
|
|
if (process.env.OFFLINE === 'true') {
|
|
const decodedLicense = this._offlineVerifyLicense(FLOWISE_EE_LICENSE_KEY)
|
|
|
|
if (!decodedLicense) {
|
|
this.licenseValid = false
|
|
} else {
|
|
const issuedAtSeconds = decodedLicense.iat
|
|
if (!issuedAtSeconds) {
|
|
this.licenseValid = false
|
|
} else {
|
|
const issuedAt = new Date(issuedAtSeconds * 1000)
|
|
const expiryDurationInMonths = decodedLicense.expiryDurationInMonths || 0
|
|
|
|
const expiryDate = new Date(issuedAt)
|
|
expiryDate.setMonth(expiryDate.getMonth() + expiryDurationInMonths)
|
|
|
|
if (new Date() > expiryDate) {
|
|
this.licenseValid = false
|
|
} else {
|
|
this.licenseValid = true
|
|
}
|
|
}
|
|
}
|
|
this.currentInstancePlatform = Platform.ENTERPRISE
|
|
} else if (LICENSE_URL) {
|
|
try {
|
|
const response = await axios.post(`${LICENSE_URL}/enterprise/verify`, { license: FLOWISE_EE_LICENSE_KEY })
|
|
this.licenseValid = response.data?.valid
|
|
|
|
if (!LICENSE_URL.includes('api')) this.currentInstancePlatform = Platform.ENTERPRISE
|
|
else if (LICENSE_URL.includes('v1')) this.currentInstancePlatform = Platform.ENTERPRISE
|
|
else if (LICENSE_URL.includes('v2')) this.currentInstancePlatform = response.data?.platform
|
|
else throw new InternalFlowiseError(StatusCodes.INTERNAL_SERVER_ERROR, GeneralErrorMessage.UNHANDLED_EDGE_CASE)
|
|
} catch (error) {
|
|
console.error('Error verifying license key:', error)
|
|
this.licenseValid = false
|
|
this.currentInstancePlatform = Platform.ENTERPRISE
|
|
return
|
|
}
|
|
}
|
|
} catch (error) {
|
|
this.licenseValid = false
|
|
}
|
|
}
|
|
|
|
public initializeSSO = async (app: express.Application) => {
|
|
if (this.getPlatformType() === Platform.CLOUD || this.getPlatformType() === Platform.ENTERPRISE) {
|
|
const loginMethodService = new LoginMethodService()
|
|
let queryRunner
|
|
try {
|
|
queryRunner = getRunningExpressApp().AppDataSource.createQueryRunner()
|
|
await queryRunner.connect()
|
|
let organizationId = undefined
|
|
if (this.getPlatformType() === Platform.ENTERPRISE) {
|
|
const organizationService = new OrganizationService()
|
|
const organizations = await organizationService.readOrganization(queryRunner)
|
|
if (organizations.length > 0) {
|
|
organizationId = organizations[0].id
|
|
} else {
|
|
this.initializeEmptySSO(app)
|
|
return
|
|
}
|
|
}
|
|
const loginMethods = await loginMethodService.readLoginMethodByOrganizationId(organizationId, queryRunner)
|
|
if (loginMethods && loginMethods.length > 0) {
|
|
for (let method of loginMethods) {
|
|
if (method.status === LoginMethodStatus.ENABLE) {
|
|
method.config = JSON.parse(await loginMethodService.decryptLoginMethodConfig(method.config))
|
|
this.initializeSsoProvider(app, method.name, method.config)
|
|
}
|
|
}
|
|
}
|
|
} finally {
|
|
if (queryRunner) await queryRunner.release()
|
|
}
|
|
}
|
|
// iterate through the remaining providers and initialize them with configEnabled as false
|
|
this.initializeEmptySSO(app)
|
|
}
|
|
|
|
initializeEmptySSO(app: Application) {
|
|
allSSOProviders.map((providerName) => {
|
|
if (!this.ssoProviders.has(providerName)) {
|
|
this.initializeSsoProvider(app, providerName, undefined)
|
|
}
|
|
})
|
|
}
|
|
|
|
initializeSsoProvider(app: Application, providerName: string, providerConfig: any) {
|
|
if (this.ssoProviders.has(providerName)) {
|
|
const provider = this.ssoProviders.get(providerName)
|
|
if (provider) {
|
|
if (providerConfig && providerConfig.configEnabled === true) {
|
|
provider.setSSOConfig(providerConfig)
|
|
provider.initialize()
|
|
} else {
|
|
// if false, disable the provider
|
|
provider.setSSOConfig(undefined)
|
|
}
|
|
}
|
|
} else {
|
|
switch (providerName) {
|
|
case 'azure': {
|
|
const azureSSO = new AzureSSO(app, providerConfig)
|
|
azureSSO.initialize()
|
|
this.ssoProviders.set(providerName, azureSSO)
|
|
break
|
|
}
|
|
case 'google': {
|
|
const googleSSO = new GoogleSSO(app, providerConfig)
|
|
googleSSO.initialize()
|
|
this.ssoProviders.set(providerName, googleSSO)
|
|
break
|
|
}
|
|
case 'auth0': {
|
|
const auth0SSO = new Auth0SSO(app, providerConfig)
|
|
auth0SSO.initialize()
|
|
this.ssoProviders.set(providerName, auth0SSO)
|
|
break
|
|
}
|
|
case 'github': {
|
|
const githubSSO = new GithubSSO(app, providerConfig)
|
|
githubSSO.initialize()
|
|
this.ssoProviders.set(providerName, githubSSO)
|
|
break
|
|
}
|
|
default:
|
|
throw new Error(`SSO Provider ${providerName} not found`)
|
|
}
|
|
}
|
|
}
|
|
|
|
async getRefreshToken(providerName: any, ssoRefreshToken: string) {
|
|
if (!this.ssoProviders.has(providerName)) {
|
|
throw new Error(`SSO Provider ${providerName} not found`)
|
|
}
|
|
return await (this.ssoProviders.get(providerName) as SSOBase).refreshToken(ssoRefreshToken)
|
|
}
|
|
|
|
public async getProductIdFromSubscription(subscriptionId: string) {
|
|
if (!subscriptionId) return ''
|
|
if (!this.stripeManager) {
|
|
throw new Error('Stripe manager is not initialized')
|
|
}
|
|
return await this.stripeManager.getProductIdFromSubscription(subscriptionId)
|
|
}
|
|
|
|
public async getFeaturesByPlan(subscriptionId: string, withoutCache: boolean = false) {
|
|
if (this.isEnterprise()) {
|
|
const features: Record<string, string> = {}
|
|
for (const feature of ENTERPRISE_FEATURE_FLAGS) {
|
|
features[feature] = 'true'
|
|
}
|
|
return features
|
|
} else if (this.isCloud()) {
|
|
if (!this.stripeManager || !subscriptionId) {
|
|
return {}
|
|
}
|
|
return await this.stripeManager.getFeaturesByPlan(subscriptionId, withoutCache)
|
|
}
|
|
return {}
|
|
}
|
|
|
|
public static checkFeatureByPlan(feature: string) {
|
|
return (req: Request, res: Response, next: NextFunction) => {
|
|
const user = req.user
|
|
if (user) {
|
|
if (!user.features || Object.keys(user.features).length === 0) {
|
|
return res.status(403).json({ message: ErrorMessage.FORBIDDEN })
|
|
}
|
|
if (Object.keys(user.features).includes(feature) && user.features[feature] === 'true') {
|
|
return next()
|
|
}
|
|
}
|
|
return res.status(403).json({ message: ErrorMessage.FORBIDDEN })
|
|
}
|
|
}
|
|
|
|
public async createStripeCustomerPortalSession(req: Request) {
|
|
if (!this.stripeManager) {
|
|
throw new Error('Stripe manager is not initialized')
|
|
}
|
|
return await this.stripeManager.createStripeCustomerPortalSession(req)
|
|
}
|
|
|
|
public async getAdditionalSeatsQuantity(subscriptionId: string) {
|
|
if (!subscriptionId) return {}
|
|
if (!this.stripeManager) {
|
|
throw new Error('Stripe manager is not initialized')
|
|
}
|
|
return await this.stripeManager.getAdditionalSeatsQuantity(subscriptionId)
|
|
}
|
|
|
|
public async getCustomerWithDefaultSource(customerId: string) {
|
|
if (!customerId) return
|
|
if (!this.stripeManager) {
|
|
throw new Error('Stripe manager is not initialized')
|
|
}
|
|
return await this.stripeManager.getCustomerWithDefaultSource(customerId)
|
|
}
|
|
|
|
public async getAdditionalSeatsProration(subscriptionId: string, newQuantity: number) {
|
|
if (!subscriptionId) return {}
|
|
if (!this.stripeManager) {
|
|
throw new Error('Stripe manager is not initialized')
|
|
}
|
|
return await this.stripeManager.getAdditionalSeatsProration(subscriptionId, newQuantity)
|
|
}
|
|
|
|
public async updateAdditionalSeats(subscriptionId: string, quantity: number, prorationDate: number, increase: boolean) {
|
|
if (!subscriptionId) return {}
|
|
|
|
if (!this.stripeManager) {
|
|
throw new Error('Stripe manager is not initialized')
|
|
}
|
|
const { success, subscription, invoice, paymentFailed, paymentError } = await this.stripeManager.updateAdditionalSeats(
|
|
subscriptionId,
|
|
quantity,
|
|
prorationDate,
|
|
increase
|
|
)
|
|
|
|
// Fetch product details to get quotas
|
|
const items = subscription.items.data
|
|
if (items.length === 0) {
|
|
throw new Error('No subscription items found')
|
|
}
|
|
|
|
const productId = items[0].price.product as string
|
|
const product = await this.stripeManager.getStripe().products.retrieve(productId)
|
|
const productMetadata = product.metadata
|
|
|
|
// Extract quotas from metadata
|
|
const quotas: Record<string, number> = {}
|
|
for (const key in productMetadata) {
|
|
if (key.startsWith('quota:')) {
|
|
quotas[key] = parseInt(productMetadata[key])
|
|
}
|
|
}
|
|
quotas[LICENSE_QUOTAS.ADDITIONAL_SEATS_LIMIT] = quantity
|
|
|
|
// Get features from Stripe
|
|
const features = await this.getFeaturesByPlan(subscription.id, true)
|
|
|
|
// Update the cache with new subscription data including quotas
|
|
const cacheManager = await UsageCacheManager.getInstance()
|
|
await cacheManager.updateSubscriptionDataToCache(subscriptionId, {
|
|
features,
|
|
quotas,
|
|
subsriptionDetails: this.stripeManager.getSubscriptionObject(subscription)
|
|
})
|
|
|
|
return {
|
|
success,
|
|
subscription,
|
|
invoice,
|
|
paymentFailed,
|
|
paymentError: paymentFailed ? paymentError?.message || 'Payment failed' : null
|
|
}
|
|
}
|
|
|
|
public async getPlanProration(subscriptionId: string, newPlanId: string) {
|
|
if (!subscriptionId || !newPlanId) return {}
|
|
|
|
if (!this.stripeManager) {
|
|
throw new Error('Stripe manager is not initialized')
|
|
}
|
|
return await this.stripeManager.getPlanProration(subscriptionId, newPlanId)
|
|
}
|
|
|
|
public async updateSubscriptionPlan(req: Request, subscriptionId: string, newPlanId: string, prorationDate: number) {
|
|
if (!subscriptionId || !newPlanId) return {}
|
|
|
|
if (!this.stripeManager) {
|
|
throw new Error('Stripe manager is not initialized')
|
|
}
|
|
if (!req.user) {
|
|
throw new InternalFlowiseError(StatusCodes.UNAUTHORIZED, GeneralErrorMessage.UNAUTHORIZED)
|
|
}
|
|
try {
|
|
const result = await this.stripeManager.updateSubscriptionPlan(subscriptionId, newPlanId, prorationDate)
|
|
const { success, subscription, special_case, paymentFailed, paymentError } = result
|
|
if (success) {
|
|
// Handle special case: downgrade from past_due to free plan
|
|
if (special_case === 'downgrade_from_past_due') {
|
|
// Update organization status to active using OrganizationService
|
|
const queryRunner = getRunningExpressApp().AppDataSource.createQueryRunner()
|
|
await queryRunner.connect()
|
|
|
|
try {
|
|
const organizationService = new OrganizationService()
|
|
|
|
// Find organization by subscriptionId
|
|
const organization = await queryRunner.manager.findOne(Organization, {
|
|
where: { subscriptionId }
|
|
})
|
|
|
|
if (organization) {
|
|
await organizationService.updateOrganization(
|
|
{
|
|
id: organization.id,
|
|
status: OrganizationStatus.ACTIVE,
|
|
updatedBy: req.user.id
|
|
},
|
|
queryRunner,
|
|
true // fromStripe = true to allow status updates
|
|
)
|
|
}
|
|
} finally {
|
|
await queryRunner.release()
|
|
}
|
|
}
|
|
|
|
// Fetch product details to get quotas
|
|
const product = await this.stripeManager.getStripe().products.retrieve(newPlanId)
|
|
const productMetadata = product.metadata
|
|
|
|
// Extract quotas from metadata
|
|
const quotas: Record<string, number> = {}
|
|
for (const key in productMetadata) {
|
|
if (key.startsWith('quota:')) {
|
|
quotas[key] = parseInt(productMetadata[key])
|
|
}
|
|
}
|
|
|
|
const additionalSeatsItem = subscription.items.data.find(
|
|
(item: any) => (item.price.product as string) === process.env.ADDITIONAL_SEAT_ID
|
|
)
|
|
quotas[LICENSE_QUOTAS.ADDITIONAL_SEATS_LIMIT] = additionalSeatsItem?.quantity || 0
|
|
|
|
// Get features from Stripe
|
|
const features = await this.getFeaturesByPlan(subscription.id, true)
|
|
|
|
// Update the cache with new subscription data including quotas
|
|
const cacheManager = await UsageCacheManager.getInstance()
|
|
|
|
const updateCacheData: Record<string, any> = {
|
|
features,
|
|
quotas,
|
|
subsriptionDetails: this.stripeManager.getSubscriptionObject(subscription)
|
|
}
|
|
|
|
if (
|
|
newPlanId === process.env.CLOUD_FREE_ID ||
|
|
newPlanId === process.env.CLOUD_STARTER_ID ||
|
|
newPlanId === process.env.CLOUD_PRO_ID
|
|
) {
|
|
updateCacheData.productId = newPlanId
|
|
}
|
|
|
|
await cacheManager.updateSubscriptionDataToCache(subscriptionId, updateCacheData)
|
|
|
|
const loggedInUser: LoggedInUser = {
|
|
...req.user,
|
|
activeOrganizationSubscriptionId: subscription.id,
|
|
features
|
|
}
|
|
|
|
if (
|
|
newPlanId === process.env.CLOUD_FREE_ID ||
|
|
newPlanId === process.env.CLOUD_STARTER_ID ||
|
|
newPlanId === process.env.CLOUD_PRO_ID
|
|
) {
|
|
loggedInUser.activeOrganizationProductId = newPlanId
|
|
}
|
|
|
|
req.user = {
|
|
...req.user,
|
|
...loggedInUser
|
|
}
|
|
|
|
// Update passport session
|
|
// @ts-ignore
|
|
req.session.passport.user = {
|
|
...req.user,
|
|
...loggedInUser
|
|
}
|
|
|
|
req.session.save((err) => {
|
|
if (err) throw new InternalFlowiseError(StatusCodes.BAD_REQUEST, GeneralErrorMessage.UNHANDLED_EDGE_CASE)
|
|
})
|
|
|
|
return {
|
|
status: 'success',
|
|
user: loggedInUser,
|
|
paymentFailed,
|
|
paymentError: paymentFailed ? paymentError?.message || 'Payment failed' : null
|
|
}
|
|
}
|
|
return {
|
|
status: 'error',
|
|
message: 'Payment or subscription update not completed'
|
|
}
|
|
} catch (error: any) {
|
|
// Enhanced error handling for payment method failures
|
|
if (error.type === 'StripeCardError' || error.code === 'card_declined') {
|
|
throw new InternalFlowiseError(
|
|
StatusCodes.PAYMENT_REQUIRED,
|
|
'Your payment method was declined. Please update your payment method and try again.'
|
|
)
|
|
}
|
|
|
|
if (error.type === 'StripeInvalidRequestError' && error.message?.includes('payment_method')) {
|
|
throw new InternalFlowiseError(
|
|
StatusCodes.PAYMENT_REQUIRED,
|
|
'There was an issue with your payment method. Please update your payment method and try again.'
|
|
)
|
|
}
|
|
|
|
// Re-throw other errors
|
|
throw error
|
|
}
|
|
}
|
|
|
|
public async createStripeUserAndSubscribe({ email, userPlan, referral }: { email: string; userPlan: UserPlan; referral?: string }) {
|
|
if (!this.stripeManager) {
|
|
throw new Error('Stripe manager is not initialized')
|
|
}
|
|
|
|
try {
|
|
// Create a customer in Stripe
|
|
let customer: Stripe.Response<Stripe.Customer>
|
|
if (referral) {
|
|
customer = await this.stripeManager.getStripe().customers.create({
|
|
email: email,
|
|
metadata: {
|
|
referral
|
|
}
|
|
})
|
|
} else {
|
|
customer = await this.stripeManager.getStripe().customers.create({
|
|
email: email
|
|
})
|
|
}
|
|
|
|
let productId = ''
|
|
switch (userPlan) {
|
|
case UserPlan.STARTER:
|
|
productId = process.env.CLOUD_STARTER_ID as string
|
|
break
|
|
case UserPlan.PRO:
|
|
productId = process.env.CLOUD_PRO_ID as string
|
|
break
|
|
case UserPlan.FREE:
|
|
productId = process.env.CLOUD_FREE_ID as string
|
|
break
|
|
}
|
|
|
|
// Get the default price ID for the product
|
|
const prices = await this.stripeManager.getStripe().prices.list({
|
|
product: productId,
|
|
active: true,
|
|
limit: 1
|
|
})
|
|
|
|
if (!prices.data.length) {
|
|
throw new Error('No active price found for the product')
|
|
}
|
|
|
|
// Create the subscription
|
|
const subscription = await this.stripeManager.getStripe().subscriptions.create({
|
|
customer: customer.id,
|
|
items: [{ price: prices.data[0].id }]
|
|
})
|
|
|
|
return {
|
|
customerId: customer.id,
|
|
subscriptionId: subscription.id
|
|
}
|
|
} catch (error) {
|
|
console.error('Error creating Stripe user and subscription:', error)
|
|
throw error
|
|
}
|
|
}
|
|
}
|