-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlimerick_gpt_2.py
216 lines (186 loc) · 5.98 KB
/
limerick_gpt_2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# -*- coding: utf-8 -*-
"""Limerick GPT-2
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1Rr4F4XSNZhC1jOVnUWHa0a3e9cQhEVvx
# Primary Imports and Setup
"""
# Commented out IPython magic to ensure Python compatibility.
#@title Basic imports
# %tensorflow_version 1.x
!pip install -q gpt-2-simple==0.7.2
import gpt_2_simple as gpt2
from datetime import datetime
from google.colab import files
#@title Choose and download model
##124M,355M,774M,1558M
model_name = "774M" #@param {type:"string"}
gpt2.download_gpt2(model_name=model_name)
#@title Mount Gdrive
gpt2.mount_gdrive()
#@title Import checkpoint to workspace (optional)
gpt2.copy_checkpoint_from_gdrive('limericks_rhymified')
"""# Data Pre-processing"""
#@title Functions for Testing Limericks
!pip install pyphen
import pyphen
dic = pyphen.Pyphen(lang='en')
#checks to see if peom is rhyming
def is_rhyming(poem:str):
l125= []
l34 = []
for i in range(len(poem.split('\n'))):
line = poem.split('\n')[i]
w = line.split(' ')[-1]
w = re.sub(r'[^\w\s]','',w).lower()
if i == 0 or i == 1 or i == 4:
try:
l125.append(arpabet[w][0][-1])
except:
return False
else:
try:
l34.append (arpabet[w][0][-1])
except:
return False
if len(set(l125)) == 1 and len(set(l34))==1: return True
else: return False
#checks to see if the poem has correct number of lines
def is_lining(poem:str):
if len(poem.split('\n')) == 5: return True
else: return False
#outputs the number of syllables in each line
def wordcounts(poem:str):
w_counts = []
for line in poem.split('\n'):
w=[]
for word in line.split():
w.extend(dic.inserted(word).split('-'))
w_counts.append(len(w))
return w_counts
#checks if the poem has the right number of syllables
def is_wordcounting(poem:str):
wc = wordcounts(poem)
if min([wc[0],wc[1],wc[4]]) > max([wc[2],wc[3]]):
return True
else:
return False
def limerick_score(poem:str):
score = 0
if is_lining(poem): score += 1
try:
if is_wordcounting(poem): score +=1
except:
pass
try:
if is_rhyming(poem): score += 1
except:
pass
return score
#@title Tagging and Untagging Functions
!pip install pronouncing
import pandas as pd
import nltk
import re
import pronouncing
from random import sample
nltk.download('cmudict')
arpabet = nltk.corpus.cmudict.dict()
#Tags each word in a poem with its phonetic representation according to cmu lexicon in NLTK in the following form: word[phonemes]
def phoneticize(poem):
poem_phonetics = []
for line in poem.split('\n'):
line_phonetics = []
for word in line.split(' '):
try:
w = re.sub(r'[^\w\s]','',word).lower()
line_phonetics.append(word+'['+' '.join(arpabet[w][0])+']')
except:
line_phonetics.append(word)
#print(line_phonetics)
line_phonetics = ' '.join(line_phonetics)
poem_phonetics.append(line_phonetics)
poem_phonetics = "\n".join(poem_phonetics)
return poem_phonetics
#Initially meant to be dephoneticizer, but it generalized as invariant untagger
def dephoneticize(poem:str):
s = []
p = False
for i in range(len(poem)):
if poem[i] == '[': p = True
if poem[i] == ']':
p = False
continue
if not p: s.append(poem[i])
return ''.join(s)
#Tags the end of each line with all the rhymes present in the poem and rhymes in the rhyming dictionary in 'pronouncing'.
#It outputs 5 rhymes at the end of the line in the following format: this is a sentence['penance', 'admittance','forbiddance','presence', 'incense']
def rhymify(poem:str):
clean = lambda x : re.sub(r'[^\w\s]','',x).lower()
lines = [n for n in poem.split('\n')]
gw = [n.split(' ')[-1] for n in lines]
l = [clean(i) for i in gw]
last0 = l[0:2]+l[-1:]
last1 = l[2:4]
p = []
for line in lines:
lm = ''
w = clean(line.split(" ")[-1])
r = pronouncing.rhymes(w)
if w in last0:
if len(r) >= 2:
r = sample(r,2) + last0
else:
r= r + last0
lm = line + str(r)
if w in last1:
if len(r) >= 3:
r = sample(r,3) + last1
else:
r= r + last1
lm = line + str(r)
p.append(lm)
return "\n".join(p)
#@title Read Limerick dataset
import pandas as pd
limerick_df = pd.read_csv('limerick_dataset.csv', encoding='utf-8')
limerick_df.head()
#@title Rhymify/Phoneticize database + Add start/end token + Stringify and Export as txt
#Rhymify/Phoneticize
limerick_df['limerick'] = limerick_df['limerick'].apply(lambda x: rhymify(x))
# Stringify and Add tokens
limerick_string = "<|endoftext|>\n<|startoftext|>".join(limerick_df['limerick'])
limerick_string = "<|startoftext|>"+limerick_string + "<|endoftext|>"
#Export as txt
with open("limerick_dataset.txt", "w") as text_file:
text_file.write(limerick_string)
"""# Finetuning Model"""
#@title Load existing model (optional)
gpt2.load_gpt2(sess, run_name='limericks_774')
#@title Start Sess
sess = gpt2.start_tf_sess()
#@title Finetune
#sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
dataset=file_name,
model_name='774M',
steps=5000,
restore_from='latest',
run_name='limericks_rhymified',
print_every=100,
sample_every=100,
save_every=500,
learning_rate = 3e-7,only_train_transformer_layers=False)
#@title Generate and score
from collections import Counter
answers = gpt2.generate(sess, prefix = "<|startoftext|>" , run_name="limericks_774", nsamples=100, temperature=0.1, length=300,truncate="<|endoftext|>" ,return_as_list=True)
scores = []
for i in answers:
scores.append(limerick_score(dephoneticize(i).replace("<|startoftext|>",'')))
print(Counter(scores))
#@title Find indices with perfect limerick score
indices = [i for i, x in enumerate(scores) if x == 3]
#@title Print Perfect Limericks
for i in indices:
print(dephoneticize(answers[i]).replace("<|startoftext|>",''),end='\n\n')
#!kill -9 -1