-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstreamlit_app.py
176 lines (143 loc) · 6.54 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import streamlit as st
import pandas as pd
from datetime import datetime, timedelta
import pickle
import os
from main import NursePayPipeline
def load_pipeline():
pipeline = NursePayPipeline()
if not pipeline.load_models():
st.error("Error loading models. Please ensure models are trained first.")
return None
return pipeline
def main():
st.set_page_config(page_title="Nurse Pay Rate Predictor", layout="wide")
st.title("🏥 Nurse Pay Rate Prediction System")
st.write("This application predicts hourly pay rates for healthcare professionals based on various factors.")
# Initialize pipeline
pipeline = load_pipeline()
if pipeline is None:
return
# Create tabs for different functionalities
tab1, tab2, tab3 = st.tabs(["Pay Prediction", "Market Analysis", "Historical Trends"])
with tab1:
st.header("Pay Rate Prediction")
col1, col2 = st.columns(2)
with col1:
job_title = st.selectbox(
"Select Job Title",
options=pipeline.job_titles,
help="Choose the healthcare position"
)
location = st.selectbox(
"Select Location",
options=pipeline.locations,
help="Choose the work location"
)
hospital_name = st.selectbox(
"Select Hospital",
options=[pipeline.generate_hospital_name(location) for _ in range(5)],
help="Select the healthcare facility"
)
with col2:
start_date = st.date_input(
"Contract Start Date",
max_value=datetime.now(),
help="When does the contract begin?"
)
duration_weeks = st.slider(
"Contract Duration (weeks)",
min_value=1,
max_value=13,
help="How long is the contract?"
)
end_date = start_date + timedelta(weeks=duration_weeks)
st.write(f"Contract End Date: {end_date}")
if st.button("Calculate Pay Rate", type="primary"):
# Prepare prediction data
pred_data = pd.DataFrame({
'Job_Title': [job_title],
'Location': [location],
'Hospital_Name': [hospital_name],
'Contract_Start': [start_date],
'Contract_End': [end_date],
'Season': ['normal'], # Will be updated in preprocessing
'Hourly_Pay': [0] # Placeholder
})
# Preprocess the data
pred_data = pipeline.preprocess_data(pred_data)
# Make predictions with both models
with st.spinner("Calculating predictions..."):
# Format predictions for display
month = start_date.month
season = "holiday" if month == 12 else "flu" if month in [10, 11, 1, 2, 3, 4, 5] else "normal"
base_rate = pipeline.base_rates[job_title]
st.success("Prediction Complete!")
col1, col2, col3 = st.columns(3)
with col1:
st.metric(
"Base Rate",
f"${base_rate:.2f}/hr",
help="Standard base rate for this position"
)
with col2:
seasonal_adjustment = 1.0 if season == "normal" else 1.2 if season == "flu" else 1.3
adjusted_rate = base_rate * seasonal_adjustment
st.metric(
"Seasonal Adjusted Rate",
f"${adjusted_rate:.2f}/hr",
f"{((seasonal_adjustment - 1) * 100):.0f}% seasonal adjustment",
help=f"Rate adjusted for {season} season"
)
with col3:
desirability_score = pipeline.desirability_scores[location.split(",")[0]]
st.metric(
"Location Score",
f"{desirability_score}/100",
help="Location desirability rating"
)
with tab2:
st.header("Market Analysis")
# Load and display market trends
try:
data = pd.read_csv("Synthetic_Nurse_Pay_Data.csv")
data = pipeline.preprocess_data(data)
st.subheader("Pay Rates by Location")
avg_pay_by_location = data.groupby('Location')['Hourly_Pay'].agg(['mean', 'min', 'max']).round(2)
st.dataframe(avg_pay_by_location, use_container_width=True)
st.subheader("Seasonal Trends")
seasonal_avg = data.groupby('Season')['Hourly_Pay'].mean().round(2)
st.bar_chart(seasonal_avg)
except Exception as e:
st.error("Error loading market analysis data. Please ensure the dataset exists.")
with tab3:
st.header("Historical Trends")
st.write("View historical pay rate trends for different positions and locations.")
# Allow users to filter data
selected_location = st.multiselect(
"Select Locations",
options=pipeline.locations,
default=[pipeline.locations[0]]
)
selected_job = st.multiselect(
"Select Job Titles",
options=pipeline.job_titles,
default=[pipeline.job_titles[0]]
)
try:
data = pd.read_csv("Synthetic_Nurse_Pay_Data.csv")
filtered_data = data[
data['Location'].isin(selected_location) &
data['Job_Title'].isin(selected_job)
]
if not filtered_data.empty:
st.line_chart(
filtered_data.groupby('Contract_Start')['Hourly_Pay'].mean(),
use_container_width=True
)
else:
st.warning("No data available for the selected filters.")
except Exception as e:
st.error("Error loading historical data. Please ensure the dataset exists.")
if __name__ == "__main__":
main()