mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 支持流式传输,改进数据集
This commit is contained in:
parent
8dea79a9f0
commit
325eb599c3
104
data/src/gan.ts
104
data/src/gan.ts
@ -8,7 +8,7 @@ const [refer] = process.argv.slice(2);
|
||||
|
||||
let id = 0;
|
||||
|
||||
function readMap(count: number, buffer: Buffer, h: number, w: number) {
|
||||
function readMap(count: number, arr: number[], h: number, w: number) {
|
||||
const area = w * h;
|
||||
|
||||
const maps: number[][][] = Array.from<number[][]>({
|
||||
@ -19,7 +19,7 @@ function readMap(count: number, buffer: Buffer, h: number, w: number) {
|
||||
});
|
||||
});
|
||||
|
||||
buffer.subarray(4).forEach((v, i) => {
|
||||
arr.forEach((v, i) => {
|
||||
const n = Math.floor(i / area);
|
||||
const y = Math.floor((i % area) / w);
|
||||
const x = i % w;
|
||||
@ -35,7 +35,7 @@ function generateGANData(
|
||||
map: number[][]
|
||||
) {
|
||||
const id2 = `$${id++}`;
|
||||
const toTrain = chooseFrom(keys, 4);
|
||||
const toTrain = chooseFrom(keys, 30);
|
||||
const data = toTrain.map<MinamoTrainData[]>(v => {
|
||||
const floor = refer.get(v);
|
||||
if (!floor) return [];
|
||||
@ -48,28 +48,67 @@ function generateGANData(
|
||||
return data.flat();
|
||||
}
|
||||
|
||||
const enum ReceiverStatus {
|
||||
Header,
|
||||
Content
|
||||
}
|
||||
|
||||
class DataReceiver {
|
||||
static active?: DataReceiver
|
||||
/** 接收状态 */
|
||||
private status: ReceiverStatus = ReceiverStatus.Header;
|
||||
|
||||
private received: number[] = []
|
||||
private count: number = 0;
|
||||
private h: number = 0;
|
||||
private w: number = 0;
|
||||
|
||||
receive(buf: Buffer): [number[][][], number, number, number] | null {
|
||||
// 数据通讯 node 输入协议,单位字节:
|
||||
// 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
|
||||
switch (this.status) {
|
||||
case ReceiverStatus.Header:
|
||||
this.count = buf.readInt16BE();
|
||||
this.h = buf.readInt8(2);
|
||||
this.w = buf.readInt8(3);
|
||||
this.received.push(...buf.subarray(4));
|
||||
this.status = ReceiverStatus.Content;
|
||||
break;
|
||||
case ReceiverStatus.Content:
|
||||
this.received.push(...buf);
|
||||
break
|
||||
}
|
||||
if (this.received.length === this.count * this.h * this.w) {
|
||||
delete DataReceiver.active;
|
||||
return [readMap(this.count, this.received, this.h, this.w), this.count, this.h, this.w];
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
static check(buf: Buffer) {
|
||||
if (this.active) {
|
||||
return this.active.receive(buf);
|
||||
} else {
|
||||
this.active = new DataReceiver();
|
||||
return this.active.receive(buf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(async () => {
|
||||
const referTower = await readOne(refer);
|
||||
const keys = [...referTower.keys()];
|
||||
|
||||
const client = createConnection(SOCKET_FILE, () => {
|
||||
console.log(`UDS IPC connected successfully.`);
|
||||
// 发送四字节数据表示连接成功
|
||||
client.write(new Uint8Array([0x00, 0x00, 0x00, 0x00]));
|
||||
});
|
||||
|
||||
client.on('data', buffer => {
|
||||
// 暂时不考虑流式传输,如果后续数据量非常大,再考虑优化
|
||||
// 数据通讯 node 输入协议,单位字节:
|
||||
// 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
|
||||
const count = buffer.readInt16BE();
|
||||
if (buffer.length - 4 !== count * 32 * 32) {
|
||||
client.write(`ERROR: byte length not match.`);
|
||||
return [];
|
||||
}
|
||||
const h = buffer.readInt8(2);
|
||||
const w = buffer.readInt8(3);
|
||||
const map = readMap(count, buffer, h, w);
|
||||
const data = DataReceiver.check(buffer);
|
||||
if (!data) return;
|
||||
|
||||
const [map, count, h, w] = data;
|
||||
const simData = map.map(v => generateGANData(keys, referTower, v));
|
||||
const rc = 0;
|
||||
const compareData = simData.flat();
|
||||
@ -83,12 +122,23 @@ function generateGANData(
|
||||
const toSend = Buffer.alloc(
|
||||
2 + // Tensor count
|
||||
2 + // Review count
|
||||
count + // Compare count
|
||||
2 * (count + rc) + // Similarity data
|
||||
1 * count + // Compare count
|
||||
2 * 4 * (compareData.length + rc) + // Similarity data
|
||||
compareData.length * 1 * h * w + // Compare map
|
||||
rc * 2 * h * w, // Review map
|
||||
0
|
||||
);
|
||||
console.log(
|
||||
2,
|
||||
2,
|
||||
count,
|
||||
2 * 4 * (compareData.length + rc),
|
||||
compareData.length * 1 * h * w,
|
||||
rc * 2 * h * w,
|
||||
compareData.length,
|
||||
rc
|
||||
);
|
||||
|
||||
let offset = 0;
|
||||
toSend.writeInt16BE(count); // Tensor count
|
||||
toSend.writeInt16BE(0, 2); // Review count
|
||||
@ -98,9 +148,11 @@ function generateGANData(
|
||||
simData.map(v => v.length),
|
||||
offset
|
||||
);
|
||||
offset += count;
|
||||
offset += 1 * count;
|
||||
// Similarity data
|
||||
compareData.forEach(v => {
|
||||
// console.log(v.visionSimilarity, v.topoSimilarity);
|
||||
|
||||
toSend.writeFloatBE(v.visionSimilarity, offset);
|
||||
offset += 4;
|
||||
toSend.writeFloatBE(v.topoSimilarity, offset);
|
||||
@ -108,15 +160,17 @@ function generateGANData(
|
||||
});
|
||||
// Compare map
|
||||
toSend.set(
|
||||
compareData.map(v => v.map1).flat(2),
|
||||
new Uint8Array(compareData.map(v => v.map1).flat(3)),
|
||||
offset // Set from Compare map
|
||||
);
|
||||
offset += compareData.length * 1 * h * w;
|
||||
// Review map
|
||||
toSend.set(
|
||||
reviewData.map(v => [v.map1, v.map2]).flat(3),
|
||||
offset // Set from last chunk
|
||||
);
|
||||
if (reviewData.length > 0) {
|
||||
// Review map
|
||||
toSend.set(
|
||||
new Uint8Array(reviewData.map(v => [v.map1, v.map2]).flat(4)),
|
||||
offset // Set from last chunk
|
||||
);
|
||||
}
|
||||
|
||||
client.write(toSend);
|
||||
});
|
||||
|
||||
@ -136,7 +136,7 @@ function generateSimilarData(id: string, map: number[][]) {
|
||||
// 生成最多两个微调地图
|
||||
const width = map[0].length;
|
||||
const height = map.length;
|
||||
const num = Math.floor(Math.random() * 1);
|
||||
const num = Math.floor(Math.random() * 2);
|
||||
const res: [id: string, data: MinamoTrainData][] = [];
|
||||
|
||||
for (let i = 0; i < num; i++) {
|
||||
|
||||
@ -51,10 +51,14 @@ class GinkaDataset(Dataset):
|
||||
class MinamoGANDataset(Dataset):
|
||||
def __init__(self, refer_data_path):
|
||||
self.refer = load_minamo_gan_data(load_data(refer_data_path))
|
||||
self.data = list().extend(self.refer)
|
||||
self.data = list()
|
||||
self.data.extend(random.sample(self.refer, 1000))
|
||||
|
||||
def set_data(self, data: list):
|
||||
self.data = data.extend(self.refer)
|
||||
self.data.clear()
|
||||
self.data.extend(data)
|
||||
k = min(len(data) / 4, len(self.refer))
|
||||
self.data.extend(random.sample(self.refer, int(k)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -63,8 +67,8 @@ class MinamoGANDataset(Dataset):
|
||||
item = self.data[idx]
|
||||
|
||||
map1, map2, vis_sim, topo_sim, review = item
|
||||
map1 = torch.ShortTensor(map1)
|
||||
map2 = torch.ShortTensor(map2)
|
||||
map1 = torch.LongTensor(map1)
|
||||
map2 = torch.LongTensor(map2)
|
||||
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
|
||||
if review:
|
||||
map1 = F.one_hot(map1, num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
|
||||
@ -255,7 +255,7 @@ class GinkaLoss(nn.Module):
|
||||
minamo_loss = (1.0 - minamo_sim).mean()
|
||||
|
||||
tqdm.write(
|
||||
f"{minamo_loss.item():.8f}, {class_loss.item():.8f}, {entrance_loss.item():.8f}, {count_loss.item():.8f}"
|
||||
f"{minamo_loss.item():.12f}, {class_loss.item():.12f}, {entrance_loss.item():.12f}, {count_loss.item():.12f}"
|
||||
)
|
||||
|
||||
losses = [
|
||||
|
||||
@ -19,7 +19,7 @@ from shared.image import matrix_to_image_cv
|
||||
|
||||
BATCH_SIZE = 32
|
||||
EPOCHS_GINKA = 30
|
||||
EPOCHS_MINAMO = 10
|
||||
EPOCHS_MINAMO = 15
|
||||
SOCKET_PATH = "./tmp/ginka_uds"
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@ -64,6 +64,16 @@ def send_all(sock, data):
|
||||
raise RuntimeError("Socket connection broken")
|
||||
total_sent += sent
|
||||
|
||||
def recv_all(sock: socket.socket, length: int):
|
||||
"""循环接收直到获得指定长度的数据"""
|
||||
data = bytes()
|
||||
while len(data) < length:
|
||||
packet = sock.recv(length - len(data)) # 只请求剩余部分
|
||||
if not packet:
|
||||
raise ConnectionError("连接中断")
|
||||
data += packet
|
||||
return data
|
||||
|
||||
def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
|
||||
# 数据通讯 node 输出协议,单位字节:
|
||||
# 2 - Tensor count; 2 - Review count. Review is right behind train data;
|
||||
@ -75,21 +85,21 @@ def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
|
||||
rc_buf = sock.recv(2)
|
||||
tc = struct.unpack('>h', tc_buf)[0]
|
||||
rc = struct.unpack('>h', rc_buf)[0]
|
||||
count_buf = sock.recv(1 * tc)
|
||||
count: list = struct.unpack(f">{tc}b", count_buf)[0]
|
||||
count_buf = recv_all(sock, 1 * tc)
|
||||
count: list = struct.unpack(f">{tc}b", count_buf)
|
||||
N = sum(count)
|
||||
sim_buf = sock.recv(2 * 4 * (N + rc))
|
||||
com_buf = sock.recv(N * 1 * H * W)
|
||||
review_buf = sock.recv(rc * 2 * H * W) if rc > 0 else bytes()
|
||||
sim_buf = recv_all(sock, 2 * 4 * (N + rc))
|
||||
com_buf = recv_all(sock, N * 1 * H * W)
|
||||
review_buf = recv_all(sock, rc * 2 * H * W) if rc > 0 else bytes()
|
||||
|
||||
sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)[0]
|
||||
com = struct.unpack(f">{N * 1 * H * W}b", com_buf)[0]
|
||||
review = struct.unpack(f">{rc * 2 * H * W}", review_buf)[0] if rc > 0 else list()
|
||||
sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)
|
||||
com = struct.unpack(f">{N * 1 * H * W}b", com_buf)
|
||||
review = struct.unpack(f">{rc * 2 * H * W}", review_buf) if rc > 0 else list()
|
||||
|
||||
res = list()
|
||||
flatten_idx = 0
|
||||
# 读取当前这一轮生成器的数据
|
||||
for idx in range(N):
|
||||
for idx in range(tc):
|
||||
com_count = count[idx]
|
||||
for i in range(com_count):
|
||||
com_start = flatten_idx * H * W
|
||||
@ -98,7 +108,7 @@ def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
|
||||
topo_sim = sim[flatten_idx * 2 + 1]
|
||||
com_data = com[com_start:com_end]
|
||||
flatten_idx += 1
|
||||
com_map = np.fromiter(com_data, np.int8).view(H, W)
|
||||
com_map = np.array(com_data, dtype=np.int8).reshape(H, W)
|
||||
# map1, map2, vision_similarity, topo_similarity, is_review
|
||||
res.append((maps[idx], com_map, vis_sim, topo_sim, False))
|
||||
|
||||
@ -107,7 +117,7 @@ def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json')
|
||||
args = parse_arguments()
|
||||
|
||||
ginka = GinkaModel()
|
||||
ginka.to(device)
|
||||
@ -119,8 +129,8 @@ def train():
|
||||
# 准备数据集
|
||||
ginka_dataset = GinkaDataset(args.train, device, minamo)
|
||||
ginka_dataset_val = GinkaDataset(args.validate, device, minamo)
|
||||
minamo_dataset = MinamoGANDataset()
|
||||
minamo_dataset_val = MinamoGANDataset()
|
||||
minamo_dataset = MinamoGANDataset("datasets/minamo-dataset-1.json")
|
||||
minamo_dataset_val = MinamoGANDataset("datasets/minamo-eval-1.json")
|
||||
ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
||||
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
@ -132,7 +142,7 @@ def train():
|
||||
criterion_ginka = GinkaLoss(minamo)
|
||||
|
||||
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3)
|
||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6)
|
||||
criterion_minamo = MinamoLoss()
|
||||
|
||||
# 用于生成图片
|
||||
@ -142,10 +152,16 @@ def train():
|
||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||
|
||||
# 与 node 端通讯
|
||||
if os.path.exists(SOCKET_PATH):
|
||||
os.remove(SOCKET_PATH)
|
||||
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server.bind(SOCKET_PATH)
|
||||
server.listen(1)
|
||||
|
||||
print("Waiting for client connection...")
|
||||
conn, _ = server.accept()
|
||||
print("Client connected.")
|
||||
|
||||
if args.resume:
|
||||
data = torch.load(args.from_state, map_location=device)
|
||||
ginka.load_state_dict(data["model_state"], strict=False)
|
||||
@ -157,11 +173,11 @@ def train():
|
||||
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
|
||||
criterion_ginka.weight[0] = 0.0
|
||||
|
||||
for cycle in tqdm(range(args.from_cycle, args.to_cycle)):
|
||||
for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"):
|
||||
# -------------------- 训练生成器
|
||||
gen_list: np.ndarray = np.empty(np.int8)
|
||||
prob_list: np.ndarray = np.empty(np.float32)
|
||||
for epoch in tqdm(range(args.epochs), desc="Training Ginka Model"):
|
||||
gen_list: np.ndarray = np.empty((0, 13, 13), np.int8)
|
||||
prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32)
|
||||
for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model"):
|
||||
ginka.train()
|
||||
minamo.eval()
|
||||
total_loss = 0
|
||||
@ -170,7 +186,7 @@ def train():
|
||||
if not args.resume and epoch == 10:
|
||||
criterion_ginka.weight[0] = 0.5
|
||||
|
||||
for batch in ginka_dataloader:
|
||||
for batch in tqdm(ginka_dataloader, leave=False, desc="Epoch Progress"):
|
||||
# 数据迁移到设备
|
||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
||||
# 前向传播
|
||||
@ -187,14 +203,14 @@ def train():
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
|
||||
|
||||
# 学习率调整
|
||||
scheduler_ginka.step()
|
||||
scheduler_ginka.step(epoch + 1)
|
||||
|
||||
if (epoch + 1) % 5 == 0:
|
||||
loss_val = 0
|
||||
ginka.eval()
|
||||
idx = 0
|
||||
with torch.no_grad():
|
||||
for batch in ginka_dataloader_val:
|
||||
for batch in tqdm(ginka_dataloader_val, leave=False, desc="Validating Ginka Model"):
|
||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
||||
output, output_softmax = ginka(feat_vec)
|
||||
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
@ -202,9 +218,9 @@ def train():
|
||||
if epoch + 1 == EPOCHS_GINKA:
|
||||
# 最后一次验证的时候顺带生成图片
|
||||
prob = output_softmax.cpu().numpy()
|
||||
np.concatenate((prob_list, prob), axis=1)
|
||||
prob_list = np.concatenate((prob_list, prob), axis=0)
|
||||
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
||||
gen_list = np.concatenate((gen_list, map_matrix), axis=1)
|
||||
gen_list = np.concatenate((gen_list, map_matrix), axis=0)
|
||||
for matrix in map_matrix:
|
||||
image = matrix_to_image_cv(matrix, tile_dict)
|
||||
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
|
||||
@ -232,8 +248,8 @@ def train():
|
||||
buf.extend(struct.pack('>b', H)) # Map height
|
||||
buf.extend(struct.pack('>b', W)) # Map width
|
||||
buf.extend(gen_bytes) # Map tensor
|
||||
server.sendall(buf)
|
||||
data = parse_minamo_data(server, prob_list)
|
||||
conn.sendall(buf)
|
||||
data = parse_minamo_data(conn, prob_list)
|
||||
minamo_dataset.set_data(data)
|
||||
|
||||
# -------------------- 训练判别器
|
||||
@ -242,7 +258,7 @@ def train():
|
||||
minamo.train()
|
||||
total_loss = 0
|
||||
|
||||
for batch in minamo_dataloader:
|
||||
for batch in tqdm(minamo_dataloader, leave=False, desc="Epoch Progress"):
|
||||
map1, map2, vision_simi, topo_simi, graph1, graph2 = parse_minamo_batch(batch)
|
||||
|
||||
if map1.shape[0] == 1:
|
||||
@ -265,16 +281,16 @@ def train():
|
||||
total_loss += loss.item()
|
||||
|
||||
ave_loss = total_loss / len(minamo_dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_minamo.param_groups[0]['lr']):.6f}")
|
||||
|
||||
scheduler_minamo.step()
|
||||
scheduler_minamo.step(epoch + 1)
|
||||
|
||||
# 每十轮推理一次验证集
|
||||
if (epoch + 1) % 5 == 0:
|
||||
minamo.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for val_batch in tqdm(minamo_dataloader_val, leave=False):
|
||||
for val_batch in tqdm(minamo_dataloader_val, leave=False, desc="Validating Minamo Model"):
|
||||
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch)
|
||||
|
||||
vision_feat1, topo_feat1 = minamo(map1_val, graph1)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
class MinamoLoss(nn.Module):
|
||||
def __init__(self, vision_weight=0.2, topo_weight=0.8):
|
||||
@ -9,7 +10,7 @@ class MinamoLoss(nn.Module):
|
||||
|
||||
def forward(self, vis_pred, topo_pred, vis_true, topo_true):
|
||||
# print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape)
|
||||
# print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item())
|
||||
# tqdm.write(f"{vis_pred[0].item():.12f}, {vis_true[0].item():.12f}, {topo_pred[0].item():.12f}, {topo_true[0].item():.12f}")
|
||||
vis_loss = self.loss(vis_pred, vis_true)
|
||||
topo_loss = self.loss(topo_pred, topo_true)
|
||||
# print(vis_loss.item(), topo_loss.item())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user