diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index 83561f4..f4397b2 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -408,6 +408,10 @@ class RNNGinkaLoss: pass def rnn_loss(self, fake, target): + """ + fake: [B, C, H, W] + target: [B, H, W] + """ weight = torch.ones(self.num_classes) weight[0] = 0.3 weight[1] = 0.5