Skip to content

Commit

Permalink
Merge pull request #5 from canstralian/canstralian-patch-1
Browse files Browse the repository at this point in the history
Create model.py
  • Loading branch information
canstralian authored Feb 2, 2025
2 parents bb4d206 + 81a87ca commit 45b2705
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset

def fine_tune(model_name, dataset_url=None, file=None, epochs=3, batch_size=8, learning_rate=5e-5):
try:
# Load dataset
if dataset_url:
dataset = load_dataset(dataset_url)
elif file:
dataset = load_dataset("csv", data_files={"train": file.name})
else:
raise ValueError("Please provide a dataset URL or upload a file.")

# Load model & tokenizer
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)

dataset = dataset.map(tokenize_function, batched=True)

# Define training arguments
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
save_strategy="epoch",
logging_strategy="epoch",
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=epochs,
weight_decay=0.01,
push_to_hub=False,
report_to="all"
)

# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
)

# Start training
trainer.train()

return "Fine-tuning complete."

except Exception as e:
return f"An error occurred: {e}"

0 comments on commit 45b2705

Please sign in to comment.