Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 60 additions & 46 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,44 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import util.datasets as dataset
import datasets as dataset
import torch.utils.data
import sklearn
import numpy as np
from option import args
from model.tgat import TGAT
from util.utils import EarlyStopMonitor, logger_config
from utils import EarlyStopMonitor, logger_config
from tqdm import tqdm
import datetime, os
import json
import random

crossentropyloss = nn.CrossEntropyLoss()
def create_dataloader(dataset_type, config, collate_fn):
datasets = dataset.DygDataset(config, dataset_type,
[(1522728000000000000, 1522987200000000000),
(1522987200000000000, 1523073600000000000),
(1523073600000000000, 1523073600000000000)])
return torch.utils.data.DataLoader(
dataset=datasets,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_data_workers,
pin_memory=True,
collate_fn=collate_fn.dyg_collate_fn
)
import numpy as np
import torch.backends.cudnn as cudnn

seed = 41
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.deterministic = True
cudnn.benchmark = False

# 如果在使用 DataLoader 并且采用了多进程加载数据的方式
#torch.set_deterministic(True)

# 创建 DataLoader
# 在训练开始前设置随机种子
set_random_seed(seed)

def worker_init_fn(worker_id, seed):
worker_seed = seed + worker_id
np.random.seed(worker_seed)
random.seed(worker_seed)

def criterion(prediction_dict, labels, model, config):

crossentropyloss = nn.CrossEntropyLoss()
def criterion(prediction_dict, labels,model, config):

for key, value in prediction_dict.items():
if key != 'root_embedding' and key != 'dst_embedding' and key != 'group' and key != 'dev':
Expand All @@ -44,29 +52,25 @@ def criterion(prediction_dict, labels, model, config):

loss = loss_classify.clone()
loss_anomaly = torch.Tensor(0).to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
alpha = config.anomaly_alpha # 1e-1

loss_anomaly = model.gdn.dev_loss(torch.squeeze(labels), torch.squeeze(prediction_dict['anom_score']), torch.squeeze(prediction_dict['time']))
loss_supc = torch.Tensor(0).to(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
alpha = config.anomaly_alpha
beta = config.supc_alpha
loss_anomaly = model.gdn.dev_loss(torch.squeeze(prediction_dict['pre_lable']), torch.squeeze(prediction_dict['anom_score']), torch.squeeze(prediction_dict['time']))
loss_anomaly = torch.mean(loss_anomaly)
loss += alpha * loss_anomaly
loss_supc = model.suploss(prediction_dict['root_embedding'],prediction_dict['dst_embedding'] ,prediction_dict['group'], prediction_dict['dev'])
loss_supc = loss_supc.mean()
loss += alpha * loss_anomaly + beta * loss_supc


return loss, loss_classify, loss_anomaly
return loss, loss_classify, loss_anomaly, loss_supc



config = args
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# log file name set
now_time = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
log_base_path = f"{os.getcwd()}/train_log"
file_list = os.listdir(log_base_path)
max_num = [0] # [int(fl.split("_")[0]) for fl in file_list if len(fl.split("_"))>2] + [-1]
log_base_path = f"{log_base_path}/{max(max_num)+1}_{now_time}"
# log and path
get_checkpoint_path = lambda epoch: f'{log_base_path}saved_checkpoints/{args.data_set}-{args.mode}-{args.module_type}-{args.mask_ratio}-{epoch}.pth'
logger = logger_config(log_path=f'{log_base_path}/log.txt', logging_name='TFLAG')
logger.info(config)


dataset_train = dataset.DygDataset(config, 'train',[(1522728000000000000,1522987200000000000),(1522987200000000000,1523073600000000000),(1523073600000000000,1523073600000000000)])

gpus = None if config.gpus == 0 else config.gpus

Expand All @@ -76,9 +80,19 @@ def criterion(prediction_dict, labels, model, config):
model = backbone.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

loader_train = create_dataloader('train', config, collate_fn)
loader_valid = create_dataloader('valid', config, collate_fn)
loader_test = create_dataloader('test', config, collate_fn)
print(len(dataset_train))
#print(len(dataset_test))
loader_train = torch.utils.data.DataLoader(
dataset=dataset_train,
batch_size=config.batch_size,
shuffle=False,
#shuffle=True,
num_workers=config.num_data_workers,
pin_memory=True,
#sampler=dataset.RandomDropSampler(dataset_train, 0.75), #for reddit
collate_fn=collate_fn.dyg_collate_fn,
worker_init_fn=lambda worker_id: worker_init_fn(worker_id, seed)
)


max_val_auc, max_test_auc = 0.0, 0.0
Expand All @@ -90,9 +104,7 @@ def criterion(prediction_dict, labels, model, config):
m_loss, auc = [], []
loss_anomaly_list = []
loss_class_list = []
dev_score_list = np.array([])
dev_label_list = np.array([])

loss_supc_list = []
with tqdm(total=len(loader_train)) as t:
for batch_sample in loader_train:
count_flag += 1
Expand All @@ -111,7 +123,8 @@ def criterion(prediction_dict, labels, model, config):
)

y = batch_sample['labels'].to(device)
loss, loss_classify, loss_anomaly = criterion(x, y, model, config)
loss, loss_classify, loss_anomaly, loss_supc = criterion(x, y ,model, config)

loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1, norm_type=2)
optimizer.step()
Expand All @@ -120,17 +133,18 @@ def criterion(prediction_dict, labels, model, config):
with torch.no_grad():
model = model.eval()
m_loss.append(loss.item())
# pred_score = x['logits'].sigmoid()


loss_class_list.append(loss_classify.detach().clone().cpu().numpy().flatten())

loss_anomaly_list.append(loss_anomaly.detach().clone().cpu().numpy().flatten())
t.set_postfix(loss=np.mean(loss_class_list), loss_anomaly=np.mean(loss_anomaly_list))
t.update(1)
loss_supc_list.append(loss_supc.detach().clone().cpu().numpy().flatten())
t.set_postfix(loss=np.mean(loss_class_list), loss_anomaly=np.mean(loss_anomaly_list), loss_sup=np.mean(loss_supc_list))

t.update(1)


logger.info('\n epoch: {}'.format(epoch))
logger.info(f'train mean loss:{np.mean(m_loss)}, class loss: {np.mean(loss_class_list)}, anomaly loss: {np.mean(loss_anomaly_list)}')

torch.save(model.state_dict(), "./checkpoint-{}".format(epoch))
print('\n epoch: {}'.format(epoch))
print(f'train mean loss:{np.mean(m_loss)}, class loss: {np.mean(loss_class_list)}, anomaly loss: {np.mean(loss_anomaly_list)}, sup loss: {np.mean(loss_supc_list)}')
torch.save(model.state_dict(), "./checkpoint-{}".format(epoch))