mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 11:01:12 +08:00
feat: 合并可选是否允许重复键
This commit is contained in:
parent
f169167409
commit
87f48dc8ee
@ -10,6 +10,6 @@ const [output, ...datasets] = process.argv.slice(2);
|
|||||||
return JSON.parse(file) as DatasetMergable<any>;
|
return JSON.parse(file) as DatasetMergable<any>;
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
const merged = mergeDataset(...data);
|
const merged = mergeDataset(true, ...data);
|
||||||
await writeFile(output, JSON.stringify(merged), 'utf-8');
|
await writeFile(output, JSON.stringify(merged), 'utf-8');
|
||||||
})();
|
})();
|
||||||
|
|||||||
@ -26,13 +26,13 @@ function getNum() {
|
|||||||
);
|
);
|
||||||
const targetFile = await readFile(target, 'utf-8');
|
const targetFile = await readFile(target, 'utf-8');
|
||||||
const targetData = JSON.parse(targetFile) as DatasetMergable<any>;
|
const targetData = JSON.parse(targetFile) as DatasetMergable<any>;
|
||||||
const merged = mergeDataset(...datas);
|
const merged = mergeDataset(true, ...datas);
|
||||||
const keys = Object.keys(merged.data);
|
const keys = Object.keys(merged.data);
|
||||||
const toReview = chooseFrom(keys, n);
|
const toReview = chooseFrom(keys, n);
|
||||||
const reviewData: DatasetMergable<any> = {
|
const reviewData: DatasetMergable<any> = {
|
||||||
datasetId: Math.floor(Math.random() * 1e12),
|
datasetId: Math.floor(Math.random() * 1e12),
|
||||||
data: Object.fromEntries(toReview.map(v => [v, merged.data[v]]))
|
data: Object.fromEntries(toReview.map(v => [v, merged.data[v]]))
|
||||||
};
|
};
|
||||||
const reviewed = mergeDataset(targetData, reviewData);
|
const reviewed = mergeDataset(false, targetData, reviewData);
|
||||||
await writeFile(target, JSON.stringify(reviewed), 'utf-8');
|
await writeFile(target, JSON.stringify(reviewed), 'utf-8');
|
||||||
})();
|
})();
|
||||||
|
|||||||
@ -15,6 +15,7 @@ export interface FloorData {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function mergeDataset<T>(
|
export function mergeDataset<T>(
|
||||||
|
allowDuplicateKeys: boolean,
|
||||||
...datasets: DatasetMergable<T>[]
|
...datasets: DatasetMergable<T>[]
|
||||||
): DatasetMergable<T> {
|
): DatasetMergable<T> {
|
||||||
if (datasets.length === 1) {
|
if (datasets.length === 1) {
|
||||||
@ -24,7 +25,7 @@ export function mergeDataset<T>(
|
|||||||
const data: Record<string, T> = {};
|
const data: Record<string, T> = {};
|
||||||
datasets.forEach(v => {
|
datasets.forEach(v => {
|
||||||
for (const [key, value] of Object.entries(v.data)) {
|
for (const [key, value] of Object.entries(v.data)) {
|
||||||
if (usedKeys.has(key)) {
|
if (usedKeys.has(key) && allowDuplicateKeys) {
|
||||||
const dataKey = `${v.datasetId}/${key}`;
|
const dataKey = `${v.datasetId}/${key}`;
|
||||||
data[dataKey] = value;
|
data[dataKey] = value;
|
||||||
usedKeys.add(dataKey);
|
usedKeys.add(dataKey);
|
||||||
|
|||||||
@ -62,11 +62,19 @@ def train():
|
|||||||
optimizer.load_state_dict(data["optimizer_state"])
|
optimizer.load_state_dict(data["optimizer_state"])
|
||||||
print("Train from loaded state.")
|
print("Train from loaded state.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
|
||||||
|
criterion.weight[0] = 0.0
|
||||||
|
|
||||||
# 开始训练
|
# 开始训练
|
||||||
for epoch in tqdm(range(args.epochs)):
|
for epoch in tqdm(range(args.epochs)):
|
||||||
model.train()
|
model.train()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
|
||||||
|
# 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来
|
||||||
|
if not args.resume and epoch == 10:
|
||||||
|
criterion.weight[0] = 0.5
|
||||||
|
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
# 数据迁移到设备
|
# 数据迁移到设备
|
||||||
target = batch["target"].to(device)
|
target = batch["target"].to(device)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user