import { Injectable, Logger } from '@nestjs/common'; import { Knex, knex } from 'knex'; import { getCentralPrisma } from '../prisma/central-prisma.service'; import * as crypto from 'crypto'; @Injectable() export class TenantDatabaseService { private readonly logger = new Logger(TenantDatabaseService.name); private tenantConnections: Map = new Map(); /** * Get tenant database connection by domain (for subdomain-based authentication) * This is used when users log in via tenant subdomains */ async getTenantKnexByDomain(domain: string): Promise { const cacheKey = `domain:${domain}`; // Check if we have a cached connection if (this.tenantConnections.has(cacheKey)) { // Validate the domain still exists before returning cached connection const centralPrisma = getCentralPrisma(); try { const domainRecord = await centralPrisma.domain.findUnique({ where: { domain }, }); // If domain no longer exists, remove cached connection if (!domainRecord) { this.logger.warn(`Domain ${domain} no longer exists, removing cached connection`); await this.disconnectTenant(cacheKey); throw new Error(`Domain ${domain} not found`); } } catch (error) { // If domain doesn't exist, remove from cache and re-throw if (error.message.includes('not found')) { throw error; } // For other errors, log but continue with cached connection this.logger.warn(`Error validating domain ${domain}:`, error.message); } return this.tenantConnections.get(cacheKey); } const centralPrisma = getCentralPrisma(); // Find tenant by domain const domainRecord = await centralPrisma.domain.findUnique({ where: { domain }, include: { tenant: true }, }); if (!domainRecord) { throw new Error(`Domain ${domain} not found`); } const tenant = domainRecord.tenant; this.logger.log(`Found tenant by domain: ${domain} -> ${tenant.name}`); if (tenant.status !== 'active') { throw new Error(`Tenant ${tenant.name} is not active`); } // Create connection and cache it const tenantKnex = await this.createTenantConnection(tenant); this.tenantConnections.set(cacheKey, tenantKnex); return tenantKnex; } /** * Get tenant database connection by tenant ID (for central admin operations) * This is used when central admin needs to access tenant databases */ async getTenantKnexById(tenantId: string): Promise { const cacheKey = `id:${tenantId}`; // Check if we have a cached connection (no validation needed for ID-based lookups) if (this.tenantConnections.has(cacheKey)) { return this.tenantConnections.get(cacheKey); } const centralPrisma = getCentralPrisma(); // Find tenant by ID const tenant = await centralPrisma.tenant.findUnique({ where: { id: tenantId }, }); if (!tenant) { throw new Error(`Tenant ${tenantId} not found`); } if (tenant.status !== 'active') { throw new Error(`Tenant ${tenant.name} is not active`); } this.logger.log(`Connecting to tenant database by ID: ${tenant.name}`); // Create connection and cache it const tenantKnex = await this.createTenantConnection(tenant); this.tenantConnections.set(cacheKey, tenantKnex); return tenantKnex; } /** * Legacy method - delegates to domain-based lookup * @deprecated Use getTenantKnexByDomain or getTenantKnexById instead */ async getTenantKnex(tenantIdOrSlug: string): Promise { // Assume it's a domain if it contains a dot return this.getTenantKnexByDomain(tenantIdOrSlug); } /** * Create a new Knex connection to a tenant database */ private async createTenantConnection(tenant: any): Promise { // Decrypt password const decryptedPassword = this.decryptPassword(tenant.dbPassword); const tenantKnex = knex({ client: 'mysql2', connection: { host: tenant.dbHost, port: tenant.dbPort, user: tenant.dbUsername, password: decryptedPassword, database: tenant.dbName, }, pool: { min: 2, max: 10, }, }); // Test connection try { await tenantKnex.raw('SELECT 1'); this.logger.log(`Connected to tenant database: ${tenant.dbName}`); } catch (error) { this.logger.error( `Failed to connect to tenant database: ${tenant.dbName}`, error, ); throw error; } return tenantKnex; } async getTenantByDomain(domain: string): Promise { const centralPrisma = getCentralPrisma(); const domainRecord = await centralPrisma.domain.findUnique({ where: { domain }, include: { tenant: true }, }); if (!domainRecord) { throw new Error(`Domain ${domain} not found`); } if (domainRecord.tenant.status !== 'active') { throw new Error(`Tenant for domain ${domain} is not active`); } return domainRecord.tenant; } /** * Resolve tenant by ID or slug * Tries ID first, then falls back to slug */ async resolveTenantId(idOrSlug: string): Promise { const centralPrisma = getCentralPrisma(); // Try by ID first let tenant = await centralPrisma.tenant.findUnique({ where: { id: idOrSlug }, }); // If not found, try by slug if (!tenant) { tenant = await centralPrisma.tenant.findUnique({ where: { slug: idOrSlug }, }); } if (!tenant) { throw new Error(`Tenant ${idOrSlug} not found`); } if (tenant.status !== 'active') { throw new Error(`Tenant ${tenant.name} is not active`); } return tenant.id; } async disconnectTenant(tenantId: string) { const connection = this.tenantConnections.get(tenantId); if (connection) { await connection.destroy(); this.tenantConnections.delete(tenantId); this.logger.log(`Disconnected tenant: ${tenantId}`); } } removeTenantConnection(tenantId: string) { this.tenantConnections.delete(tenantId); this.logger.log(`Removed tenant connection from cache: ${tenantId}`); } async disconnectAll() { for (const [tenantId, connection] of this.tenantConnections.entries()) { await connection.destroy(); } this.tenantConnections.clear(); this.logger.log('Disconnected all tenant connections'); } encryptPassword(password: string): string { const algorithm = 'aes-256-cbc'; const key = Buffer.from(process.env.ENCRYPTION_KEY, 'hex'); const iv = crypto.randomBytes(16); const cipher = crypto.createCipheriv(algorithm, key, iv); let encrypted = cipher.update(password, 'utf8', 'hex'); encrypted += cipher.final('hex'); return iv.toString('hex') + ':' + encrypted; } private decryptPassword(encryptedPassword: string): string { const algorithm = 'aes-256-cbc'; const key = Buffer.from(process.env.ENCRYPTION_KEY, 'hex'); const parts = encryptedPassword.split(':'); const iv = Buffer.from(parts[0], 'hex'); const encrypted = parts[1]; const decipher = crypto.createDecipheriv(algorithm, key, iv); let decrypted = decipher.update(encrypted, 'hex', 'utf8'); decrypted += decipher.final('utf8'); return decrypted; } }