-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfine_tune.py
110 lines (90 loc) · 4.26 KB
/
fine_tune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Another version of main.py used to fine-tune a pretrained model
Look at finetune combined
"""
from dataloader import GraphTextDataset, GraphDataset, TextDataset
from torch_geometric.data import DataLoader
from torch.utils.data import DataLoader as TorchDataLoader
from Model4 import Model
import numpy as np
from transformers import AutoTokenizer
import torch
from torch import optim
import time
import os
import pandas as pd
from info_nce import InfoNCE
from torch.optim.lr_scheduler import ReduceLROnPlateau
CE = torch.nn.CrossEntropyLoss()
INCE = InfoNCE() # Contrastive Predictive Coding; van den Oord, et al. 2018
def contrastive_loss(v1, v2, beta=0.1):
logits = torch.matmul(v1, torch.transpose(v2, 0, 1))
labels = torch.arange(logits.shape[0], device=v1.device)
return beta * (CE(logits, labels) + CE(torch.transpose(logits, 0, 1), labels)) + (1 - beta) * (INCE(v1, v2) + INCE(v2, v1))
model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
gt = np.load("../data/token_embedding_dict.npy", allow_pickle=True)[()]
val_dataset = GraphTextDataset(root='../data/', gt=gt, split='val', tokenizer=tokenizer)
train_dataset = GraphTextDataset(root='../data/', gt=gt, split='train', tokenizer=tokenizer)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nb_epochs = 100
batch_size = 32
learning_rate = 2e-6
bert_lr = 2e-7
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
model = Model(model_name=model_name, num_node_features=300, nout=2048, nhid=300, graph_hidden_channels=2048, conv_type="GAT", nheads=10) # nout is changed to 2048
model.to(device)
parameter_string = 'model_name: {}, num_node_features: {}, nout: {}, nhid: {}, graph_hidden_channels: {}, conv_type: {}, nheads: {}'.format(model_name, 300, 2048, 500, 500, "GAT", 10) # Adjusted nout to 2048
print(parameter_string)
print(batch_size)
print(learning_rate)
optimizer = optim.AdamW([{'params': model.graph_encoder.parameters()},
{'params': model.text_encoder.parameters(), 'lr': bert_lr}], lr=learning_rate,
betas=(0.9, 0.999),
weight_decay=0.01)
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True, min_lr=1e-7)
loss = 0
losses = []
count_iter = 0
time1 = time.time()
printEvery = 50
best_validation_loss = 1000000
last_seen_model = None
if os.path.exists("./models/"):
for file in os.listdir("./models/"):
if file.startswith('model'):
last_seen_model = os.path.join("./models/", file)
if last_seen_model is not None:
checkpoint = torch.load(last_seen_model)
model.load_state_dict(checkpoint['model_state_dict'])
print('loaded model from: {}'.format(last_seen_model))
epoch = int(last_seen_model.split('/')[-1].split('.')[0].split('model')[-1])
best_validation_loss = checkpoint['validation_accuracy']
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True, min_lr=1e-7)
else:
epoch = 0
del checkpoint
model.eval()
graph_model = model.get_graph_encoder()
text_model = model.get_text_encoder()
test_cids_dataset = GraphDataset(root='../data/', gt=gt, split='test_cids')
test_text_dataset = TextDataset(file_path='../data/test_text.txt', tokenizer=tokenizer)
idx_to_cid = test_cids_dataset.get_idx_to_cid()
test_loader = DataLoader(test_cids_dataset, batch_size=batch_size, shuffle=False)
graph_embeddings = []
for batch in test_loader:
for output in graph_model(batch.to(device)):
graph_embeddings.append(output.tolist())
test_text_loader = TorchDataLoader(test_text_dataset, batch_size=batch_size, shuffle=False)
text_embeddings = []
for batch in test_text_loader:
for output in text_model(batch['input_ids'].to(device),
attention_mask=batch['attention_mask'].to(device)):
text_embeddings.append(output.tolist())
from sklearn.metrics.pairwise import cosine_similarity, linear_kernel
similarity = cosine_similarity(text_embeddings, graph_embeddings)
solution = pd.DataFrame(similarity)
solution['ID'] = solution.index
solution = solution[['ID'] + [col for col in solution.columns if col!='ID']]
solution.to_csv('submission.csv', index=False)