Source code for dalib.modules.domain_discriminator
"""
@author: Junguang Jiang
@contact: [email protected]
"""
from typing import List, Dict
import torch.nn as nn
__all__ = ['DomainDiscriminator']
[docs]class DomainDiscriminator(nn.Sequential):
r"""Domain discriminator model from
`Domain-Adversarial Training of Neural Networks (ICML 2015) <https://arxiv.org/abs/1505.07818>`_
Distinguish whether the input features come from the source domain or the target domain.
The source domain label is 1 and the target domain label is 0.
Args:
in_feature (int): dimension of the input feature
hidden_size (int): dimension of the hidden features
batch_norm (bool): whether use :class:`~torch.nn.BatchNorm1d`.
Use :class:`~torch.nn.Dropout` if ``batch_norm`` is False. Default: True.
Shape:
- Inputs: (minibatch, `in_feature`)
- Outputs: :math:`(minibatch, 1)`
"""
def __init__(self, in_feature: int, hidden_size: int, batch_norm=True):
if batch_norm:
super(DomainDiscriminator, self).__init__(
nn.Linear(in_feature, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1),
nn.Sigmoid()
)
else:
super(DomainDiscriminator, self).__init__(
nn.Linear(in_feature, hidden_size),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(hidden_size, 1),
nn.Sigmoid()
)
def get_parameters(self) -> List[Dict]:
return [{"params": self.parameters(), "lr": 1.}]