-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtp.py
109 lines (91 loc) · 4.03 KB
/
tp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# model1.py
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import streamlit as st
import nltk
from nltk.probability import FreqDist
def load_model1():
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
return tokenizer, model
def predict_model1(text, tokenizer, model):
encoded_input = tokenizer.encode(text, add_special_tokens=False, return_tensors='pt')
input_ids = encoded_input[0]
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
return logits
# app_new.py
def load_model2():
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
return tokenizer, model
def predict_model2(text, tokenizer, model):
encoded_input = tokenizer.encode(text, add_special_tokens=False, return_tensors='pt')
input_ids = encoded_input[0]
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
return logits
# Function to calculate perplexity
def calculate_perplexity(logits, input_ids):
cross_entropy_loss = torch.nn.CrossEntropyLoss()
loss = cross_entropy_loss(logits.view(-1, logits.size(-1)), input_ids.view(-1))
perplexity = torch.exp(loss)
return perplexity.item()
# Function to calculate burstiness
def calculate_burstiness(text):
tokens = nltk.word_tokenize(text.lower())
word_freq = FreqDist(tokens)
repeated_count = sum(count > 1 for count in word_freq.values())
burstiness_score = repeated_count / len(tokens)
return burstiness_score
# Load models
tokenizer1, model1 = load_model1()
tokenizer2, model2 = load_model2()
st.set_page_config(layout="wide")
st.title("GPT Shield: AI Plagiarism Detector")
# User input
text = st.text_area("Enter text", "")
# Define text input
if st.button("Analyze"):
if text:
# Tokenize input text
input_ids1 = tokenizer1.encode(text, add_special_tokens=False, return_tensors='pt')
input_ids2 = tokenizer2.encode(text, add_special_tokens=False, return_tensors='pt')
# Get predictions
logits1 = predict_model1(text, tokenizer1, model1)
logits2 = predict_model2(text, tokenizer2, model2)
# Calculate perplexity for both models
perplexity1 = calculate_perplexity(logits1, input_ids1)
perplexity2 = calculate_perplexity(logits2, input_ids2)
# Calculate burstiness
burstiness_score = calculate_burstiness(text)
# Assign weights (assumed here as equal, modify based on validation performance if available)
weight1 = 0.5
weight2 = 0.5
# Ensure weights sum to 1
total_weight = weight1 + weight2
weight1 /= total_weight
weight2 /= total_weight
# Combine perplexity scores using weighted average
combined_perplexity = weight1 * perplexity1 + weight2 * perplexity2
# Determine if text is AI-generated based on combined perplexity and burstiness
is_ai_generated = combined_perplexity > 7524.77197265625 and burstiness_score < 0.2
model1_ai_generated = perplexity1 > 7524.77197265625 and burstiness_score < 0.2
model2_ai_generated = perplexity2 > 7524.77197265625 and burstiness_score < 0.2
st.write("*Individual Model Results:*")
st.write(f"Model 1: {'AI generated content' if model1_ai_generated else 'Likely not generated by AI'}")
st.write(f"Model 2: {'AI generated content' if model2_ai_generated else 'Likely not generated by AI'}")
# Output the result
if is_ai_generated:
st.markdown("*Text Analysis Result: AI generated content*")
else:
st.markdown("*Text Analysis Result: Likely not generated by AI*")
# Display detailed scores
st.write(f"Perplexity (Model 1): {perplexity1}")
st.write(f"Perplexity (Model 2): {perplexity2}")
st.write(f"Combined Perplexity: {combined_perplexity}")
st.write(f"Burstiness Score: {burstiness_score}")
else:
st.error("Please enter some text to analyze.")