mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 18:31:13 +08:00
feat: 支持流式传输,改进数据集
This commit is contained in:
parent
8dea79a9f0
commit
325eb599c3
104
data/src/gan.ts
104
data/src/gan.ts
@ -8,7 +8,7 @@ const [refer] = process.argv.slice(2);
|
|||||||
|
|
||||||
let id = 0;
|
let id = 0;
|
||||||
|
|
||||||
function readMap(count: number, buffer: Buffer, h: number, w: number) {
|
function readMap(count: number, arr: number[], h: number, w: number) {
|
||||||
const area = w * h;
|
const area = w * h;
|
||||||
|
|
||||||
const maps: number[][][] = Array.from<number[][]>({
|
const maps: number[][][] = Array.from<number[][]>({
|
||||||
@ -19,7 +19,7 @@ function readMap(count: number, buffer: Buffer, h: number, w: number) {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
buffer.subarray(4).forEach((v, i) => {
|
arr.forEach((v, i) => {
|
||||||
const n = Math.floor(i / area);
|
const n = Math.floor(i / area);
|
||||||
const y = Math.floor((i % area) / w);
|
const y = Math.floor((i % area) / w);
|
||||||
const x = i % w;
|
const x = i % w;
|
||||||
@ -35,7 +35,7 @@ function generateGANData(
|
|||||||
map: number[][]
|
map: number[][]
|
||||||
) {
|
) {
|
||||||
const id2 = `$${id++}`;
|
const id2 = `$${id++}`;
|
||||||
const toTrain = chooseFrom(keys, 4);
|
const toTrain = chooseFrom(keys, 30);
|
||||||
const data = toTrain.map<MinamoTrainData[]>(v => {
|
const data = toTrain.map<MinamoTrainData[]>(v => {
|
||||||
const floor = refer.get(v);
|
const floor = refer.get(v);
|
||||||
if (!floor) return [];
|
if (!floor) return [];
|
||||||
@ -48,28 +48,67 @@ function generateGANData(
|
|||||||
return data.flat();
|
return data.flat();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const enum ReceiverStatus {
|
||||||
|
Header,
|
||||||
|
Content
|
||||||
|
}
|
||||||
|
|
||||||
|
class DataReceiver {
|
||||||
|
static active?: DataReceiver
|
||||||
|
/** 接收状态 */
|
||||||
|
private status: ReceiverStatus = ReceiverStatus.Header;
|
||||||
|
|
||||||
|
private received: number[] = []
|
||||||
|
private count: number = 0;
|
||||||
|
private h: number = 0;
|
||||||
|
private w: number = 0;
|
||||||
|
|
||||||
|
receive(buf: Buffer): [number[][][], number, number, number] | null {
|
||||||
|
// 数据通讯 node 输入协议,单位字节:
|
||||||
|
// 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
|
||||||
|
switch (this.status) {
|
||||||
|
case ReceiverStatus.Header:
|
||||||
|
this.count = buf.readInt16BE();
|
||||||
|
this.h = buf.readInt8(2);
|
||||||
|
this.w = buf.readInt8(3);
|
||||||
|
this.received.push(...buf.subarray(4));
|
||||||
|
this.status = ReceiverStatus.Content;
|
||||||
|
break;
|
||||||
|
case ReceiverStatus.Content:
|
||||||
|
this.received.push(...buf);
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if (this.received.length === this.count * this.h * this.w) {
|
||||||
|
delete DataReceiver.active;
|
||||||
|
return [readMap(this.count, this.received, this.h, this.w), this.count, this.h, this.w];
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static check(buf: Buffer) {
|
||||||
|
if (this.active) {
|
||||||
|
return this.active.receive(buf);
|
||||||
|
} else {
|
||||||
|
this.active = new DataReceiver();
|
||||||
|
return this.active.receive(buf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
(async () => {
|
(async () => {
|
||||||
const referTower = await readOne(refer);
|
const referTower = await readOne(refer);
|
||||||
const keys = [...referTower.keys()];
|
const keys = [...referTower.keys()];
|
||||||
|
|
||||||
const client = createConnection(SOCKET_FILE, () => {
|
const client = createConnection(SOCKET_FILE, () => {
|
||||||
console.log(`UDS IPC connected successfully.`);
|
console.log(`UDS IPC connected successfully.`);
|
||||||
// 发送四字节数据表示连接成功
|
|
||||||
client.write(new Uint8Array([0x00, 0x00, 0x00, 0x00]));
|
|
||||||
});
|
});
|
||||||
|
|
||||||
client.on('data', buffer => {
|
client.on('data', buffer => {
|
||||||
// 暂时不考虑流式传输,如果后续数据量非常大,再考虑优化
|
const data = DataReceiver.check(buffer);
|
||||||
// 数据通讯 node 输入协议,单位字节:
|
if (!data) return;
|
||||||
// 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
|
|
||||||
const count = buffer.readInt16BE();
|
const [map, count, h, w] = data;
|
||||||
if (buffer.length - 4 !== count * 32 * 32) {
|
|
||||||
client.write(`ERROR: byte length not match.`);
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
const h = buffer.readInt8(2);
|
|
||||||
const w = buffer.readInt8(3);
|
|
||||||
const map = readMap(count, buffer, h, w);
|
|
||||||
const simData = map.map(v => generateGANData(keys, referTower, v));
|
const simData = map.map(v => generateGANData(keys, referTower, v));
|
||||||
const rc = 0;
|
const rc = 0;
|
||||||
const compareData = simData.flat();
|
const compareData = simData.flat();
|
||||||
@ -83,12 +122,23 @@ function generateGANData(
|
|||||||
const toSend = Buffer.alloc(
|
const toSend = Buffer.alloc(
|
||||||
2 + // Tensor count
|
2 + // Tensor count
|
||||||
2 + // Review count
|
2 + // Review count
|
||||||
count + // Compare count
|
1 * count + // Compare count
|
||||||
2 * (count + rc) + // Similarity data
|
2 * 4 * (compareData.length + rc) + // Similarity data
|
||||||
compareData.length * 1 * h * w + // Compare map
|
compareData.length * 1 * h * w + // Compare map
|
||||||
rc * 2 * h * w, // Review map
|
rc * 2 * h * w, // Review map
|
||||||
0
|
0
|
||||||
);
|
);
|
||||||
|
console.log(
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
count,
|
||||||
|
2 * 4 * (compareData.length + rc),
|
||||||
|
compareData.length * 1 * h * w,
|
||||||
|
rc * 2 * h * w,
|
||||||
|
compareData.length,
|
||||||
|
rc
|
||||||
|
);
|
||||||
|
|
||||||
let offset = 0;
|
let offset = 0;
|
||||||
toSend.writeInt16BE(count); // Tensor count
|
toSend.writeInt16BE(count); // Tensor count
|
||||||
toSend.writeInt16BE(0, 2); // Review count
|
toSend.writeInt16BE(0, 2); // Review count
|
||||||
@ -98,9 +148,11 @@ function generateGANData(
|
|||||||
simData.map(v => v.length),
|
simData.map(v => v.length),
|
||||||
offset
|
offset
|
||||||
);
|
);
|
||||||
offset += count;
|
offset += 1 * count;
|
||||||
// Similarity data
|
// Similarity data
|
||||||
compareData.forEach(v => {
|
compareData.forEach(v => {
|
||||||
|
// console.log(v.visionSimilarity, v.topoSimilarity);
|
||||||
|
|
||||||
toSend.writeFloatBE(v.visionSimilarity, offset);
|
toSend.writeFloatBE(v.visionSimilarity, offset);
|
||||||
offset += 4;
|
offset += 4;
|
||||||
toSend.writeFloatBE(v.topoSimilarity, offset);
|
toSend.writeFloatBE(v.topoSimilarity, offset);
|
||||||
@ -108,15 +160,17 @@ function generateGANData(
|
|||||||
});
|
});
|
||||||
// Compare map
|
// Compare map
|
||||||
toSend.set(
|
toSend.set(
|
||||||
compareData.map(v => v.map1).flat(2),
|
new Uint8Array(compareData.map(v => v.map1).flat(3)),
|
||||||
offset // Set from Compare map
|
offset // Set from Compare map
|
||||||
);
|
);
|
||||||
offset += compareData.length * 1 * h * w;
|
offset += compareData.length * 1 * h * w;
|
||||||
// Review map
|
if (reviewData.length > 0) {
|
||||||
toSend.set(
|
// Review map
|
||||||
reviewData.map(v => [v.map1, v.map2]).flat(3),
|
toSend.set(
|
||||||
offset // Set from last chunk
|
new Uint8Array(reviewData.map(v => [v.map1, v.map2]).flat(4)),
|
||||||
);
|
offset // Set from last chunk
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
client.write(toSend);
|
client.write(toSend);
|
||||||
});
|
});
|
||||||
|
|||||||
@ -136,7 +136,7 @@ function generateSimilarData(id: string, map: number[][]) {
|
|||||||
// 生成最多两个微调地图
|
// 生成最多两个微调地图
|
||||||
const width = map[0].length;
|
const width = map[0].length;
|
||||||
const height = map.length;
|
const height = map.length;
|
||||||
const num = Math.floor(Math.random() * 1);
|
const num = Math.floor(Math.random() * 2);
|
||||||
const res: [id: string, data: MinamoTrainData][] = [];
|
const res: [id: string, data: MinamoTrainData][] = [];
|
||||||
|
|
||||||
for (let i = 0; i < num; i++) {
|
for (let i = 0; i < num; i++) {
|
||||||
|
|||||||
@ -51,10 +51,14 @@ class GinkaDataset(Dataset):
|
|||||||
class MinamoGANDataset(Dataset):
|
class MinamoGANDataset(Dataset):
|
||||||
def __init__(self, refer_data_path):
|
def __init__(self, refer_data_path):
|
||||||
self.refer = load_minamo_gan_data(load_data(refer_data_path))
|
self.refer = load_minamo_gan_data(load_data(refer_data_path))
|
||||||
self.data = list().extend(self.refer)
|
self.data = list()
|
||||||
|
self.data.extend(random.sample(self.refer, 1000))
|
||||||
|
|
||||||
def set_data(self, data: list):
|
def set_data(self, data: list):
|
||||||
self.data = data.extend(self.refer)
|
self.data.clear()
|
||||||
|
self.data.extend(data)
|
||||||
|
k = min(len(data) / 4, len(self.refer))
|
||||||
|
self.data.extend(random.sample(self.refer, int(k)))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -63,8 +67,8 @@ class MinamoGANDataset(Dataset):
|
|||||||
item = self.data[idx]
|
item = self.data[idx]
|
||||||
|
|
||||||
map1, map2, vis_sim, topo_sim, review = item
|
map1, map2, vis_sim, topo_sim, review = item
|
||||||
map1 = torch.ShortTensor(map1)
|
map1 = torch.LongTensor(map1)
|
||||||
map2 = torch.ShortTensor(map2)
|
map2 = torch.LongTensor(map2)
|
||||||
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
|
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
|
||||||
if review:
|
if review:
|
||||||
map1 = F.one_hot(map1, num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
map1 = F.one_hot(map1, num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||||
|
|||||||
@ -255,7 +255,7 @@ class GinkaLoss(nn.Module):
|
|||||||
minamo_loss = (1.0 - minamo_sim).mean()
|
minamo_loss = (1.0 - minamo_sim).mean()
|
||||||
|
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
f"{minamo_loss.item():.8f}, {class_loss.item():.8f}, {entrance_loss.item():.8f}, {count_loss.item():.8f}"
|
f"{minamo_loss.item():.12f}, {class_loss.item():.12f}, {entrance_loss.item():.12f}, {count_loss.item():.12f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
losses = [
|
losses = [
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from shared.image import matrix_to_image_cv
|
|||||||
|
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
EPOCHS_GINKA = 30
|
EPOCHS_GINKA = 30
|
||||||
EPOCHS_MINAMO = 10
|
EPOCHS_MINAMO = 15
|
||||||
SOCKET_PATH = "./tmp/ginka_uds"
|
SOCKET_PATH = "./tmp/ginka_uds"
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
@ -64,6 +64,16 @@ def send_all(sock, data):
|
|||||||
raise RuntimeError("Socket connection broken")
|
raise RuntimeError("Socket connection broken")
|
||||||
total_sent += sent
|
total_sent += sent
|
||||||
|
|
||||||
|
def recv_all(sock: socket.socket, length: int):
|
||||||
|
"""循环接收直到获得指定长度的数据"""
|
||||||
|
data = bytes()
|
||||||
|
while len(data) < length:
|
||||||
|
packet = sock.recv(length - len(data)) # 只请求剩余部分
|
||||||
|
if not packet:
|
||||||
|
raise ConnectionError("连接中断")
|
||||||
|
data += packet
|
||||||
|
return data
|
||||||
|
|
||||||
def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
|
def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
|
||||||
# 数据通讯 node 输出协议,单位字节:
|
# 数据通讯 node 输出协议,单位字节:
|
||||||
# 2 - Tensor count; 2 - Review count. Review is right behind train data;
|
# 2 - Tensor count; 2 - Review count. Review is right behind train data;
|
||||||
@ -75,21 +85,21 @@ def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
|
|||||||
rc_buf = sock.recv(2)
|
rc_buf = sock.recv(2)
|
||||||
tc = struct.unpack('>h', tc_buf)[0]
|
tc = struct.unpack('>h', tc_buf)[0]
|
||||||
rc = struct.unpack('>h', rc_buf)[0]
|
rc = struct.unpack('>h', rc_buf)[0]
|
||||||
count_buf = sock.recv(1 * tc)
|
count_buf = recv_all(sock, 1 * tc)
|
||||||
count: list = struct.unpack(f">{tc}b", count_buf)[0]
|
count: list = struct.unpack(f">{tc}b", count_buf)
|
||||||
N = sum(count)
|
N = sum(count)
|
||||||
sim_buf = sock.recv(2 * 4 * (N + rc))
|
sim_buf = recv_all(sock, 2 * 4 * (N + rc))
|
||||||
com_buf = sock.recv(N * 1 * H * W)
|
com_buf = recv_all(sock, N * 1 * H * W)
|
||||||
review_buf = sock.recv(rc * 2 * H * W) if rc > 0 else bytes()
|
review_buf = recv_all(sock, rc * 2 * H * W) if rc > 0 else bytes()
|
||||||
|
|
||||||
sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)[0]
|
sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)
|
||||||
com = struct.unpack(f">{N * 1 * H * W}b", com_buf)[0]
|
com = struct.unpack(f">{N * 1 * H * W}b", com_buf)
|
||||||
review = struct.unpack(f">{rc * 2 * H * W}", review_buf)[0] if rc > 0 else list()
|
review = struct.unpack(f">{rc * 2 * H * W}", review_buf) if rc > 0 else list()
|
||||||
|
|
||||||
res = list()
|
res = list()
|
||||||
flatten_idx = 0
|
flatten_idx = 0
|
||||||
# 读取当前这一轮生成器的数据
|
# 读取当前这一轮生成器的数据
|
||||||
for idx in range(N):
|
for idx in range(tc):
|
||||||
com_count = count[idx]
|
com_count = count[idx]
|
||||||
for i in range(com_count):
|
for i in range(com_count):
|
||||||
com_start = flatten_idx * H * W
|
com_start = flatten_idx * H * W
|
||||||
@ -98,7 +108,7 @@ def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
|
|||||||
topo_sim = sim[flatten_idx * 2 + 1]
|
topo_sim = sim[flatten_idx * 2 + 1]
|
||||||
com_data = com[com_start:com_end]
|
com_data = com[com_start:com_end]
|
||||||
flatten_idx += 1
|
flatten_idx += 1
|
||||||
com_map = np.fromiter(com_data, np.int8).view(H, W)
|
com_map = np.array(com_data, dtype=np.int8).reshape(H, W)
|
||||||
# map1, map2, vision_similarity, topo_similarity, is_review
|
# map1, map2, vision_similarity, topo_similarity, is_review
|
||||||
res.append((maps[idx], com_map, vis_sim, topo_sim, False))
|
res.append((maps[idx], com_map, vis_sim, topo_sim, False))
|
||||||
|
|
||||||
@ -107,7 +117,7 @@ def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
|
|||||||
def train():
|
def train():
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||||
|
|
||||||
args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json')
|
args = parse_arguments()
|
||||||
|
|
||||||
ginka = GinkaModel()
|
ginka = GinkaModel()
|
||||||
ginka.to(device)
|
ginka.to(device)
|
||||||
@ -119,8 +129,8 @@ def train():
|
|||||||
# 准备数据集
|
# 准备数据集
|
||||||
ginka_dataset = GinkaDataset(args.train, device, minamo)
|
ginka_dataset = GinkaDataset(args.train, device, minamo)
|
||||||
ginka_dataset_val = GinkaDataset(args.validate, device, minamo)
|
ginka_dataset_val = GinkaDataset(args.validate, device, minamo)
|
||||||
minamo_dataset = MinamoGANDataset()
|
minamo_dataset = MinamoGANDataset("datasets/minamo-dataset-1.json")
|
||||||
minamo_dataset_val = MinamoGANDataset()
|
minamo_dataset_val = MinamoGANDataset("datasets/minamo-eval-1.json")
|
||||||
ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
@ -132,7 +142,7 @@ def train():
|
|||||||
criterion_ginka = GinkaLoss(minamo)
|
criterion_ginka = GinkaLoss(minamo)
|
||||||
|
|
||||||
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3)
|
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3)
|
||||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6)
|
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6)
|
||||||
criterion_minamo = MinamoLoss()
|
criterion_minamo = MinamoLoss()
|
||||||
|
|
||||||
# 用于生成图片
|
# 用于生成图片
|
||||||
@ -142,10 +152,16 @@ def train():
|
|||||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
# 与 node 端通讯
|
# 与 node 端通讯
|
||||||
|
if os.path.exists(SOCKET_PATH):
|
||||||
|
os.remove(SOCKET_PATH)
|
||||||
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||||
server.bind(SOCKET_PATH)
|
server.bind(SOCKET_PATH)
|
||||||
server.listen(1)
|
server.listen(1)
|
||||||
|
|
||||||
|
print("Waiting for client connection...")
|
||||||
|
conn, _ = server.accept()
|
||||||
|
print("Client connected.")
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
data = torch.load(args.from_state, map_location=device)
|
data = torch.load(args.from_state, map_location=device)
|
||||||
ginka.load_state_dict(data["model_state"], strict=False)
|
ginka.load_state_dict(data["model_state"], strict=False)
|
||||||
@ -157,11 +173,11 @@ def train():
|
|||||||
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
|
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
|
||||||
criterion_ginka.weight[0] = 0.0
|
criterion_ginka.weight[0] = 0.0
|
||||||
|
|
||||||
for cycle in tqdm(range(args.from_cycle, args.to_cycle)):
|
for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"):
|
||||||
# -------------------- 训练生成器
|
# -------------------- 训练生成器
|
||||||
gen_list: np.ndarray = np.empty(np.int8)
|
gen_list: np.ndarray = np.empty((0, 13, 13), np.int8)
|
||||||
prob_list: np.ndarray = np.empty(np.float32)
|
prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32)
|
||||||
for epoch in tqdm(range(args.epochs), desc="Training Ginka Model"):
|
for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model"):
|
||||||
ginka.train()
|
ginka.train()
|
||||||
minamo.eval()
|
minamo.eval()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
@ -170,7 +186,7 @@ def train():
|
|||||||
if not args.resume and epoch == 10:
|
if not args.resume and epoch == 10:
|
||||||
criterion_ginka.weight[0] = 0.5
|
criterion_ginka.weight[0] = 0.5
|
||||||
|
|
||||||
for batch in ginka_dataloader:
|
for batch in tqdm(ginka_dataloader, leave=False, desc="Epoch Progress"):
|
||||||
# 数据迁移到设备
|
# 数据迁移到设备
|
||||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
||||||
# 前向传播
|
# 前向传播
|
||||||
@ -187,14 +203,14 @@ def train():
|
|||||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
|
||||||
|
|
||||||
# 学习率调整
|
# 学习率调整
|
||||||
scheduler_ginka.step()
|
scheduler_ginka.step(epoch + 1)
|
||||||
|
|
||||||
if (epoch + 1) % 5 == 0:
|
if (epoch + 1) % 5 == 0:
|
||||||
loss_val = 0
|
loss_val = 0
|
||||||
ginka.eval()
|
ginka.eval()
|
||||||
idx = 0
|
idx = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in ginka_dataloader_val:
|
for batch in tqdm(ginka_dataloader_val, leave=False, desc="Validating Ginka Model"):
|
||||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
||||||
output, output_softmax = ginka(feat_vec)
|
output, output_softmax = ginka(feat_vec)
|
||||||
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
|
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||||
@ -202,9 +218,9 @@ def train():
|
|||||||
if epoch + 1 == EPOCHS_GINKA:
|
if epoch + 1 == EPOCHS_GINKA:
|
||||||
# 最后一次验证的时候顺带生成图片
|
# 最后一次验证的时候顺带生成图片
|
||||||
prob = output_softmax.cpu().numpy()
|
prob = output_softmax.cpu().numpy()
|
||||||
np.concatenate((prob_list, prob), axis=1)
|
prob_list = np.concatenate((prob_list, prob), axis=0)
|
||||||
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
||||||
gen_list = np.concatenate((gen_list, map_matrix), axis=1)
|
gen_list = np.concatenate((gen_list, map_matrix), axis=0)
|
||||||
for matrix in map_matrix:
|
for matrix in map_matrix:
|
||||||
image = matrix_to_image_cv(matrix, tile_dict)
|
image = matrix_to_image_cv(matrix, tile_dict)
|
||||||
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
|
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
|
||||||
@ -232,8 +248,8 @@ def train():
|
|||||||
buf.extend(struct.pack('>b', H)) # Map height
|
buf.extend(struct.pack('>b', H)) # Map height
|
||||||
buf.extend(struct.pack('>b', W)) # Map width
|
buf.extend(struct.pack('>b', W)) # Map width
|
||||||
buf.extend(gen_bytes) # Map tensor
|
buf.extend(gen_bytes) # Map tensor
|
||||||
server.sendall(buf)
|
conn.sendall(buf)
|
||||||
data = parse_minamo_data(server, prob_list)
|
data = parse_minamo_data(conn, prob_list)
|
||||||
minamo_dataset.set_data(data)
|
minamo_dataset.set_data(data)
|
||||||
|
|
||||||
# -------------------- 训练判别器
|
# -------------------- 训练判别器
|
||||||
@ -242,7 +258,7 @@ def train():
|
|||||||
minamo.train()
|
minamo.train()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
|
||||||
for batch in minamo_dataloader:
|
for batch in tqdm(minamo_dataloader, leave=False, desc="Epoch Progress"):
|
||||||
map1, map2, vision_simi, topo_simi, graph1, graph2 = parse_minamo_batch(batch)
|
map1, map2, vision_simi, topo_simi, graph1, graph2 = parse_minamo_batch(batch)
|
||||||
|
|
||||||
if map1.shape[0] == 1:
|
if map1.shape[0] == 1:
|
||||||
@ -265,16 +281,16 @@ def train():
|
|||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
ave_loss = total_loss / len(minamo_dataloader)
|
ave_loss = total_loss / len(minamo_dataloader)
|
||||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_minamo.param_groups[0]['lr']):.6f}")
|
||||||
|
|
||||||
scheduler_minamo.step()
|
scheduler_minamo.step(epoch + 1)
|
||||||
|
|
||||||
# 每十轮推理一次验证集
|
# 每十轮推理一次验证集
|
||||||
if (epoch + 1) % 5 == 0:
|
if (epoch + 1) % 5 == 0:
|
||||||
minamo.eval()
|
minamo.eval()
|
||||||
val_loss = 0
|
val_loss = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for val_batch in tqdm(minamo_dataloader_val, leave=False):
|
for val_batch in tqdm(minamo_dataloader_val, leave=False, desc="Validating Minamo Model"):
|
||||||
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch)
|
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch)
|
||||||
|
|
||||||
vision_feat1, topo_feat1 = minamo(map1_val, graph1)
|
vision_feat1, topo_feat1 = minamo(map1_val, graph1)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
class MinamoLoss(nn.Module):
|
class MinamoLoss(nn.Module):
|
||||||
def __init__(self, vision_weight=0.2, topo_weight=0.8):
|
def __init__(self, vision_weight=0.2, topo_weight=0.8):
|
||||||
@ -9,7 +10,7 @@ class MinamoLoss(nn.Module):
|
|||||||
|
|
||||||
def forward(self, vis_pred, topo_pred, vis_true, topo_true):
|
def forward(self, vis_pred, topo_pred, vis_true, topo_true):
|
||||||
# print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape)
|
# print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape)
|
||||||
# print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item())
|
# tqdm.write(f"{vis_pred[0].item():.12f}, {vis_true[0].item():.12f}, {topo_pred[0].item():.12f}, {topo_true[0].item():.12f}")
|
||||||
vis_loss = self.loss(vis_pred, vis_true)
|
vis_loss = self.loss(vis_pred, vis_true)
|
||||||
topo_loss = self.loss(topo_pred, topo_true)
|
topo_loss = self.loss(topo_pred, topo_true)
|
||||||
# print(vis_loss.item(), topo_loss.item())
|
# print(vis_loss.item(), topo_loss.item())
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user