Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 修复测试错误 #5

Merged
merged 2 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions examples/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,20 @@ async def handle_message_group(event: MessageGroupEvent):
message=MessageChain([
Text('你好')
]).to_dict()
),
self=None
)
), SendMessageResponse)


@bot.on(MessagePrivateEvent)
async def handle_message_private(event: MessagePrivateEvent):
print(event.message)
await bot.call(SendMessageRequest(
params=SendMessageRequestParams(
detail_type="private",
user_id=event.user_id,
message=MessageChain([
Text('你好')
]).to_dict()
)
), SendMessageResponse)

bot.run()
14 changes: 9 additions & 5 deletions mirai_onebot/adapters/reverse_websocket_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ def stop(self):
if self.server is None:
return

tasks: List[asyncio.Task] = []
for ws in self.ws_connections:
asyncio.get_event_loop().run_until_complete(ws.close())
tasks.append(asyncio.create_task(ws.close()))
asyncio.get_event_loop().run_until_complete(asyncio.wait(tasks))

async def handler(self, websocket: websockets.WebSocketServerProtocol, path: str):
# 检测OneBot标准版本
Expand All @@ -53,16 +55,18 @@ async def handler(self, websocket: websockets.WebSocketServerProtocol, path: str

if protocol_str is None:
logger.warning('未提供Sec-WebSocket-Protocol,可能出现兼容性问题。')
logger.warning('自动将 protocol 指定为 12.undefined。')
protocol_str = '12.undefined'
logger.warning('自动将 protocol 指定为 11.undefined。')
protocol_str = '11.undefined'

protocol = protocol_str.split('.')

if int(protocol[0]) > 12:
logger.warning('不支持版本12以上的OneBot实现,可能出现兼容性问题。')
elif int(protocol[0]) < 12:
elif int(protocol[0]) < 11:
logger.warning(
f'不支持版本12以下的OneBot实现,可能出现兼容性问题,请在调用api前查询对应OneBot {protocol[0]} 的接口定义。')
f'不支持版本11以下的OneBot实现,可能出现兼容性问题,请在调用api前查询对应OneBot {protocol[0]} 的接口定义。')
elif int(protocol[0]) == 11:
logger.warning('OneBot 11 的实现暂处于开发阶段,可能存在兼容性问题。')

# 检测Access Token
query = parse_qs(urlparse(websocket.path).query)
Expand Down
2 changes: 1 addition & 1 deletion mirai_onebot/api/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ async def call(self, request: Request, response_type: Type[Response]):
Args:
api (Request): API接口
"""
resp = await self._call_api(request.action, request.params.model_dump(mode='json'), request.echo)
resp = await self._call_api(request.action, request.params.model_dump(mode='json'), request.echo if request.echo is not None else secrets.token_hex(8))
return response_type.model_validate(resp)
2 changes: 1 addition & 1 deletion mirai_onebot/api/interfaces/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Request(BaseModel):
action: str
params: RequestParams
echo: Optional[str] = Field(default_factory=lambda: secrets.token_hex(8))
self: Optional[BotSelf]
self: Optional[BotSelf] = None


class Response(BaseModel):
Expand Down
10 changes: 6 additions & 4 deletions mirai_onebot/event/bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@ def decorator(func: Callable) -> Callable:

return decorator

async def emit(self, event: Union[Type[EventBase], str], *args, **kwargs) -> None:
async def emit(self, event: Union[Type[EventBase], str], background: bool = True, *args, **kwargs) -> None:
"""触发事件

Args:
event (str | Type[EventBase]): 事件
background (bool, optional): 是否在后台触发事件,设置为False会等待事件完成. Defaults to True.
args/kwargs: 传递给事件处理器的参数
"""
if event in self._subscribers.keys():
[asyncio.create_task(subscriber(*args, **kwargs))
for subscriber in self._subscribers[event]]
# await asyncio.wait(tasks)
tasks = [asyncio.create_task(subscriber(*args, **kwargs))
for subscriber in self._subscribers[event]]
if not background:
await asyncio.wait(tasks)
6 changes: 3 additions & 3 deletions test/test_event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ async def handle_message_group_event():

bus.subscribe('print_message', handle_print_message)

await bus.emit(MessageGroupEvent)
await bus.emit('print_message', 'hello')
await bus.emit(MessageGroupEvent, background=False)
await bus.emit(event='print_message', background=False, message='hello')

assert run1 is True
assert run2 is True
Expand Down Expand Up @@ -60,7 +60,7 @@ async def test31():
global run2
run2 = True

await bus.emit('test3')
await bus.emit('test3', background=False)

assert run1 is True
assert run2 is True
Expand Down
14 changes: 7 additions & 7 deletions test/test_rwebsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,21 @@ async def subscribe(data: dict):
assert data['id'] == id

# 错误事件
await ws_client.send('hello?')
await ws_client.send('{"good": "nice"}')

# 调用api
asyncio.create_task(adapter.call_api('hello', test='test'))
asyncio.create_task(adapter._call_api('hello', {'test': 'test'}))

# 发送响应
echo = json.loads(await ws_client.recv())['echo']
await ws_client.send(json.dumps({'echo': echo, 'resp': 'test'}))

# 调用api超时
await adapter.call_api('hello', test='test')
await adapter._call_api('hello', {'test': 'test'})
await asyncio.sleep(1.5)

# 关闭连接
await ws_client.close()

# 再次调用handler函数,引发错误
await adapter.handler(adapter.ws_connections[0], '/')
def test_new_adapter():
adapter = ReverseWebsocketAdapter('hello', '0.0.0.0', 4561, 1)
adapter.start()
adapter.stop()
Loading