feat: 合并可选是否允许重复键

This commit is contained in:
unanmed 2025-03-24 17:16:57 +08:00
parent f169167409
commit 87f48dc8ee
4 changed files with 13 additions and 4 deletions

View File

@ -10,6 +10,6 @@ const [output, ...datasets] = process.argv.slice(2);
return JSON.parse(file) as DatasetMergable<any>; return JSON.parse(file) as DatasetMergable<any>;
}) })
); );
const merged = mergeDataset(...data); const merged = mergeDataset(true, ...data);
await writeFile(output, JSON.stringify(merged), 'utf-8'); await writeFile(output, JSON.stringify(merged), 'utf-8');
})(); })();

View File

@ -26,13 +26,13 @@ function getNum() {
); );
const targetFile = await readFile(target, 'utf-8'); const targetFile = await readFile(target, 'utf-8');
const targetData = JSON.parse(targetFile) as DatasetMergable<any>; const targetData = JSON.parse(targetFile) as DatasetMergable<any>;
const merged = mergeDataset(...datas); const merged = mergeDataset(true, ...datas);
const keys = Object.keys(merged.data); const keys = Object.keys(merged.data);
const toReview = chooseFrom(keys, n); const toReview = chooseFrom(keys, n);
const reviewData: DatasetMergable<any> = { const reviewData: DatasetMergable<any> = {
datasetId: Math.floor(Math.random() * 1e12), datasetId: Math.floor(Math.random() * 1e12),
data: Object.fromEntries(toReview.map(v => [v, merged.data[v]])) data: Object.fromEntries(toReview.map(v => [v, merged.data[v]]))
}; };
const reviewed = mergeDataset(targetData, reviewData); const reviewed = mergeDataset(false, targetData, reviewData);
await writeFile(target, JSON.stringify(reviewed), 'utf-8'); await writeFile(target, JSON.stringify(reviewed), 'utf-8');
})(); })();

View File

@ -15,6 +15,7 @@ export interface FloorData {
} }
export function mergeDataset<T>( export function mergeDataset<T>(
allowDuplicateKeys: boolean,
...datasets: DatasetMergable<T>[] ...datasets: DatasetMergable<T>[]
): DatasetMergable<T> { ): DatasetMergable<T> {
if (datasets.length === 1) { if (datasets.length === 1) {
@ -24,7 +25,7 @@ export function mergeDataset<T>(
const data: Record<string, T> = {}; const data: Record<string, T> = {};
datasets.forEach(v => { datasets.forEach(v => {
for (const [key, value] of Object.entries(v.data)) { for (const [key, value] of Object.entries(v.data)) {
if (usedKeys.has(key)) { if (usedKeys.has(key) && allowDuplicateKeys) {
const dataKey = `${v.datasetId}/${key}`; const dataKey = `${v.datasetId}/${key}`;
data[dataKey] = value; data[dataKey] = value;
usedKeys.add(dataKey); usedKeys.add(dataKey);

View File

@ -62,11 +62,19 @@ def train():
optimizer.load_state_dict(data["optimizer_state"]) optimizer.load_state_dict(data["optimizer_state"])
print("Train from loaded state.") print("Train from loaded state.")
else:
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
criterion.weight[0] = 0.0
# 开始训练 # 开始训练
for epoch in tqdm(range(args.epochs)): for epoch in tqdm(range(args.epochs)):
model.train() model.train()
total_loss = 0 total_loss = 0
# 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来
if not args.resume and epoch == 10:
criterion.weight[0] = 0.5
for batch in dataloader: for batch in dataloader:
# 数据迁移到设备 # 数据迁移到设备
target = batch["target"].to(device) target = batch["target"].to(device)