|
|
|
@ -43,9 +43,9 @@ class ErnieLinear(nn.Layer):
|
|
|
|
|
num_classes, int
|
|
|
|
|
) and num_classes > 0, 'Argument `num_classes` must be an integer.'
|
|
|
|
|
self.ernie = ErnieForTokenClassification.from_pretrained(
|
|
|
|
|
pretrained_token, num_classes=num_classes, **kwargs)
|
|
|
|
|
pretrained_token, num_labels=num_classes, **kwargs)
|
|
|
|
|
|
|
|
|
|
self.num_classes = self.ernie.num_classes
|
|
|
|
|
self.num_classes = self.ernie.num_labels
|
|
|
|
|
self.softmax = nn.Softmax()
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|