loss=loss_func(outputs,labels)
_,pred=torch.max(outputs.data,dim=1)
acc=pred.eq(labels.data).cpu().sum()
one=torch.ones_like(labels)
zero=torch.zeros_like(labels)
tn=((labels==zero)*(pred==zero)).sum()
tp=((labels==one)*(pred==one)).sum()
fp=((labels==zero)*(pred==one)).sum()
fn=((labels==one)*(pred==zero)).sum()
test_sum_fn+=fn.item()
test_sum_fp+=fp.item()
test_sum_tn+=tn.item()
test_sum_tp+=tp.item()
sum_loss+=loss.item()
sum_correct+=acc.item()
test_precision=test_sum_tp*1.0/(test_sum_fp+test_sum_tp)
test_recall=test_sum_tp*1.0/(test_sum_fn+test_sum_tp)
test_loss=sum_loss*1.0/len(testDataLoader)
test_correct=sum_correct*1.0/len(testDataLoader)/batch_size
writer.add_scalar(“testloss“,test_loss,global_step=epoch+1)
writer.add_scalar(“testcorrect“,test_correct,gl