feat: 合并可选是否允许重复键

This commit is contained in:
unanmed 2025-03-24 17:16:57 +08:00
parent f169167409
commit 87f48dc8ee
4 changed files with 13 additions and 4 deletions

View File

@ -10,6 +10,6 @@ const [output, ...datasets] = process.argv.slice(2);
return JSON.parse(file) as DatasetMergable<any>;
})
);
const merged = mergeDataset(...data);
const merged = mergeDataset(true, ...data);
await writeFile(output, JSON.stringify(merged), 'utf-8');
})();

View File

@ -26,13 +26,13 @@ function getNum() {
);
const targetFile = await readFile(target, 'utf-8');
const targetData = JSON.parse(targetFile) as DatasetMergable<any>;
const merged = mergeDataset(...datas);
const merged = mergeDataset(true, ...datas);
const keys = Object.keys(merged.data);
const toReview = chooseFrom(keys, n);
const reviewData: DatasetMergable<any> = {
datasetId: Math.floor(Math.random() * 1e12),
data: Object.fromEntries(toReview.map(v => [v, merged.data[v]]))
};
const reviewed = mergeDataset(targetData, reviewData);
const reviewed = mergeDataset(false, targetData, reviewData);
await writeFile(target, JSON.stringify(reviewed), 'utf-8');
})();

View File

@ -15,6 +15,7 @@ export interface FloorData {
}
export function mergeDataset<T>(
allowDuplicateKeys: boolean,
...datasets: DatasetMergable<T>[]
): DatasetMergable<T> {
if (datasets.length === 1) {
@ -24,7 +25,7 @@ export function mergeDataset<T>(
const data: Record<string, T> = {};
datasets.forEach(v => {
for (const [key, value] of Object.entries(v.data)) {
if (usedKeys.has(key)) {
if (usedKeys.has(key) && allowDuplicateKeys) {
const dataKey = `${v.datasetId}/${key}`;
data[dataKey] = value;
usedKeys.add(dataKey);

View File

@ -61,12 +61,20 @@ def train():
if args.load_optim:
optimizer.load_state_dict(data["optimizer_state"])
print("Train from loaded state.")
else:
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
criterion.weight[0] = 0.0
# 开始训练
for epoch in tqdm(range(args.epochs)):
model.train()
total_loss = 0
# 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来
if not args.resume and epoch == 10:
criterion.weight[0] = 0.5
for batch in dataloader:
# 数据迁移到设备
target = batch["target"].to(device)