-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
152 lines (125 loc) · 4.72 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import streamlit as st
import time
import os
import cv2
import random
import base64
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import transforms, models
from PIL import Image, ImageFont, ImageDraw
def autoplay_audio(file_path: str):
# code from https://github.com/streamlit/streamlit/issues/2446
# https://discuss.streamlit.io/t/remove-a-markdown/4053/2
placeholder = st.empty()
with open(file_path, "rb") as f:
data = f.read()
b64 = base64.b64encode(data).decode()
md = f"""
<audio controls autoplay="true" >
<source src="data:audio/mp3;base64,{b64}" type="audio/mp3">
</audio>
"""
placeholder.markdown(
md,
unsafe_allow_html=True,
)
time.sleep(1)
placeholder.empty()
def autoplay_audio_loop(file_path: str):
# code from https://github.com/streamlit/streamlit/issues/2446
# https://discuss.streamlit.io/t/remove-a-markdown/4053/2
with open(file_path, "rb") as f:
data = f.read()
b64 = base64.b64encode(data).decode()
md = f"""
<audio controls autoplay="true" loop="true">
<source src="data:audio/mp3;base64,{b64}" type="audio/mp3">
</audio>
"""
st.markdown(
md,
unsafe_allow_html=True,
)
KUN_DIR = os.getcwd() + '/Dataset/kun'
CHICKEN_DIR = os.getcwd() + '/Dataset/zhiyin'
WEIGHT_DIR = os.getcwd() + "/kun_weight.pt"
# Kun Classifier
device = "cuda" if torch.cuda.is_available() else "cpu"
class Kun_Classifier:
"""
Binary Classification Class
"""
def __init__(self):
"""
Init.
"""
pass
def inference(self, img_path):
"""
inference
:params: img_path: the image path of the image for inference
:returns: the inference result in terms of strings
"""
kun = "KUN: {a}% | 含坤量为:{a}%"
chicken = "CHICKEN: {a}% | 含只因量为:{a}%"
model = models.resnet18(pretrained=True)
nr_filters = model.fc.in_features
model.fc = nn.Linear(nr_filters, 1)
model = model.to(device)
img = Image.open(img_path).convert('RGB')
transformations = transforms.Compose([transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
img_tensor = transformations(img).reshape(1,3,224,224).to(device)
model.load_state_dict(torch.load(WEIGHT_DIR, map_location=torch.device(device)))
model.eval()
pred = model(img_tensor)
sigmoid = torch.sigmoid(pred)
if sigmoid < 0.5:
autoplay_audio('Assets/biebie.wav')
return kun.format(a=round(float(100 * (1-sigmoid)), 2))
else:
autoplay_audio('Assets/zhiyin.wav')
return chicken.format(a=round(float(100 * sigmoid), 2))
kuner = Kun_Classifier()
# Welcome Page
st.title("KUN-er Classifier")
autoplay_audio_loop('Assets/ji.mp3')
st.caption('Welcome! | 欢迎各位小黑子们前来体验二元坤类器! | https://github.com/zslrmhb/Kun_Classifier')
if st.button("Don't Click! | 小黑子勿按!"):autoplay_audio_loop("Assets/background.mp3")
tab1, tab2, tab3 = st.tabs(["Data Visualization | 让我康康", "Classification | 二元坤类器", "Let me try try | 让我试试"])
# Data Visualization
with tab1:
st.subheader('Data Visualization | 让我康康')
def show_and_classify(user_choice, classify=False):
if user_choice == 'kun | 坤':
result = KUN_DIR + '/' + random.choice(os.listdir(KUN_DIR))
elif user_choice == 'chicken | 只因':
result = CHICKEN_DIR + '/' + random.choice(os.listdir(CHICKEN_DIR))
else:
result = random.choice([KUN_DIR + '/' + random.choice(os.listdir(KUN_DIR)), CHICKEN_DIR + '/' + random.choice(os.listdir(CHICKEN_DIR))])
img = cv2.imread(result)
converted_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
st.image(converted_img)
if classify: st.caption(kuner.inference(result))
desired_image = st.select_slider('', options=['kun | 坤', 'chicken | 只因'])
if (st.button("Click for image 👆 | 点我看图 👆")): show_and_classify(desired_image)
with tab2:
# Classification
st.subheader('Classification | 二元坤类器')
desired_classify = st.select_slider('', options=['kun | 坤', 'Random | 随便', 'chicken | 只因'])
if (st.button("Click for classification 👆 | 点我看玄只因 👆")): show_and_classify(desired_classify, True)
with tab3:
# User Input
st.subheader('Let me try try | 让我试试')
def classify_user_input(picture):
if picture:
st.caption(kuner.inference(picture))
picture = st.camera_input("")
classify_user_input(picture)