Skip to content

Commit

Permalink
u
Browse files Browse the repository at this point in the history
  • Loading branch information
jagadeshchilla committed Oct 17, 2024
1 parent bf5d89c commit 8072e66
Showing 1 changed file with 21 additions and 28 deletions.
49 changes: 21 additions & 28 deletions pages/Image_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import gdown
import tempfile

# Define the model links
# Define the model links and their target sizes
model_links = {
'CNN': {
'url': 'https://drive.google.com/uc?id=1mtDtPtM-E7y20LlFlEPn1UI20fykFL2P',
Expand All @@ -37,7 +37,7 @@
},
}

# Initialize session state
# Initialize session state for saved predictions and uploaded images
if 'saved_predictions' not in st.session_state:
st.session_state.saved_predictions = []
if 'predictions' not in st.session_state:
Expand All @@ -58,15 +58,14 @@ def load_existing_predictions():
st.session_state.saved_predictions = load_existing_predictions()

def save_predictions_to_history(uploaded_files, predictions, model_name):
"""Save predictions to a JSON file."""
prediction_data = []
for i, uploaded_file in enumerate(uploaded_files):
actual = 'Cancer' if predictions[i][0] == 0 else 'Non Cancer'
prediction_data.append({
"""Save predictions to history in a JSON file."""
prediction_data = [
{
'file_name': uploaded_file.name,
'model_used': model_name,
'prediction': actual
})
'prediction': 'Cancer' if predictions[i][0] == 0 else 'Non Cancer'
} for i, uploaded_file in enumerate(uploaded_files)
]

st.session_state.saved_predictions.extend(prediction_data)

Expand All @@ -83,7 +82,7 @@ def save_predictions_to_history(uploaded_files, predictions, model_name):
]

def download_and_load_model(model_url):
"""Downloads and loads the model from the provided Google Drive URL."""
"""Download and load the model from the provided Google Drive URL."""
if st.session_state.model_temp_file is None:
with tempfile.NamedTemporaryFile(suffix='.keras', delete=False) as tmp:
st.session_state.model_temp_file = tmp.name
Expand All @@ -92,26 +91,20 @@ def download_and_load_model(model_url):
gdown.download(model_url, st.session_state.model_temp_file, quiet=False)
st.toast("✅ Model download completed!")

model = load_model(st.session_state.model_temp_file)
return model
return load_model(st.session_state.model_temp_file)

def load_uploaded_images(uploaded_files, target_size):
"""Load uploaded images and resize them to the target size."""
images = []
for uploaded_file in uploaded_files:
image = load_img(uploaded_file, target_size=target_size)
image_array = img_to_array(image)
images.append(image_array)
"""Load and preprocess uploaded images."""
images = [img_to_array(load_img(uploaded_file, target_size=target_size)) for uploaded_file in uploaded_files]
return np.array(images)

def evaluate_model(model, images):
"""Evaluate the model on uploaded images and return predictions."""
"""Evaluate the model on the uploaded images."""
predictions = model.predict(images)
predicted_classes = (predictions > 0.5).astype(int)
return predicted_classes
return (predictions > 0.5).astype(int)

def show_image_prediction():
# Streamlit UI
"""Main function to show the image prediction interface."""
st.title('Oral Cancer Detection Model Evaluation')

# Model selection
Expand All @@ -128,6 +121,7 @@ def show_image_prediction():
if st.button('Predict'):
st.info("Downloading and loading the model. This may take a few moments...")

# Download and load the model
model_url = model_links[model_selection]['url']
with st.spinner("Loading model..."):
model_to_use = download_and_load_model(model_url)
Expand Down Expand Up @@ -156,17 +150,16 @@ def show_image_prediction():
if st.button('Clear'):
st.session_state.predictions = []
st.session_state.uploaded_images = []
st.session_state.model_temp_file = None
st.session_state.model_temp_file = None # Reset the temp file
st.success("🗑️ Cleared all predictions and uploaded images.")

with col2:
if len(st.session_state.predictions) > 0 and len(st.session_state.uploaded_images) > 0:
if st.session_state.predictions and st.session_state.uploaded_images:
if st.button('Save Predictions'):
save_predictions_to_history(
st.session_state.uploaded_images, st.session_state.predictions, model_selection)
save_predictions_to_history(st.session_state.uploaded_images, st.session_state.predictions, model_selection)

# Download predictions functionality
if len(st.session_state.predictions) > 0 and len(st.session_state.uploaded_images) > 0:
if st.session_state.predictions and st.session_state.uploaded_images:
prediction_images = []
for i, uploaded_file in enumerate(st.session_state.uploaded_images):
actual = 'Cancer' if st.session_state.predictions[i][0] == 0 else 'Non Cancer'
Expand Down Expand Up @@ -195,8 +188,8 @@ def show_image_prediction():
mime='application/zip'
)

# Utility function to convert an image to base64 for display
def image_to_base64(image: Image.Image) -> str:
"""Convert an image to base64 for display."""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
Expand Down

0 comments on commit 8072e66

Please sign in to comment.