From 87f48dc8eee9f123fbe24406c1d365a410f589e4 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Mon, 24 Mar 2025 17:16:57 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=90=88=E5=B9=B6=E5=8F=AF=E9=80=89?= =?UTF-8?q?=E6=98=AF=E5=90=A6=E5=85=81=E8=AE=B8=E9=87=8D=E5=A4=8D=E9=94=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/src/merge.ts | 2 +- data/src/review.ts | 4 ++-- data/src/utils.ts | 3 ++- ginka/train.py | 8 ++++++++ 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/data/src/merge.ts b/data/src/merge.ts index b9f649c..73681d6 100644 --- a/data/src/merge.ts +++ b/data/src/merge.ts @@ -10,6 +10,6 @@ const [output, ...datasets] = process.argv.slice(2); return JSON.parse(file) as DatasetMergable; }) ); - const merged = mergeDataset(...data); + const merged = mergeDataset(true, ...data); await writeFile(output, JSON.stringify(merged), 'utf-8'); })(); diff --git a/data/src/review.ts b/data/src/review.ts index 5ba0df6..989da4a 100644 --- a/data/src/review.ts +++ b/data/src/review.ts @@ -26,13 +26,13 @@ function getNum() { ); const targetFile = await readFile(target, 'utf-8'); const targetData = JSON.parse(targetFile) as DatasetMergable; - const merged = mergeDataset(...datas); + const merged = mergeDataset(true, ...datas); const keys = Object.keys(merged.data); const toReview = chooseFrom(keys, n); const reviewData: DatasetMergable = { datasetId: Math.floor(Math.random() * 1e12), data: Object.fromEntries(toReview.map(v => [v, merged.data[v]])) }; - const reviewed = mergeDataset(targetData, reviewData); + const reviewed = mergeDataset(false, targetData, reviewData); await writeFile(target, JSON.stringify(reviewed), 'utf-8'); })(); diff --git a/data/src/utils.ts b/data/src/utils.ts index 0f0dde7..5040106 100644 --- a/data/src/utils.ts +++ b/data/src/utils.ts @@ -15,6 +15,7 @@ export interface FloorData { } export function mergeDataset( + allowDuplicateKeys: boolean, ...datasets: DatasetMergable[] ): DatasetMergable { if (datasets.length === 1) { @@ -24,7 +25,7 @@ export function mergeDataset( const data: Record = {}; datasets.forEach(v => { for (const [key, value] of Object.entries(v.data)) { - if (usedKeys.has(key)) { + if (usedKeys.has(key) && allowDuplicateKeys) { const dataKey = `${v.datasetId}/${key}`; data[dataKey] = value; usedKeys.add(dataKey); diff --git a/ginka/train.py b/ginka/train.py index 96d5d35..97e824e 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -61,12 +61,20 @@ def train(): if args.load_optim: optimizer.load_state_dict(data["optimizer_state"]) print("Train from loaded state.") + + else: + # 从头开始训练的话,初始时先把 minamo 损失值权重改为 0 + criterion.weight[0] = 0.0 # 开始训练 for epoch in tqdm(range(args.epochs)): model.train() total_loss = 0 + # 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来 + if not args.resume and epoch == 10: + criterion.weight[0] = 0.5 + for batch in dataloader: # 数据迁移到设备 target = batch["target"].to(device)