Change loss function code
para_loss = F.binary_cross_entropy_with_logits(para_outputs.view(-1), para_labels, reduction='mean') / args.batch_size
and
sts_loss = F.mse_loss(sts_outputs.squeeze(), sts_labels.view(-1).float(), reduction='mean') / args.batch_size
delete args.batch_size
from each