mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 15:01:10 +08:00
feat: 合并可选是否允许重复键
This commit is contained in:
parent
f169167409
commit
87f48dc8ee
@ -10,6 +10,6 @@ const [output, ...datasets] = process.argv.slice(2);
|
||||
return JSON.parse(file) as DatasetMergable<any>;
|
||||
})
|
||||
);
|
||||
const merged = mergeDataset(...data);
|
||||
const merged = mergeDataset(true, ...data);
|
||||
await writeFile(output, JSON.stringify(merged), 'utf-8');
|
||||
})();
|
||||
|
||||
@ -26,13 +26,13 @@ function getNum() {
|
||||
);
|
||||
const targetFile = await readFile(target, 'utf-8');
|
||||
const targetData = JSON.parse(targetFile) as DatasetMergable<any>;
|
||||
const merged = mergeDataset(...datas);
|
||||
const merged = mergeDataset(true, ...datas);
|
||||
const keys = Object.keys(merged.data);
|
||||
const toReview = chooseFrom(keys, n);
|
||||
const reviewData: DatasetMergable<any> = {
|
||||
datasetId: Math.floor(Math.random() * 1e12),
|
||||
data: Object.fromEntries(toReview.map(v => [v, merged.data[v]]))
|
||||
};
|
||||
const reviewed = mergeDataset(targetData, reviewData);
|
||||
const reviewed = mergeDataset(false, targetData, reviewData);
|
||||
await writeFile(target, JSON.stringify(reviewed), 'utf-8');
|
||||
})();
|
||||
|
||||
@ -15,6 +15,7 @@ export interface FloorData {
|
||||
}
|
||||
|
||||
export function mergeDataset<T>(
|
||||
allowDuplicateKeys: boolean,
|
||||
...datasets: DatasetMergable<T>[]
|
||||
): DatasetMergable<T> {
|
||||
if (datasets.length === 1) {
|
||||
@ -24,7 +25,7 @@ export function mergeDataset<T>(
|
||||
const data: Record<string, T> = {};
|
||||
datasets.forEach(v => {
|
||||
for (const [key, value] of Object.entries(v.data)) {
|
||||
if (usedKeys.has(key)) {
|
||||
if (usedKeys.has(key) && allowDuplicateKeys) {
|
||||
const dataKey = `${v.datasetId}/${key}`;
|
||||
data[dataKey] = value;
|
||||
usedKeys.add(dataKey);
|
||||
|
||||
@ -61,12 +61,20 @@ def train():
|
||||
if args.load_optim:
|
||||
optimizer.load_state_dict(data["optimizer_state"])
|
||||
print("Train from loaded state.")
|
||||
|
||||
else:
|
||||
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
|
||||
criterion.weight[0] = 0.0
|
||||
|
||||
# 开始训练
|
||||
for epoch in tqdm(range(args.epochs)):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
|
||||
# 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来
|
||||
if not args.resume and epoch == 10:
|
||||
criterion.weight[0] = 0.5
|
||||
|
||||
for batch in dataloader:
|
||||
# 数据迁移到设备
|
||||
target = batch["target"].to(device)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user