Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
jagadeshchilla committed Oct 17, 2024
1 parent 13244eb commit 6d867fa
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions pages/Image_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def save_predictions_to_history(uploaded_files, predictions, model_name):
'prediction': actual
})

st.session_state.saved_predictions.extend(prediction_data)
# Check if predictions exist and extend saved_predictions list
if isinstance(st.session_state.saved_predictions, list):
st.session_state.saved_predictions.extend(prediction_data)

with open('prediction_history.json', 'w') as f:
json.dump(st.session_state.saved_predictions, f, indent=4)
Expand All @@ -85,21 +87,29 @@ def download_and_load_model(model_url):
if st.session_state.model_temp_file is None:
with tempfile.NamedTemporaryFile(suffix='.keras', delete=False) as tmp:
st.session_state.model_temp_file = tmp.name
# Show toast message for downloading model
st.toast("📥 Downloading model... Please wait.")
# Spinner for downloading the model
with st.spinner("Downloading the model..."):
gdown.download(model_url, st.session_state.model_temp_file, quiet=False)

# Show toast message for download completion
st.toast("✅ Model download completed!")

# Load the model from the temp file
model = load_model(st.session_state.model_temp_file)
return model

def show_image_prediction():
# Streamlit UI
st.title('Oral Cancer Detection Model Evaluation')

# Model selection
model_selection = st.selectbox("Select a model", list(model_links.keys()))

uploaded_files = st.file_uploader("Upload images", type=['jpg', 'jpeg', 'png'], accept_multiple_files=True)
# Upload images
uploaded_files = st.file_uploader(
"Upload images", type=['jpg', 'jpeg', 'png'], accept_multiple_files=True)

if uploaded_files:
target_size = model_links[model_selection]['target_size']
Expand All @@ -114,24 +124,30 @@ def load_uploaded_images(uploaded_files, target_size):

X_test = load_uploaded_images(uploaded_files, target_size)

# Function to evaluate the model on uploaded images
def evaluate_model(model, images):
predictions = model.predict(images)
predicted_classes = (predictions > 0.5).astype(int)
return predicted_classes

# Add a button to trigger predictions
if st.button('Predict'):
st.info("Downloading and loading the model. This may take a few moments...")

# Download and load the model
model_url = model_links[model_selection]['url']
with st.spinner("Loading model..."):
model_to_use = download_and_load_model(model_url)

# Evaluate the model
with st.spinner("Evaluating images..."):
st.session_state.predictions = evaluate_model(model_to_use, X_test)
st.session_state.uploaded_images = uploaded_files

# Show toast message for image prediction
st.toast("✨ Images predicted successfully!")

# Display predictions
st.subheader('Predictions:')
for i, uploaded_file in enumerate(uploaded_files):
actual = 'Cancer' if st.session_state.predictions[i][0] == 0 else 'Non Cancer'
Expand All @@ -148,20 +164,17 @@ def evaluate_model(model, images):
if st.button('Clear'):
st.session_state.predictions = []
st.session_state.uploaded_images = []
st.session_state.model_temp_file = None
st.session_state.model_temp_file = None # Reset the temp file
st.success("🗑️ Cleared all predictions and uploaded images.")

with col2:
# Ensure that both predictions and uploaded_images are initialized before accessing them
if hasattr(st.session_state, 'predictions') and len(st.session_state.predictions) > 0 and \
hasattr(st.session_state, 'uploaded_images') and len(st.session_state.uploaded_images) > 0:
if len(st.session_state.predictions) > 0 and len(st.session_state.uploaded_images) > 0:
if st.button('Save Predictions'):
save_predictions_to_history(
st.session_state.uploaded_images, st.session_state.predictions, model_selection)

# Download predictions functionality
if hasattr(st.session_state, 'predictions') and len(st.session_state.predictions) > 0 and \
hasattr(st.session_state, 'uploaded_images') and len(st.session_state.uploaded_images) > 0:
if len(st.session_state.predictions) > 0 and len(st.session_state.uploaded_images) > 0:
prediction_images = []
for i, uploaded_file in enumerate(st.session_state.uploaded_images):
actual = 'Cancer' if st.session_state.predictions[i][0] == 0 else 'Non Cancer'
Expand All @@ -176,6 +189,7 @@ def evaluate_model(model, images):
buf.seek(0)
prediction_images.append((buf, f'prediction_{i + 1}.png'))

# Create zip file for download
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w') as zf:
for image_buf, filename in prediction_images:
Expand All @@ -189,22 +203,27 @@ def evaluate_model(model, images):
mime='application/zip'
)

# Utility function to convert an image to base64 for display
def image_to_base64(image: Image.Image) -> str:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()

# Display logo
logo_path = "./assets/logo.png" # Update with your logo file path
logo_image = Image.open(logo_path)

# Convert the logo image to base64
logo_base64 = image_to_base64(logo_image)

# Display the logo with custom CSS styles
st.sidebar.markdown(
f"""
<img src="data:image/jpeg;base64,{logo_base64}"
style="border-radius: 30px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); width: 90%; height: auto;" />
""", unsafe_allow_html=True
)

# Call the function to show image prediction
if __name__ == "__main__":
# Streamlit application runner
if __name__ == '__main__':
show_image_prediction()

0 comments on commit 6d867fa

Please sign in to comment.