diff --git a/src/modules/database/base/BaseSeeder.ts b/src/modules/database/base/BaseSeeder.ts index 8d553a5..9712cfd 100644 --- a/src/modules/database/base/BaseSeeder.ts +++ b/src/modules/database/base/BaseSeeder.ts @@ -6,12 +6,14 @@ import { DataSource, EntityManager, EntityTarget, ObjectLiteral } from 'typeorm' import { Configure } from '@/modules/config/configure'; import { panic } from '@/modules/core/helpers'; import { + DBFactory, Seeder, SeederConstructor, SeederLoadParams, SeederOptions, } from '@/modules/database/commands/types'; -import { DBOptions } from '@/modules/database/types'; +import { DBFactoryOption, DBOptions } from '@/modules/database/types'; +import { factoryBuilder } from '@/modules/database/utils'; /** * 数据填充基类 @@ -23,6 +25,7 @@ export abstract class BaseSeeder implements Seeder { protected configure: Configure; protected ignoreLock: boolean; protected truncates: EntityTarget[] = []; + protected factories: { [entityName: string]: DBFactoryOption }; constructor( protected readonly spinner: Ora, @@ -34,12 +37,13 @@ export abstract class BaseSeeder implements Seeder { * @param params */ async load(params: SeederLoadParams): Promise { - const { connection, dataSource, em, configure, ignoreLock } = params; + const { connection, dataSource, em, configure, ignoreLock, factory, factories } = params; this.connection = connection; this.dataSource = dataSource; this.em = em; this.configure = configure; this.ignoreLock = ignoreLock; + this.factories = factories; if (this.ignoreLock) { for (const option of this.truncates) { @@ -47,16 +51,21 @@ export abstract class BaseSeeder implements Seeder { } } - return this.run(this.dataSource); + return this.run(factory, this.dataSource); } /** * 运行seeder的关键方法 + * @param factory * @param dataSource * @param em * @protected */ - protected abstract run(dataSource: DataSource, em?: EntityManager): Promise; + protected abstract run( + factory?: DBFactory, + dataSource?: DataSource, + em?: EntityManager, + ): Promise; protected async getDBConfig() { const { connections = [] }: DBOptions = await this.configure.get('database'); @@ -80,6 +89,8 @@ export abstract class BaseSeeder implements Seeder { em: this.em, configure: this.configure, ignoreLock: this.ignoreLock, + factories: this.factories, + factory: factoryBuilder(this.configure, this.dataSource, this.factories), }); } } diff --git a/src/modules/database/commands/types.ts b/src/modules/database/commands/types.ts index c5d9ca1..2986a58 100644 --- a/src/modules/database/commands/types.ts +++ b/src/modules/database/commands/types.ts @@ -1,8 +1,10 @@ import { Ora } from 'ora'; -import { DataSource, EntityManager } from 'typeorm'; +import { DataSource, EntityManager, EntityTarget } from 'typeorm'; import { Arguments } from 'yargs'; import { Configure } from '@/modules/config/configure'; +import { DataFactory } from '@/modules/database/resolver/data.factory'; +import { DBFactoryOption } from '@/modules/database/types'; /** * 基础数据库命令参数类型 @@ -121,9 +123,40 @@ export interface SeederLoadParams { * 是否忽略锁定 */ ignoreLock: boolean; + /** + * Factory解析器 + */ + factory?: DBFactory; + /** + * Factory函数列表 + */ + factories: FactoryOptions; } /** * 数据填充命令参数 */ export type SeederArguments = TypeOrmArguments & SeederOptions; + +/** + * Factory解析器 + */ +export interface DBFactory { +

(entity: EntityTarget

): (options?: T) => DataFactory; +} + +/** + * 数据填充函数映射对象 + */ +export type FactoryOptions = { + [entityName: string]: DBFactoryOption; +}; + +/** + * Factory构造器 + */ +export type DBFactoryBuilder = ( + configure: Configure, + dataSource: DataSource, + factories: { [entityName: string]: DBFactoryOption }, +) => DBFactory; diff --git a/src/modules/database/config.ts b/src/modules/database/config.ts index 72e2549..db46ce1 100644 --- a/src/modules/database/config.ts +++ b/src/modules/database/config.ts @@ -14,7 +14,13 @@ export const createDBConfig: ( register, hook: (configure, value) => createDBOptions(value), defaultRegister: () => ({ - common: { charset: 'utf8mb4', logging: ['error'], seeders: [], seedRunner: SeederRunner }, + common: { + charset: 'utf8mb4', + logging: ['error'], + seeders: [], + seedRunner: SeederRunner, + factories: [], + }, connections: [], }), }); diff --git a/src/modules/database/resolver/data.factory.ts b/src/modules/database/resolver/data.factory.ts new file mode 100644 index 0000000..668c481 --- /dev/null +++ b/src/modules/database/resolver/data.factory.ts @@ -0,0 +1,107 @@ +import { isPromise } from 'node:util/types'; + +import { isNil } from 'lodash'; +import { EntityManager, EntityTarget } from 'typeorm'; + +import { panic } from '@/modules/core/helpers'; +import { DBFactoryHandler, FactoryOverride } from '@/modules/database/types'; + +export class DataFactory { + private mapFunction!: (entity: P) => Promise

; + + constructor( + public name: string, + public config: Configure, + public entity: EntityTarget

, + protected em: EntityManager, + protected factory: DBFactoryHandler, + protected settings: T, + ) {} + + map(mapFunction: (entity: P) => Promise

): DataFactory { + this.mapFunction = mapFunction; + return this; + } + + async make(params: FactoryOverride

= {}): Promise

{ + if (this.factory) { + let entity: P = await this.resolveEntity( + await this.factory(this.configure, this.settings), + ); + if (this.mapFunction) { + entity = await this.mapFunction(entity); + } + for (const key in params) { + if (params[key]) { + entity[key] = params[key]; + } + } + return entity; + } + throw new Error('Could not found entity'); + } + + async create(params: FactoryOverride

= {}, existsCheck?: string): Promise

{ + try { + const entity = await this.make(params); + if (!isNil(existsCheck)) { + const repo = this.em.getRepository(this.entity); + const value = (entity as any)[existsCheck]; + if (!isNil(value)) { + const item = await repo.findOneBy({ [existsCheck]: value } as any); + if (isNil(item)) { + return await this.em.save(entity); + } + return item; + } + } + return await this.em.save(entity); + } catch (error) { + const message = 'Could not save entity'; + await panic({ message, error }); + throw new Error(message); + } + } + + async makeMany(amount: number, params: FactoryOverride

= {}): Promise { + const list = []; + for (let i = 0; i < amount; i++) { + list[i] = await this.make(params); + } + return list; + } + + async createMany( + amount: number, + params: FactoryOverride

= {}, + existsCheck?: string, + ): Promise { + const list = []; + for (let i = 0; i < amount; i++) { + list[i] = await this.create(params, existsCheck); + } + return list; + } + + private async resolveEntity(entity: P): Promise

{ + for (const attr in entity) { + if (entity[attr]) { + if (isPromise(entity[attr])) { + entity[attr] = await entity[attr]; + } else if (typeof entity[attr] === 'object' && !(entity[attr] instanceof Date)) { + const item = entity[attr]; + try { + if (typeof (item as any).make === 'function') { + entity[attr] = await (item as any).make(); + } + } catch (error) { + const message = `Could not make ${(subEntityFactory as any).name}`; + await panic({ message, error }); + throw new Error(message); + } + } + } + } + return entity; + } +} diff --git a/src/modules/database/resolver/seeder.runner.ts b/src/modules/database/resolver/seeder.runner.ts index 4fb82e1..9f687e6 100644 --- a/src/modules/database/resolver/seeder.runner.ts +++ b/src/modules/database/resolver/seeder.runner.ts @@ -7,12 +7,17 @@ import { DataSource, EntityManager } from 'typeorm'; import YAML from 'yaml'; import { BaseSeeder } from '@/modules/database/base/BaseSeeder'; +import { DBFactory } from '@/modules/database/commands/types'; /** * 默认的Seed Runner */ export class SeederRunner extends BaseSeeder { - protected async run(dataSource: DataSource, em?: EntityManager): Promise { + protected async run( + factory: DBFactory, + dataSource: DataSource, + em: EntityManager, + ): Promise { let seeders: Type[] = ((await this.getDBConfig()) as any).seeders ?? []; const seedLockFile = resolve(__dirname, '../../../..', 'seed-lock.yml'); ensureFileSync(seedLockFile); diff --git a/src/modules/database/types.ts b/src/modules/database/types.ts index 2018f6e..a9eac37 100644 --- a/src/modules/database/types.ts +++ b/src/modules/database/types.ts @@ -2,11 +2,13 @@ import { TypeOrmModuleOptions } from '@nestjs/typeorm'; import { FindTreeOptions, ObjectLiteral, + ObjectType, Repository, SelectQueryBuilder, TreeRepository, } from 'typeorm'; +import { Configure } from '@/modules/config/configure'; import { SeederConstructor } from '@/modules/database/commands/types'; import { OrderType, SelectTrashMode } from '@/modules/database/constants'; @@ -102,4 +104,24 @@ type DBAdditionalOption = { * 数据填充入口类 */ seedRunner?: SeederConstructor; + /** + * 定义数据工厂列表 + */ + factories?: (() => DBFactoryOption)[]; +}; + +export type DBFactoryHandler = (configure: Configure, options: T) => Promise

; + +export type DBFactoryOption = { + entity: ObjectType

; + handler: DBFactoryHandler; +}; + +export type DefineFactory = ( + entity: ObjectType

, + handler: DBFactoryHandler, +) => () => DBFactoryOption; + +export type FactoryOverride = { + [Property in keyof Entity]: Entity[Property]; }; diff --git a/src/modules/database/utils.ts b/src/modules/database/utils.ts index 856bc5f..e99ade7 100644 --- a/src/modules/database/utils.ts +++ b/src/modules/database/utils.ts @@ -7,6 +7,7 @@ import { DataSource, DataSourceOptions, EntityManager, + EntityTarget, ObjectLiteral, ObjectType, Repository, @@ -14,9 +15,17 @@ import { } from 'typeorm'; import { Configure } from '@/modules/config/configure'; -import { Seeder, SeederConstructor, SeederOptions } from '@/modules/database/commands/types'; +import { + DBFactoryBuilder, + FactoryOptions, + Seeder, + SeederConstructor, + SeederOptions, +} from '@/modules/database/commands/types'; +import { DataFactory } from '@/modules/database/resolver/data.factory'; import { DBOptions, + DefineFactory, OrderQueryType, PaginateOptions, PaginateReturn, @@ -206,6 +215,13 @@ export async function runSeeder( const dataSource: DataSource = new DataSource({ ...dbConfig } as DataSourceOptions); await dataSource.initialize(); + + const factoryMaps: FactoryOptions = {}; + for (const factory of dbConfig.factories) { + const { entity, handler } = factory(); + factoryMaps[entity.name] = { entity, handler }; + } + if (typeof args.transaction === 'boolean' && !args.transaction) { const em = await resetForeignKey(dataSource.manager, dataSource.options.type); await seeder.load({ @@ -214,6 +230,8 @@ export async function runSeeder( configure, connection: args.connection ?? 'default', ignoreLock: args.ignorelock, + factory: factoryBuilder(configure, dataSource, factoryMaps), + factories: factoryMaps, }); await resetForeignKey(em, dataSource.options.type, false); } else { @@ -229,6 +247,8 @@ export async function runSeeder( configure, connection: args.connection ?? 'default', ignoreLock: args.ignorelock, + factory: factoryBuilder(configure, dataSource, factoryMaps), + factories: factoryMaps, }); await resetForeignKey(em, dataSource.options.type, false); await queryRunner.commitTransaction(); @@ -245,3 +265,40 @@ export async function runSeeder( } return dataSource; } + +/** + * 定义factory用于生成数据 + * @param entity + * @param handler + */ +export const defineFactory: DefineFactory = (entity, handler) => () => ({ entity, handler }); + +/** + * 获取Entity类名 + * @param entity + */ +export function entityName(entity: EntityTarget): string { + if (isNil(entity)) { + throw new Error('Entity is not defined'); + } + if (entity instanceof Function) { + return entity.name; + } + return new (entity as any)().constructor.name; +} + +export const factoryBuilder: DBFactoryBuilder = + (configure, dataSource, factories) => (entity) => (settings) => { + const name = entityName(entity); + if (!factories[name]) { + throw new Error(`has none factory for entity named ${name}`); + } + return new DataFactory( + name, + configure, + entity, + dataSource.createEntityManager(), + factories[name].handler, + settings, + ); + };