Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
Wybxc committed Aug 14, 2021
2 parents cfd147e + 3b37011 commit fff9026
Show file tree
Hide file tree
Showing 15 changed files with 1,276 additions and 241 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ dmypy.json


# test cases
/test.py
/*.py

# pdoc documents
/docs
Expand Down
12 changes: 6 additions & 6 deletions mirai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
更多信息请看[文档](https://yiri-mirai.vercel.app/)。
"""
__version__ = '0.2.2'
__version__ = '0.2.3'
__author__ = '忘忧北萱草'

import logging
Expand All @@ -24,18 +24,18 @@
)
from mirai.models import (
At, AtAll, Dice, Event, Face, FriendMessage, GroupMessage, Image,
MessageChain, MessageEvent, Plain, Poke, StrangerMessage, TempMessage,
Voice, deserialize, serialize
MessageChain, MessageEvent, Plain, Poke, PokeNames, StrangerMessage,
TempMessage, Voice, deserialize, serialize
)

__all__ = [
'Mirai', 'SimpleMirai', 'MiraiRunner', 'LifeSpan', 'Startup', 'Shutdown',
'Adapter', 'Method', 'HTTPAdapter', 'WebSocketAdapter', 'WebHookAdapter',
'ComposeAdapter', 'EventBus', 'get_logger', 'Event', 'MessageEvent',
'FriendMessage', 'GroupMessage', 'TempMessage', 'StrangerMessage',
'MessageChain', 'Plain', 'At', 'AtAll', 'Dice', 'Face', 'Poke', 'Image',
'Voice', 'serialize', 'deserialize', 'ApiError', 'NetworkError',
'SkipExecution', 'StopExecution', 'StopPropagation'
'MessageChain', 'Plain', 'At', 'AtAll', 'Dice', 'Face', 'Poke',
'PokeNames', 'Image', 'Voice', 'serialize', 'deserialize', 'ApiError',
'NetworkError', 'SkipExecution', 'StopExecution', 'StopPropagation'
]

logger = logging.getLogger(__name__)
Expand Down
11 changes: 6 additions & 5 deletions mirai/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mirai import exceptions
from mirai.api_provider import ApiProvider, Method
from mirai.bus import AbstractEventBus
from mirai.tasks import Tasks

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,14 +113,14 @@ def via(cls, adapter_interface: AdapterInterface) -> "Adapter":
def register_event_bus(self, *buses: AbstractEventBus):
"""注册事件总线。
`*buses: List[AbstractEventBus]` 一个或多个事件总线。
`*buses: AbstractEventBus` 一个或多个事件总线。
"""
self.buses |= set(buses)

def unregister_event_bus(self, *buses: AbstractEventBus):
"""解除注册事件总线。
`*buses: List[AbstractEventBus]` 一个或多个事件总线。
`*buses: AbstractEventBus` 一个或多个事件总线。
"""
self.buses -= set(buses)

Expand All @@ -146,7 +147,7 @@ async def call_api(self, api: str, method: Method = Method.GET, **params):
async def _background(self):
"""背景事件循环,用于接收事件。"""

def start(self):
async def start(self):
"""运行背景事件循环。"""
if not self.buses:
raise RuntimeError('事件总线未指定!')
Expand All @@ -155,10 +156,10 @@ def start(self):

self.background = asyncio.create_task(self._background())

def shutdown(self):
async def shutdown(self):
"""停止背景事件循环。"""
if self.background:
self.background.cancel()
await Tasks.cancel(self.background)

async def emit(self, event: str, *args, **kwargs):
"""向事件总线发送一个事件。"""
Expand Down
2 changes: 1 addition & 1 deletion mirai/adapters/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,4 @@ async def _background(self):
self._tasks.create_task(self.poll_event())
await asyncio.sleep(self.poll_interval)
finally:
self._tasks.cancel_all()
await self._tasks.cancel_all()
9 changes: 4 additions & 5 deletions mirai/adapters/webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
此模块提供 HTTP 回调适配器,适用于 mirai-api-http 的 webhook adapter。
"""
import logging
from typing import Optional, cast
from typing import Mapping, Optional, cast

from starlette.requests import Request
from starlette.responses import JSONResponse
Expand All @@ -28,15 +28,15 @@ class WebHookAdapter(Adapter):
"""WebHook 不需要 session,此处为机器人的 QQ 号。"""
route: str
"""适配器的路由。"""
extra_headers: dict
extra_headers: Mapping[str, str]
"""额外请求头。"""
enable_quick_response: bool
"""是否启用快速响应。"""
def __init__(
self,
verify_key: Optional[str],
route: str = '/',
extra_headers: Optional[dict] = None,
extra_headers: Optional[Mapping[str, str]] = None,
enable_quick_response: bool = True,
single_mode: bool = False
):
Expand All @@ -45,7 +45,7 @@ def __init__(
`route: str = '/'` 适配器的路由,默认在根目录上提供服务。
`extra_headers: Optional[dict] = None` 额外请求头,与 mirai-api-http 的配置一致。
`extra_headers: Optional[Mapping[str, str]] = None` 额外请求头,与 mirai-api-http 的配置一致。
`enable_quick_response: bool = True` 是否启用快速响应,当与其他适配器混合使用时,
禁用可以提高响应速度。
Expand All @@ -59,7 +59,6 @@ def __init__(

async def endpoint(request: Request):
# 鉴权(QQ 号和额外请求头)
print(request.headers)
if request.headers.get('bot') != self.session: # 验证 QQ 号
logger.debug(f"收到来自其他账号({request.headers.get('bot')})的事件。")
return
Expand Down
37 changes: 30 additions & 7 deletions mirai/adapters/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import json
import logging
import random
import time
from collections import defaultdict, deque
from itertools import repeat
from typing import Dict, Optional, cast

from websockets.client import WebSocketClientProtocol, connect
Expand Down Expand Up @@ -40,13 +42,16 @@ class WebSocketAdapter(Adapter):
"""机器人的 QQ 号。"""
connection: WebSocketClientProtocol
"""WebSocket 客户端连接。"""
heartbeat_interval: float
"""每隔多久发送心跳包,单位:秒。"""
def __init__(
self,
verify_key: Optional[str],
host: str,
port: int,
sync_id: str = '-1',
single_mode: bool = False
single_mode: bool = False,
heartbeat_interval: float = 60.,
):
"""
`verify_key: str` mirai-api-http 配置的认证 key,关闭认证时为 None。
Expand All @@ -58,6 +63,8 @@ def __init__(
`sync_id: int` mirai-api-http 配置的同步 ID。
`single_mode: bool = False` 是否启用单例模式。
`heartbeat_interval: float = 60.` 每隔多久发送心跳包,单位:秒。
"""
super().__init__(verify_key=verify_key, single_mode=single_mode)

Expand All @@ -79,6 +86,8 @@ def __init__(
self.sync_id = sync_id # 这个神奇的 sync_id,默认值 -1,居然是个字符串
# 既然这样不如把 sync_id 全改成字符串好了

self.heartbeat_interval = heartbeat_interval

# 接收 WebSocket 数据的 Task
self._receiver_task: Optional[asyncio.Task] = None
# 用于临时保存接收到的数据,以便根据 sync_id 进行同步识别
Expand All @@ -87,6 +96,8 @@ def __init__(
self._local_sync_id = random.randint(1, 1024) * 1024
# 事件处理任务管理器
self._tasks = Tasks()
# 心跳机制(Keep-Alive):上次发送数据包的时间
self._last_send_time: float = 0.

@property
def adapter_info(self):
Expand Down Expand Up @@ -116,11 +127,11 @@ def via(cls, adapter_interface: AdapterInterface) -> "WebSocketAdapter":
@_error_handler_async_local
async def _receiver(self):
"""开始接收 websocket 数据。"""
if not self.connect:
if not self.connection:
raise exceptions.NetworkError(
f'WebSocket 通道 {self.host_name} 未连接!'
)
while self._started:
while True:
try:
# 数据格式:
# {
Expand Down Expand Up @@ -149,9 +160,10 @@ async def _receiver(self):
)
return

async def _recv(self, sync_id: str = '-1') -> dict:
async def _recv(self, sync_id: str = '-1', timeout: int = 600) -> dict:
"""接收并解析 websocket 数据。"""
for _ in range(600):
timer = range(timeout) if timeout > 0 else repeat(0)
for _ in timer:
if self._recv_dict[sync_id]:
return self._recv_dict[sync_id].popleft()
else:
Expand Down Expand Up @@ -195,7 +207,7 @@ async def logout(self, terminate: bool = True):

async def poll_event(self):
"""获取并处理事件。"""
event = await self._recv(self.sync_id)
event = await self._recv(self.sync_id, -1)

self._tasks.create_task(self.emit(event['type'], event))

Expand All @@ -217,18 +229,29 @@ async def call_api(self, api: str, method: Method = Method.GET, **params):
)

await self.connection.send(json_dumps(content))
self._last_send_time = time.time()
logger.debug(f"[WebSocket] 发送 WebSocket 数据,同步 ID:{sync_id}。")
try:
return await self._recv(sync_id)
except TimeoutError as e:
logger.debug(e)

async def _heartbeat(self):
while True:
await asyncio.sleep(self.heartbeat_interval)
if time.time() - self._last_send_time > self.heartbeat_interval:
await self.call_api('about')
self._last_send_time = time.time()
logger.debug("[WebSocket] 发送心跳包。")

async def _background(self):
"""开始接收事件。"""
logger.info('[WebSocket] 机器人开始运行。按 Ctrl + C 停止。')

try:
heartbeat = asyncio.create_task(self._heartbeat())
while True:
await self.poll_event()
finally:
self._tasks.cancel_all()
await Tasks.cancel(heartbeat)
await self._tasks.cancel_all()
9 changes: 5 additions & 4 deletions mirai/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ async def startup(self):
self.qq = (await self.call_api('sessionInfo'))['data']['qq']['id']

asyncio.create_task(self._adapter.emit("Startup", {'type': 'Startup'}))
self._adapter.start()
await self._adapter.start()

async def background(self):
"""等待背景任务完成。"""
Expand All @@ -151,7 +151,7 @@ async def shutdown(self):
self._adapter.emit("Shutdown", {'type': 'Shutdown'})
)
await self._adapter.logout()
self._adapter.shutdown()
await self._adapter.shutdown()

@property
def session(self) -> str:
Expand Down Expand Up @@ -331,14 +331,15 @@ def __getattr__(self, api: str) -> ApiModel.Proxy:
async def send(
self,
target: Union[Entity, MessageEvent],
message: Union[MessageChain, List[Union[MessageComponent, str]], MessageComponent, str],
message: Union[MessageChain, Iterable[Union[MessageComponent, str]],
MessageComponent, str],
quote: bool = False
) -> int:
"""发送消息。可以从 `Friend` `Group` 等对象,或者从 `MessageEvent` 中自动识别消息发送对象。
`target: Union[Entity, MessageEvent]` 目标对象。
`message: Union[MessageChain, List[Union[MessageComponent, str]], str]` 发送的消息。
`message: Union[MessageChain, Iterable[Union[MessageComponent, str]], str]` 发送的消息。
`quote: bool = False` 是否以回复消息的形式发送。
Expand Down
Loading

0 comments on commit fff9026

Please sign in to comment.