-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
88 lines (71 loc) · 2.77 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#turning dataset into format that can fit sentence transformer
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
import torch
#The Stanford Natural Language Inference (SNLI) dataset is used for training our model, containing 570,000 human-annotated sentence pairs.
#Each pair consists of a premise and a hypothesis, labeled as entailment, contradiction, or neutral
# Load BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
#customized multi-tasks dataset loader. The original dataset is used for classification, to train the real world multi-tasks
#model, the dataset should has labels for both tasks. For the second task Sentence Similarity detection, we simply set
#three different class Entailment, Contradiction, Neutral as 1, -1, 0 seperately.
class SNLIDataset(Dataset):
def __init__(self, dataset):
"""
dataset: An iterable of SNLI-like dicts
(each has keys: 'premise', 'hypothesis', 'label')
"""
# Define label mapping for comparing cosine similarity later (map Entailment, Contradiction, Neutral to 2, 1, 0)
self.label_map = {2: 0, 1: 1, 0: 2}
self.data = [
(
example["premise"],
example["hypothesis"],
self.label_map[example["label"]]
)
for example in dataset
if example["label"] != -1
]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# Each item: (premise_text, hypothesis_text, mapped_label)
return self.data[idx]
def collate_fn(batch):
# batch is a list of tuples: [(premise, hypothesis, label), (...), ...]
premises, hypotheses, labels = zip(*batch) # each is a tuple of size batch_size
# Tokenize premises
premise_encodings = tokenizer(
list(premises),
padding=True,
truncation=True,
return_tensors="pt"
)
# Tokenize hypotheses
hypothesis_encodings = tokenizer(
list(hypotheses),
padding=True,
truncation=True,
return_tensors="pt"
)
sentence_features = [
premise_encodings, # (dict with input_ids, attention_mask, etc.)
hypothesis_encodings
]
# Convert labels to a single tensor, shape [batch_size]
labels_tensor = torch.tensor(labels, dtype=torch.long)
return {
"sentence_features": sentence_features,
"label": labels_tensor
}
def create_loader(dataset, batch_size=16, shuffle=True):
"""
Wrap SNLIDataset in a DataLoader using the custom collate_fn.
"""
dataset = SNLIDataset(dataset)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn
)