-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_1.py
187 lines (158 loc) · 7.29 KB
/
main_1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""
This script trains a model for graph-text classification and generates predictions.
Pipeline Overview:
1. Initialize necessary modules and variables.
2. Data Loading: Load graph and text datasets for training, validation, and testing.
3. Model Definition: Define the Graph-Text Fusion model architecture.
4. Define contrastive loss function (combining InfoNCE and CE).
5. Optimizer and Scheduler Setup: Setup AdamW optimizer with separate learning rates and linear scheduler with warm-up steps
6. Training Loop: Iterate over epochs and batches to train the model.
7. Validation: Evaluate the model's performance on a validation set / Save the model if validation loss improves.
8. Testing and Submission: Generate embeddings for the test dataset using the best model and create a CSV file with similarity scores.
Disclaimer : Some parts of this code were provided by the challenge organizers
"""
from sklearn.metrics import label_ranking_average_precision_score
from dataloader import GraphTextDataset, GraphDataset, TextDataset
from torch_geometric.data import DataLoader
from torch.utils.data import DataLoader as TorchDataLoader
from Model1 import Model
import numpy as np
from transformers import AutoTokenizer
import torch
from torch import optim
import time
import os
import pandas as pd
from info_nce import InfoNCE
from torch.optim.lr_scheduler import ReduceLROnPlateau
from loss import contrastive_loss
from sklearn.metrics.pairwise import cosine_similarity
# model_name = 'distilbert-base-uncased'
model_name = 'allenai/scibert_scivocab_uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
gt = np.load("../data/token_embedding_dict.npy", allow_pickle=True)[()]
val_dataset = GraphTextDataset(root='../data/', gt=gt, split='val_scibert', tokenizer=tokenizer)
train_dataset = GraphTextDataset(root='../data/', gt=gt, split='train_scibert', tokenizer=tokenizer)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nb_epochs = 100
batch_size = 32
learning_rate = 5e-5
bert_lr = 5e-6
val_loader = DataLoader(val_dataset, batch_size=batch_size//2, shuffle=True, drop_last=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
model = Model(model_name=model_name, num_node_features=300, nout=2048, nhid=300, graph_hidden_channels=2048, conv_type="GAT", nheads=8) # nout is changed to 2048
model.to(device)
parameter_string = 'model_name: {}, num_node_features: {}, nout: {}, nhid: {}, graph_hidden_channels: {}, conv_type: {}, nheads: {}'.format(model_name, 300, 2048, 500, 500, "GAT", 10) # Adjusted nout to 2048
print(parameter_string)
print(batch_size)
print(learning_rate)
optimizer = optim.AdamW([{'params': model.graph_encoder.parameters()},
{'params': model.text_encoder.parameters(), 'lr': bert_lr}], lr=learning_rate,
betas=(0.9, 0.999),
weight_decay=0.01)
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, verbose=True, min_lr=1e-7)
epoch = 0
loss = 0
losses = []
count_iter = 0
time1 = time.time()
printEvery = 50
best_validation_loss = 1000000
for i in range(nb_epochs):
print('-----EPOCH{}-----'.format(i+1))
model.train()
for idx, batch in enumerate(train_loader):
input_ids = batch.input_ids
batch.pop('input_ids')
attention_mask = batch.attention_mask
batch.pop('attention_mask')
graph_batch = batch
x_graph, x_text = model(graph_batch.to(device),
input_ids.to(device),
attention_mask.to(device))
current_loss = contrastive_loss(x_graph, x_text)
# accumulating 2 batches before backprop
# optimizer.zero_grad()
# current_loss.backward()
# optimizer.step()
optimizer.zero_grad()
current_loss.backward()
optimizer.step()
loss += current_loss.item().detach()
count_iter += 1
if count_iter % printEvery == 0:
time2 = time.time()
print("Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(count_iter,
time2 - time1, loss/printEvery), flush=True)
losses.append(loss)
loss = 0
optimizer.zero_grad()
del graph_batch
del input_ids
del attention_mask
del batch
torch.cuda.empty_cache()
model.eval()
val_loss = 0
val_text = []
val_graph = []
for batch in val_loader:
input_ids = batch.input_ids
batch.pop('input_ids')
attention_mask = batch.attention_mask
batch.pop('attention_mask')
graph_batch = batch
x_graph, x_text = model(graph_batch.to(device),
input_ids.to(device),
attention_mask.to(device))
val_text.append(x_text.tolist())
val_graph.append(x_graph.tolist())
current_loss = contrastive_loss(x_graph, x_text)
val_loss += current_loss.item()
lr_scheduler.step(val_loss)
val_text = np.concatenate(val_text)
val_graph = np.concatenate(val_graph)
similarity = cosine_similarity(val_text, val_graph)
print('validation lrap: ', label_ranking_average_precision_score(np.eye(similarity.shape[0]), similarity), flush=True)
best_validation_loss = min(best_validation_loss, val_loss)
print('-----EPOCH'+str(i+1)+'----- done. Validation loss: ', str(val_loss/len(val_loader)) ,flush=True)
if best_validation_loss==val_loss:
print('validation loss improoved saving checkpoint...')
if os.path.exists("./models/"):
for file in os.listdir("./models/"):
if file.startswith('model'):
os.remove(os.path.join("./models/", file))
save_path = os.path.join('./models/', 'model'+str(i)+'.pt')
torch.save({
'epoch': i,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'validation_accuracy': val_loss,
'loss': loss,
}, save_path)
print('checkpoint saved to: {}'.format(save_path))
print('loading best model...')
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
graph_model = model.get_graph_encoder()
text_model = model.get_text_encoder()
test_cids_dataset = GraphDataset(root='./data/', gt=gt, split='test_cids')
test_text_dataset = TextDataset(file_path='./data/test_text.txt', tokenizer=tokenizer)
idx_to_cid = test_cids_dataset.get_idx_to_cid()
test_loader = DataLoader(test_cids_dataset, batch_size=batch_size, shuffle=False)
graph_embeddings = []
for batch in test_loader:
for output in graph_model(batch.to(device)):
graph_embeddings.append(output.tolist())
test_text_loader = TorchDataLoader(test_text_dataset, batch_size=batch_size, shuffle=False)
text_embeddings = []
for batch in test_text_loader:
for output in text_model(batch['input_ids'].to(device),
attention_mask=batch['attention_mask'].to(device)):
text_embeddings.append(output.tolist())
similarity = cosine_similarity(text_embeddings, graph_embeddings)
solution = pd.DataFrame(similarity)
solution['ID'] = solution.index
solution = solution[['ID'] + [col for col in solution.columns if col!='ID']]
solution.to_csv('submission.csv', index=False)