diff --git a/train.py b/train.py index 81538d0..4aef411 100644 --- a/train.py +++ b/train.py @@ -1,36 +1,44 @@ import torch import torch.nn as nn import torch.nn.functional as F -import util.datasets as dataset +import datasets as dataset import torch.utils.data import sklearn import numpy as np from option import args from model.tgat import TGAT -from util.utils import EarlyStopMonitor, logger_config +from utils import EarlyStopMonitor, logger_config from tqdm import tqdm import datetime, os import json +import random -crossentropyloss = nn.CrossEntropyLoss() -def create_dataloader(dataset_type, config, collate_fn): - datasets = dataset.DygDataset(config, dataset_type, - [(1522728000000000000, 1522987200000000000), - (1522987200000000000, 1523073600000000000), - (1523073600000000000, 1523073600000000000)]) - return torch.utils.data.DataLoader( - dataset=datasets, - batch_size=config.batch_size, - shuffle=False, - num_workers=config.num_data_workers, - pin_memory=True, - collate_fn=collate_fn.dyg_collate_fn - ) +import numpy as np +import torch.backends.cudnn as cudnn + +seed = 41 +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + cudnn.deterministic = True + cudnn.benchmark = False + + # 如果在使用 DataLoader 并且采用了多进程加载数据的方式 + #torch.set_deterministic(True) -# 创建 DataLoader +# 在训练开始前设置随机种子 +set_random_seed(seed) +def worker_init_fn(worker_id, seed): + worker_seed = seed + worker_id + np.random.seed(worker_seed) + random.seed(worker_seed) -def criterion(prediction_dict, labels, model, config): + +crossentropyloss = nn.CrossEntropyLoss() +def criterion(prediction_dict, labels,model, config): for key, value in prediction_dict.items(): if key != 'root_embedding' and key != 'dst_embedding' and key != 'group' and key != 'dev': @@ -44,29 +52,25 @@ def criterion(prediction_dict, labels, model, config): loss = loss_classify.clone() loss_anomaly = torch.Tensor(0).to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu")) - alpha = config.anomaly_alpha # 1e-1 - - loss_anomaly = model.gdn.dev_loss(torch.squeeze(labels), torch.squeeze(prediction_dict['anom_score']), torch.squeeze(prediction_dict['time'])) + loss_supc = torch.Tensor(0).to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu")) + alpha = config.anomaly_alpha + beta = config.supc_alpha + loss_anomaly = model.gdn.dev_loss(torch.squeeze(prediction_dict['pre_lable']), torch.squeeze(prediction_dict['anom_score']), torch.squeeze(prediction_dict['time'])) loss_anomaly = torch.mean(loss_anomaly) - loss += alpha * loss_anomaly + loss_supc = model.suploss(prediction_dict['root_embedding'],prediction_dict['dst_embedding'] ,prediction_dict['group'], prediction_dict['dev']) + loss_supc = loss_supc.mean() + loss += alpha * loss_anomaly + beta * loss_supc - return loss, loss_classify, loss_anomaly + return loss, loss_classify, loss_anomaly, loss_supc config = args device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -# log file name set -now_time = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') -log_base_path = f"{os.getcwd()}/train_log" -file_list = os.listdir(log_base_path) -max_num = [0] # [int(fl.split("_")[0]) for fl in file_list if len(fl.split("_"))>2] + [-1] -log_base_path = f"{log_base_path}/{max(max_num)+1}_{now_time}" -# log and path -get_checkpoint_path = lambda epoch: f'{log_base_path}saved_checkpoints/{args.data_set}-{args.mode}-{args.module_type}-{args.mask_ratio}-{epoch}.pth' -logger = logger_config(log_path=f'{log_base_path}/log.txt', logging_name='TFLAG') -logger.info(config) + + +dataset_train = dataset.DygDataset(config, 'train',[(1522728000000000000,1522987200000000000),(1522987200000000000,1523073600000000000),(1523073600000000000,1523073600000000000)]) gpus = None if config.gpus == 0 else config.gpus @@ -76,9 +80,19 @@ def criterion(prediction_dict, labels, model, config): model = backbone.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) -loader_train = create_dataloader('train', config, collate_fn) -loader_valid = create_dataloader('valid', config, collate_fn) -loader_test = create_dataloader('test', config, collate_fn) +print(len(dataset_train)) +#print(len(dataset_test)) +loader_train = torch.utils.data.DataLoader( + dataset=dataset_train, + batch_size=config.batch_size, + shuffle=False, + #shuffle=True, + num_workers=config.num_data_workers, + pin_memory=True, + #sampler=dataset.RandomDropSampler(dataset_train, 0.75), #for reddit + collate_fn=collate_fn.dyg_collate_fn, + worker_init_fn=lambda worker_id: worker_init_fn(worker_id, seed) +) max_val_auc, max_test_auc = 0.0, 0.0 @@ -90,9 +104,7 @@ def criterion(prediction_dict, labels, model, config): m_loss, auc = [], [] loss_anomaly_list = [] loss_class_list = [] - dev_score_list = np.array([]) - dev_label_list = np.array([]) - + loss_supc_list = [] with tqdm(total=len(loader_train)) as t: for batch_sample in loader_train: count_flag += 1 @@ -111,7 +123,8 @@ def criterion(prediction_dict, labels, model, config): ) y = batch_sample['labels'].to(device) - loss, loss_classify, loss_anomaly = criterion(x, y, model, config) + loss, loss_classify, loss_anomaly, loss_supc = criterion(x, y ,model, config) + loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1, norm_type=2) optimizer.step() @@ -120,17 +133,18 @@ def criterion(prediction_dict, labels, model, config): with torch.no_grad(): model = model.eval() m_loss.append(loss.item()) - # pred_score = x['logits'].sigmoid() loss_class_list.append(loss_classify.detach().clone().cpu().numpy().flatten()) + loss_anomaly_list.append(loss_anomaly.detach().clone().cpu().numpy().flatten()) - t.set_postfix(loss=np.mean(loss_class_list), loss_anomaly=np.mean(loss_anomaly_list)) - t.update(1) + loss_supc_list.append(loss_supc.detach().clone().cpu().numpy().flatten()) + t.set_postfix(loss=np.mean(loss_class_list), loss_anomaly=np.mean(loss_anomaly_list), loss_sup=np.mean(loss_supc_list)) + t.update(1) - logger.info('\n epoch: {}'.format(epoch)) - logger.info(f'train mean loss:{np.mean(m_loss)}, class loss: {np.mean(loss_class_list)}, anomaly loss: {np.mean(loss_anomaly_list)}') - torch.save(model.state_dict(), "./checkpoint-{}".format(epoch)) \ No newline at end of file + print('\n epoch: {}'.format(epoch)) + print(f'train mean loss:{np.mean(m_loss)}, class loss: {np.mean(loss_class_list)}, anomaly loss: {np.mean(loss_anomaly_list)}, sup loss: {np.mean(loss_supc_list)}') + torch.save(model.state_dict(), "./checkpoint-{}".format(epoch))