first commit
This commit is contained in:
commit
8498111505
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
aiortc
|
||||||
|
websockets
|
||||||
|
webrtcvad
|
||||||
|
faster-whisper
|
||||||
|
git+https://github.com/suno-ai/bark
|
||||||
|
nvidia-cublas-cu11
|
||||||
|
nvidia-cudnn-cu11
|
||||||
80
rtcllm/a2a.py
Normal file
80
rtcllm/a2a.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import os
|
||||||
|
from traceback import print_exc
|
||||||
|
import io
|
||||||
|
import asyncio
|
||||||
|
from av import AudioFrame
|
||||||
|
|
||||||
|
from aiortc import VideoStreamTrack, AudioStreamTrack
|
||||||
|
|
||||||
|
# 计算每个样本的字节数,对于's16'是2字节,'s32'是4字节等
|
||||||
|
def bytes_to_audio_frame(bytes_data, format='s16', layout='mono', sample_rate=16000):
|
||||||
|
"""
|
||||||
|
从字节数据构造av.AudioFrame对象。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- bytes_data: 字节数据,代表音频内容。
|
||||||
|
- format: 音频样本格式字符串,如's16'表示16位有符号整数。
|
||||||
|
- layout: 通道布局,如'mono'、'stereo'等。
|
||||||
|
- sample_rate: 采样率。
|
||||||
|
"""
|
||||||
|
# 根据给定的参数创建AudioFrame
|
||||||
|
frame = AudioFrame(format=format,
|
||||||
|
layout=layout,
|
||||||
|
samples=len(bytes_data) // 2)
|
||||||
|
|
||||||
|
# 将字节数据复制到AudioFrame中
|
||||||
|
frame.planes[0].update(bytes_data)
|
||||||
|
return frame
|
||||||
|
|
||||||
|
class LLMAudioStreamTrack(AudioStreamTrack):
|
||||||
|
def __init__(self, omni_infer):
|
||||||
|
super().__init__()
|
||||||
|
self.oi = omni_infer
|
||||||
|
self.audio_iters = []
|
||||||
|
self.cur_iters = None
|
||||||
|
self.tmp_files = []
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
try:
|
||||||
|
b = self.get_audio_bytes()
|
||||||
|
if b is None:
|
||||||
|
return await super().recv()
|
||||||
|
frame = bytes_to_audio_frame(b)
|
||||||
|
print('LLMAudioStreamTrack return frame ...')
|
||||||
|
return frame
|
||||||
|
except Exception as e:
|
||||||
|
print_exc()
|
||||||
|
print(f'{self.__class__.__name__} recv() exception happened')
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set_cur_audio_iter(self):
|
||||||
|
if len(self.audio_iters) == 0:
|
||||||
|
return False
|
||||||
|
self.cur_iters = self.audio_iters[0]
|
||||||
|
self.audio_iters.remove(self.cur_iters)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_audio_bytes(self):
|
||||||
|
if self.cur_iters is None:
|
||||||
|
if not self.set_cur_audio_iter():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
b = next(self.cur_iters)
|
||||||
|
return b
|
||||||
|
except StopIteration:
|
||||||
|
self.cur_iters = None
|
||||||
|
if len(self.tmp_files) > 0:
|
||||||
|
tf = self.tmp_files[0]
|
||||||
|
self.tmp_files.remove(tf)
|
||||||
|
os.remove(tf)
|
||||||
|
return self.get_audio_bytes()
|
||||||
|
|
||||||
|
def _feed(self, audio_file):
|
||||||
|
self.tmp_files.append(audio_file)
|
||||||
|
if audio_file is None:
|
||||||
|
print(f'*****{self.__class__.__name__}._feed(),{audio_file=}')
|
||||||
|
return
|
||||||
|
abiters = self.oi.run_AT_batch_stream(audio_file)
|
||||||
|
self.audio_iters.append(abiters)
|
||||||
|
return abiters
|
||||||
|
|
||||||
154
rtcllm/aav.py
Normal file
154
rtcllm/aav.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import random
|
||||||
|
from traceback import print_exc
|
||||||
|
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
|
||||||
|
from aiortc import MediaStreamTrack, VideoStreamTrack, AudioStreamTrack
|
||||||
|
from aiortc.mediastreams import MediaStreamError
|
||||||
|
|
||||||
|
class MyMediaPlayer(MediaPlayer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class MyTrackBase(MediaStreamTrack):
|
||||||
|
def __init__(self, source):
|
||||||
|
super().__init__()
|
||||||
|
self.dumb = None
|
||||||
|
self.source = source
|
||||||
|
self.set_source_track()
|
||||||
|
print(f'{self.kind=}, {self.__class__.__name__}, {dir(self)}')
|
||||||
|
|
||||||
|
def set_source_track(self):
|
||||||
|
if self.kind == 'audio':
|
||||||
|
self.track = self.source.audio
|
||||||
|
else:
|
||||||
|
self.track = self.source.video
|
||||||
|
|
||||||
|
def source_reload(self):
|
||||||
|
print(f'reload the source----{self.source._file_path}')
|
||||||
|
self.set_source(MyMediaPlayer(self.source._file_path))
|
||||||
|
|
||||||
|
def set_source(self, source):
|
||||||
|
self.source = source
|
||||||
|
self.set_source_track()
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
print(f'============{self.__class__.__name__}, {self.source.duration=}, {self.source.time=}')
|
||||||
|
if self.source is None:
|
||||||
|
return await self.dumb.recv()
|
||||||
|
try:
|
||||||
|
f = await self.track.recv()
|
||||||
|
return f
|
||||||
|
except MediaStreamError:
|
||||||
|
return await self.dumb.recv()
|
||||||
|
except Exception as e:
|
||||||
|
print_exc()
|
||||||
|
print(f'{e}')
|
||||||
|
return await self.dumb.recv()
|
||||||
|
return f
|
||||||
|
|
||||||
|
class MyAudioStreamTrack(MyTrackBase):
|
||||||
|
kind = 'audio'
|
||||||
|
def __init__(self, source):
|
||||||
|
super().__init__(source)
|
||||||
|
self.dumb = AudioStreamTrack()
|
||||||
|
|
||||||
|
class MyVideoStreamTrack(MyTrackBase):
|
||||||
|
kind = 'video'
|
||||||
|
def __init__(self, source):
|
||||||
|
super().__init__(source)
|
||||||
|
self.dumb = AudioStreamTrack()
|
||||||
|
|
||||||
|
"""
|
||||||
|
class MyAudioStreamTrack(AudioStreamTrack):
|
||||||
|
def __init__(self, source=None):
|
||||||
|
super().__init__()
|
||||||
|
self.source = source
|
||||||
|
self.set_source_track()
|
||||||
|
print(f'{self.kind=}, {self.__class__.__name__}, {dir(self)}')
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
try:
|
||||||
|
return await self._recv()
|
||||||
|
except MediaStreamError:
|
||||||
|
print(f'{self.__class__.__name__} reach ended of the media ...')
|
||||||
|
self.source_reload()
|
||||||
|
return await self._recv()
|
||||||
|
except Exception as e:
|
||||||
|
print_exc()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def _recv(self):
|
||||||
|
if self.source is None:
|
||||||
|
print('self._recv() return None')
|
||||||
|
return None
|
||||||
|
f = await self.track.recv()
|
||||||
|
return f
|
||||||
|
|
||||||
|
def set_source_track(self):
|
||||||
|
if self.kind == 'audio':
|
||||||
|
self.track = self.source.audio
|
||||||
|
else:
|
||||||
|
self.track = self.source.video
|
||||||
|
|
||||||
|
def source_reload(self):
|
||||||
|
print(f'reload the source----{self.source._file_path}')
|
||||||
|
self.set_source(MyMediaPlayer(self.source._file_path))
|
||||||
|
|
||||||
|
def set_source(self, source):
|
||||||
|
self.source = source
|
||||||
|
self.set_source_track()
|
||||||
|
|
||||||
|
class MyVideoStreamTrack(VideoStreamTrack):
|
||||||
|
def __init__(self, source=None):
|
||||||
|
super().__init__()
|
||||||
|
self.source = source
|
||||||
|
self.set_source_track()
|
||||||
|
print(f'{self.kind=}, {self.__class__.__name__}, {dir(self)}')
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
try:
|
||||||
|
return await self._recv()
|
||||||
|
except MediaStreamError:
|
||||||
|
print(f'{self.__class__.__name__} reach ended of the media ...')
|
||||||
|
self.source_reload()
|
||||||
|
return self._recv()
|
||||||
|
except Exception as e:
|
||||||
|
print_exc()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def _recv(self):
|
||||||
|
if self.source is None:
|
||||||
|
return None
|
||||||
|
f = await self.track.recv()
|
||||||
|
return f
|
||||||
|
|
||||||
|
def set_source_track(self):
|
||||||
|
if self.kind == 'audio':
|
||||||
|
self.track = self.source.audio
|
||||||
|
else:
|
||||||
|
self.track = self.source.video
|
||||||
|
|
||||||
|
def source_reload(self):
|
||||||
|
print(f'reload the source----{self.source._file_path}')
|
||||||
|
self.set_source(MyMediaPlayer(self.source._file_path))
|
||||||
|
|
||||||
|
def set_source(self, source):
|
||||||
|
self.source = source
|
||||||
|
self.set_source_track()
|
||||||
|
|
||||||
|
|
||||||
|
class LoopingVideoTrack(VideoStreamTrack):
|
||||||
|
'''
|
||||||
|
A video track that loops a video file.
|
||||||
|
'''
|
||||||
|
def __init__(self, filename):
|
||||||
|
super().__init__()
|
||||||
|
self.player = MyMediaPlayer(filename)
|
||||||
|
print(dir(self.player))
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
frame = await self.player.video.recv()
|
||||||
|
if self.player.video.readyState != 'live':
|
||||||
|
self.player = MyMediaPlayer(self.player._file_path)
|
||||||
|
frame = await self.player.video.recv()
|
||||||
|
return frame
|
||||||
|
"""
|
||||||
|
|
||||||
75
rtcllm/audio_mix.py
Normal file
75
rtcllm/audio_mix.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import asyncio
|
||||||
|
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
|
||||||
|
from av import AudioFrame
|
||||||
|
from aiortc.contrib.media import MediaPlayer
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class MixedAudioTrack(MediaStreamTrack):
|
||||||
|
kind = "audio"
|
||||||
|
|
||||||
|
def __init__(self, tracks):
|
||||||
|
super().__init__()
|
||||||
|
self.tracks = tracks
|
||||||
|
|
||||||
|
def add_track(self, track):
|
||||||
|
if track in self.tracks:
|
||||||
|
return
|
||||||
|
self.tracks.append(track)
|
||||||
|
|
||||||
|
de del_track(self, track):
|
||||||
|
tracks = [ t for t in self.tracks if t != track ]
|
||||||
|
self.tracks = tracks
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
# 获取所有音频轨道的数据
|
||||||
|
audio_data = []
|
||||||
|
for track in self.tracks:
|
||||||
|
frame = await track.recv()
|
||||||
|
if frame:
|
||||||
|
audio_data.append(frame.to_ndarray())
|
||||||
|
|
||||||
|
# 检查是否有有效的音频数据
|
||||||
|
if not audio_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 将音频数据转换为 numpy 数组
|
||||||
|
audio_arrays = [np.frombuffer(data, dtype=np.int16) for data in audio_data]
|
||||||
|
|
||||||
|
# 确保所有音频数组长度相同
|
||||||
|
min_length = min(len(arr) for arr in audio_arrays)
|
||||||
|
audio_arrays = [arr[:min_length] for arr in audio_arrays]
|
||||||
|
|
||||||
|
# 混合音频数据
|
||||||
|
mixed_audio = np.sum(audio_arrays, axis=0, dtype=np.int16)
|
||||||
|
|
||||||
|
# 创建新的音频帧
|
||||||
|
new_frame = AudioFrame(format="s16", layout="stereo", samples=len(mixed_audio) // 2)
|
||||||
|
new_frame.planes[0].update(mixed_audio.tobytes())
|
||||||
|
new_frame.pts = self._timestamp
|
||||||
|
self._timestamp += new_frame.samples
|
||||||
|
return new_frame
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 示例:创建两个音频轨道
|
||||||
|
track1 = MediaPlayer('audio1.wav').audio
|
||||||
|
track2 = MediaPlayer('audio2.wav').audio
|
||||||
|
|
||||||
|
# 创建混合音频轨道
|
||||||
|
mixed_track = MixedAudioTrack([track1, track2])
|
||||||
|
|
||||||
|
# 创建 RTCPeerConnection 并添加混合音频轨道
|
||||||
|
pc = RTCPeerConnection()
|
||||||
|
pc.addTrack(mixed_track)
|
||||||
|
|
||||||
|
# 以下部分是用于建立 WebRTC 连接的代码
|
||||||
|
# 你可以根据需要进行修改和扩展
|
||||||
|
async def create_offer():
|
||||||
|
offer = await pc.createOffer()
|
||||||
|
await pc.setLocalDescription(offer)
|
||||||
|
print("Local description set successfully")
|
||||||
|
# 通常在这里会发送 SDP 到对端
|
||||||
|
|
||||||
|
# 启动异步事件循环
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(create_offer())
|
||||||
|
loop.run_forever()
|
||||||
214
rtcllm/examples.py
Normal file
214
rtcllm/examples.py
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import ssl
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
from aiohttp import web
|
||||||
|
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||||
|
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
|
||||||
|
from av import VideoFrame
|
||||||
|
|
||||||
|
ROOT = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
logger = logging.getLogger("pc")
|
||||||
|
pcs = set()
|
||||||
|
relay = MediaRelay()
|
||||||
|
|
||||||
|
|
||||||
|
class VideoTransformTrack(MediaStreamTrack):
|
||||||
|
"""
|
||||||
|
A video stream track that transforms frames from an another track.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kind = "video"
|
||||||
|
|
||||||
|
def __init__(self, track, transform):
|
||||||
|
super().__init__() # don't forget this!
|
||||||
|
self.track = track
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
frame = await self.track.recv()
|
||||||
|
|
||||||
|
if self.transform == "cartoon":
|
||||||
|
img = frame.to_ndarray(format="bgr24")
|
||||||
|
|
||||||
|
# prepare color
|
||||||
|
img_color = cv2.pyrDown(cv2.pyrDown(img))
|
||||||
|
for _ in range(6):
|
||||||
|
img_color = cv2.bilateralFilter(img_color, 9, 9, 7)
|
||||||
|
img_color = cv2.pyrUp(cv2.pyrUp(img_color))
|
||||||
|
|
||||||
|
# prepare edges
|
||||||
|
img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
||||||
|
img_edges = cv2.adaptiveThreshold(
|
||||||
|
cv2.medianBlur(img_edges, 7),
|
||||||
|
255,
|
||||||
|
cv2.ADAPTIVE_THRESH_MEAN_C,
|
||||||
|
cv2.THRESH_BINARY,
|
||||||
|
9,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB)
|
||||||
|
|
||||||
|
# combine color and edges
|
||||||
|
img = cv2.bitwise_and(img_color, img_edges)
|
||||||
|
|
||||||
|
# rebuild a VideoFrame, preserving timing information
|
||||||
|
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
|
||||||
|
new_frame.pts = frame.pts
|
||||||
|
new_frame.time_base = frame.time_base
|
||||||
|
return new_frame
|
||||||
|
elif self.transform == "edges":
|
||||||
|
# perform edge detection
|
||||||
|
img = frame.to_ndarray(format="bgr24")
|
||||||
|
img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
|
# rebuild a VideoFrame, preserving timing information
|
||||||
|
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
|
||||||
|
new_frame.pts = frame.pts
|
||||||
|
new_frame.time_base = frame.time_base
|
||||||
|
return new_frame
|
||||||
|
elif self.transform == "rotate":
|
||||||
|
# rotate image
|
||||||
|
img = frame.to_ndarray(format="bgr24")
|
||||||
|
rows, cols, _ = img.shape
|
||||||
|
M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1)
|
||||||
|
img = cv2.warpAffine(img, M, (cols, rows))
|
||||||
|
|
||||||
|
# rebuild a VideoFrame, preserving timing information
|
||||||
|
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
|
||||||
|
new_frame.pts = frame.pts
|
||||||
|
new_frame.time_base = frame.time_base
|
||||||
|
return new_frame
|
||||||
|
else:
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
|
async def index(request):
|
||||||
|
content = open(os.path.join(ROOT, "index.html"), "r").read()
|
||||||
|
return web.Response(content_type="text/html", text=content)
|
||||||
|
|
||||||
|
|
||||||
|
async def javascript(request):
|
||||||
|
content = open(os.path.join(ROOT, "client.js"), "r").read()
|
||||||
|
return web.Response(content_type="application/javascript", text=content)
|
||||||
|
|
||||||
|
|
||||||
|
async def offer(request):
|
||||||
|
params = await request.json()
|
||||||
|
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
||||||
|
|
||||||
|
pc = RTCPeerConnection()
|
||||||
|
pc_id = "PeerConnection(%s)" % uuid.uuid4()
|
||||||
|
pcs.add(pc)
|
||||||
|
|
||||||
|
def log_info(msg, *args):
|
||||||
|
logger.info(pc_id + " " + msg, *args)
|
||||||
|
|
||||||
|
log_info("Created for %s", request.remote)
|
||||||
|
|
||||||
|
# prepare local media
|
||||||
|
player = MediaPlayer(os.path.join(ROOT, "demo-instruct.wav"))
|
||||||
|
if args.record_to:
|
||||||
|
recorder = MediaRecorder(args.record_to)
|
||||||
|
else:
|
||||||
|
recorder = MediaBlackhole()
|
||||||
|
|
||||||
|
@pc.on("datachannel")
|
||||||
|
def on_datachannel(channel):
|
||||||
|
@channel.on("message")
|
||||||
|
def on_message(message):
|
||||||
|
if isinstance(message, str) and message.startswith("ping"):
|
||||||
|
channel.send("pong" + message[4:])
|
||||||
|
|
||||||
|
@pc.on("connectionstatechange")
|
||||||
|
async def on_connectionstatechange():
|
||||||
|
log_info("Connection state is %s", pc.connectionState)
|
||||||
|
if pc.connectionState == "failed":
|
||||||
|
await pc.close()
|
||||||
|
pcs.discard(pc)
|
||||||
|
|
||||||
|
@pc.on("track")
|
||||||
|
def on_track(track):
|
||||||
|
log_info("Track %s received", track.kind)
|
||||||
|
|
||||||
|
if track.kind == "audio":
|
||||||
|
pc.addTrack(player.audio)
|
||||||
|
recorder.addTrack(track)
|
||||||
|
elif track.kind == "video":
|
||||||
|
pc.addTrack(
|
||||||
|
VideoTransformTrack(
|
||||||
|
relay.subscribe(track), transform=params["video_transform"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if args.record_to:
|
||||||
|
recorder.addTrack(relay.subscribe(track))
|
||||||
|
|
||||||
|
@track.on("ended")
|
||||||
|
async def on_ended():
|
||||||
|
log_info("Track %s ended", track.kind)
|
||||||
|
await recorder.stop()
|
||||||
|
|
||||||
|
# handle offer
|
||||||
|
await pc.setRemoteDescription(offer)
|
||||||
|
await recorder.start()
|
||||||
|
|
||||||
|
# send answer
|
||||||
|
answer = await pc.createAnswer()
|
||||||
|
await pc.setLocalDescription(answer)
|
||||||
|
|
||||||
|
return web.Response(
|
||||||
|
content_type="application/json",
|
||||||
|
text=json.dumps(
|
||||||
|
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def on_shutdown(app):
|
||||||
|
# close peer connections
|
||||||
|
coros = [pc.close() for pc in pcs]
|
||||||
|
await asyncio.gather(*coros)
|
||||||
|
pcs.clear()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="WebRTC audio / video / data-channels demo"
|
||||||
|
)
|
||||||
|
parser.add_argument("--cert-file", help="SSL certificate file (for HTTPS)")
|
||||||
|
parser.add_argument("--key-file", help="SSL key file (for HTTPS)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--host", default="0.0.0.0", help="Host for HTTP server (default: 0.0.0.0)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port", type=int, default=8080, help="Port for HTTP server (default: 8080)"
|
||||||
|
)
|
||||||
|
parser.add_argument("--record-to", help="Write received media to a file.")
|
||||||
|
parser.add_argument("--verbose", "-v", action="count")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.verbose:
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
else:
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
if args.cert_file:
|
||||||
|
ssl_context = ssl.SSLContext()
|
||||||
|
ssl_context.load_cert_chain(args.cert_file, args.key_file)
|
||||||
|
else:
|
||||||
|
ssl_context = None
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.on_shutdown.append(on_shutdown)
|
||||||
|
app.router.add_get("/", index)
|
||||||
|
app.router.add_get("/client.js", javascript)
|
||||||
|
app.router.add_post("/offer", offer)
|
||||||
|
web.run_app(
|
||||||
|
app, access_log=None, host=args.host, port=args.port, ssl_context=ssl_context
|
||||||
|
)
|
||||||
70
rtcllm/rec.py
Normal file
70
rtcllm/rec.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import pyaudio
|
||||||
|
import av
|
||||||
|
import numpy as np
|
||||||
|
from vad import frames_write_wave
|
||||||
|
import sys
|
||||||
|
import select
|
||||||
|
from vad import MyVad, bytes2frame, to16000_160_frames
|
||||||
|
|
||||||
|
def check_input():
|
||||||
|
if select.select([sys.stdin], [], [], 0.0)[0]:
|
||||||
|
input_str = sys.stdin.readline().strip()
|
||||||
|
return input_str
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def mic(seconds=None, frames_ms=20):
|
||||||
|
# PyAudio setup
|
||||||
|
#
|
||||||
|
p = pyaudio.PyAudio()
|
||||||
|
|
||||||
|
# Audio format
|
||||||
|
FORMAT = pyaudio.paInt16 # 16-bit resolution
|
||||||
|
CHANNELS = 1 # 1 channel for mono
|
||||||
|
RATE = 16000 # 44.1kHz sampling rate
|
||||||
|
CHUNK = int(RATE / 100) * 2 # 1024 samples per frame
|
||||||
|
# PyAudio input stream
|
||||||
|
stream = p.open(format=FORMAT,
|
||||||
|
channels=CHANNELS,
|
||||||
|
rate=RATE,
|
||||||
|
input=True,
|
||||||
|
frames_per_buffer=CHUNK)
|
||||||
|
|
||||||
|
print("Recording and streaming audio...")
|
||||||
|
if seconds is not None:
|
||||||
|
cnt = int(seconds * 1000 / frames_ms)
|
||||||
|
else:
|
||||||
|
cnt = 0
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
if cnt == 0:
|
||||||
|
if check_input() == 'q':
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if i>=cnt:
|
||||||
|
break
|
||||||
|
# Read audio frames from the microphone
|
||||||
|
data = stream.read(CHUNK)
|
||||||
|
# print(f'{data.__class__.__name__}, {len(data)=}, {CHUNK=}')
|
||||||
|
frames = to16000_160_frames(bytes2frame(data))
|
||||||
|
for frame in frames:
|
||||||
|
yield frame
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# Close the audio stream and the container
|
||||||
|
stream.close()
|
||||||
|
p.terminate()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
def cb(f):
|
||||||
|
print(f'{f} voice wave file')
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
vad = MyVad(callback=cb)
|
||||||
|
for f in mic():
|
||||||
|
if i == 0:
|
||||||
|
print(f'{f.sample_rate=}, {f.samples=}, {f.layout=} ')
|
||||||
|
i += 1
|
||||||
|
vad.vad_check(f)
|
||||||
|
|
||||||
|
print(f'record save to {f}')
|
||||||
246
rtcllm/rtc.old.py
Normal file
246
rtcllm/rtc.old.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.append('./mini_omni')
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from appPublic.dictObject import DictObject
|
||||||
|
from appPublic.hf import hf_socks5proxy
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from aiortc import MediaStreamTrack, VideoStreamTrack, AudioStreamTrack, RTCPeerConnection, RTCSessionDescription, RTCIceCandidate
|
||||||
|
from aiortc.sdp import candidate_from_sdp, candidate_to_sdp
|
||||||
|
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
|
||||||
|
|
||||||
|
# from websockets.asyncio.client import connect
|
||||||
|
# from websockets.asyncio.client import connect
|
||||||
|
from websockets.client import connect
|
||||||
|
import websockets
|
||||||
|
from stt import asr
|
||||||
|
from vad import AudioTrackVad
|
||||||
|
from a2a import LLMAudioStreamTrack
|
||||||
|
from aav import MyMediaPlayer, MyAudioStreamTrack, MyVideoStreamTrack
|
||||||
|
|
||||||
|
from mini_omni.inference import OmniInference
|
||||||
|
|
||||||
|
videos = ['./1.mp4', './2.mp4']
|
||||||
|
|
||||||
|
async def pc_get_local_candidates(pc, peer):
|
||||||
|
its = pc.__dict__.get('_RTCPeerConnection__iceTransports')
|
||||||
|
coros = map(lambda t: t.iceGatherer.gather(), its )
|
||||||
|
await asyncio.gather(*coros)
|
||||||
|
ss = pc.__dict__.get('_RTCPeerConnection__iceTransport')
|
||||||
|
|
||||||
|
if not peer.l_candidates:
|
||||||
|
peer.l_candidates = []
|
||||||
|
peer.sdp_id = 0
|
||||||
|
|
||||||
|
for t in its:
|
||||||
|
for c in t._connection.local_candidates:
|
||||||
|
if c not in peer.l_candidates:
|
||||||
|
# print(f'{c=}, {dir(c)}')
|
||||||
|
c.sdpMid = str(peer.sdp_id)
|
||||||
|
peer.sdp_id += 1
|
||||||
|
peer.l_candidates.append(c)
|
||||||
|
pc.emit('icecandidate', c)
|
||||||
|
|
||||||
|
RTCPeerConnection.get_local_candidates = pc_get_local_candidates
|
||||||
|
|
||||||
|
class RTCLLM:
|
||||||
|
def __init__(self, ws_url, iceServers):
|
||||||
|
self.ws_url = ws_url
|
||||||
|
self.omni_infer = OmniInference(ckpt_dir='/d/models/mini-omni')
|
||||||
|
self.iceServers = iceServers
|
||||||
|
self.peers = DictObject()
|
||||||
|
self.vid = 0
|
||||||
|
self.info = {
|
||||||
|
'id':'rtcllm_agent',
|
||||||
|
'name':'rtcllm_agent'
|
||||||
|
}
|
||||||
|
self.dc = None
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
async def ws_send(self, s):
|
||||||
|
await self.ws.send(s)
|
||||||
|
|
||||||
|
def get_pc(self, data):
|
||||||
|
return self.peers[data['from'].id].pc
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
async with connect(self.ws_url) as self.ws:
|
||||||
|
await self.login()
|
||||||
|
while True:
|
||||||
|
msg = await self.ws.recv()
|
||||||
|
data = DictObject(**json.loads(msg))
|
||||||
|
d = data.data
|
||||||
|
print(f'ws recv(): {d.type=}')
|
||||||
|
func = self.handlers.get(d.type)
|
||||||
|
if func:
|
||||||
|
f = partial(func, self)
|
||||||
|
await f(d)
|
||||||
|
self.ws.close()
|
||||||
|
|
||||||
|
async def login(self):
|
||||||
|
await self.ws_send(json.dumps({
|
||||||
|
'type':'login',
|
||||||
|
'info':self.info
|
||||||
|
}))
|
||||||
|
|
||||||
|
async def on_icecandidate(self, pc, to, candidate):
|
||||||
|
if candidate:
|
||||||
|
candi = {
|
||||||
|
'candidate':'candidate:' + candidate.to_sdp(),
|
||||||
|
'sdpMid':candidate.sdpMid,
|
||||||
|
'type': candidate.type
|
||||||
|
}
|
||||||
|
# print('***********on_icecandidate()', candi)
|
||||||
|
await self.ws_send(json.dumps({
|
||||||
|
"type":"iceCandidate",
|
||||||
|
"to":to,
|
||||||
|
"candidate":candi
|
||||||
|
}))
|
||||||
|
|
||||||
|
async def save_onlineList(self, data):
|
||||||
|
# print(f'{self}, {type(self)}')
|
||||||
|
self.onlineList = data.onlineList
|
||||||
|
|
||||||
|
async def vad_voiceend(self, peer, audio):
|
||||||
|
if audio is not None:
|
||||||
|
feed = awaitify(peer.llmtrack._feed)
|
||||||
|
ret = await feed(audio)
|
||||||
|
print(f'self.feed("{audio}") return {ret}')
|
||||||
|
# os.remove(audio)
|
||||||
|
|
||||||
|
async def auto_accept_call(self, data):
|
||||||
|
opts = DictObject(iceServers=self.iceServers)
|
||||||
|
pc = RTCPeerConnection(opts)
|
||||||
|
player = MyMediaPlayer('./1.mp4')
|
||||||
|
llmtrack = LLMAudioStreamTrack(self.omni_infer)
|
||||||
|
self.peers[data['from'].id] = DictObject(**{
|
||||||
|
'info':data['from'],
|
||||||
|
'llmtrack':llmtrack,
|
||||||
|
'player':player,
|
||||||
|
'pc':pc
|
||||||
|
})
|
||||||
|
pc.addTrack(llmtrack)
|
||||||
|
# pc.addTrack(MyAudioStreamTrack(player))
|
||||||
|
pc.addTrack(MyVideoStreamTrack(player))
|
||||||
|
# pc.addTrack(LoopingVideoTrack('./1.mp4'))
|
||||||
|
# pc.addTrack(player.video)
|
||||||
|
await self.ws_send(json.dumps({'type':'callAccepted', 'to':data['from']}))
|
||||||
|
print('auto_accept_call() end')
|
||||||
|
|
||||||
|
async def pc_track(self, peerid, track):
|
||||||
|
peer = self.peers[peerid]
|
||||||
|
pc = peer.pc
|
||||||
|
if track.kind == 'audio':
|
||||||
|
f = partial(self.vad_voiceend, peer)
|
||||||
|
vadtrack = AudioTrackVad(track, stage=3, onvoiceend=f)
|
||||||
|
peer.vadtrack = vadtrack
|
||||||
|
vadtrack.start_vad()
|
||||||
|
|
||||||
|
async def pc_connectionState_changed(self, peerid):
|
||||||
|
peer = self.peers[peerid]
|
||||||
|
pc = peer.pc
|
||||||
|
print(f'conn_state={pc.connectionState} ...........')
|
||||||
|
if pc.connectionState == 'connected':
|
||||||
|
peer.dc = pc.createDataChannel(peer.info.name)
|
||||||
|
return
|
||||||
|
if pc.connectionState == 'closed':
|
||||||
|
await pc.close()
|
||||||
|
if peer.dc:
|
||||||
|
await peer.dc.close()
|
||||||
|
await self.ws_send(json.dumps({
|
||||||
|
'type':'disconnect',
|
||||||
|
'to': peer.info
|
||||||
|
}))
|
||||||
|
peers = {
|
||||||
|
k:v for k,v in self.peers.items() if k != peerid
|
||||||
|
}
|
||||||
|
self.peers = peers
|
||||||
|
|
||||||
|
async def response_offer(self, data):
|
||||||
|
pc = self.get_pc(data)
|
||||||
|
peer = self.peers[data['from'].id]
|
||||||
|
if pc is None:
|
||||||
|
print(f'{self.peers=}, {data=}')
|
||||||
|
return
|
||||||
|
pc.on("connectionstatechange", partial(self.pc_connectionState_changed, data['from'].id))
|
||||||
|
pc.on('track', partial(self.pc_track, data['from'].id))
|
||||||
|
pc.on('icecandidate', partial(self.on_icecandidate, pc, data['from']))
|
||||||
|
|
||||||
|
offer = RTCSessionDescription(** data.offer)
|
||||||
|
await pc.setRemoteDescription(offer)
|
||||||
|
answer = await pc.createAnswer()
|
||||||
|
await pc.setLocalDescription(answer)
|
||||||
|
await self.ws_send(json.dumps({
|
||||||
|
'type':'answer',
|
||||||
|
'answer':{'type':pc.localDescription.type, 'sdp':pc.localDescription.sdp},
|
||||||
|
'to':data['from']
|
||||||
|
}))
|
||||||
|
cands = await pc.get_local_candidates(peer)
|
||||||
|
|
||||||
|
"""
|
||||||
|
offer = await pc.createOffer()
|
||||||
|
await pc.setLocalDescription(offer)
|
||||||
|
await self.ws_send(json.dumps({
|
||||||
|
'type':'offer',
|
||||||
|
'offer': {'type':pc.localDescription.type, 'sdp':pc.localDescription.sdp},
|
||||||
|
'to':data['from']
|
||||||
|
}))
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def accept_answer(self, data):
|
||||||
|
pc = self.get_pc(data)
|
||||||
|
answer = RTCSessionDescription(data.answer);
|
||||||
|
pc.setRemoteDescription(answer)
|
||||||
|
|
||||||
|
async def accept_iceCandidate(self, data):
|
||||||
|
pc = self.get_pc(data)
|
||||||
|
candidate = data.candidate
|
||||||
|
# print('accepted candidate=', candidate)
|
||||||
|
"""
|
||||||
|
rtc_candidate = RTCIceCandidate(
|
||||||
|
ip=ip,
|
||||||
|
port=port,
|
||||||
|
protocol=protocol,
|
||||||
|
priority=priority,
|
||||||
|
foundation=foundation,
|
||||||
|
component=component,
|
||||||
|
type=type,
|
||||||
|
sdpMid=candidate['sdpMid'],
|
||||||
|
sdpMLineIndex=candidate['sdpMLineIndex']
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
rtc_candidate = candidate_from_sdp(candidate['candidate'].split(":", 1)[1])
|
||||||
|
rtc_candidate.sdpMid = candidate['sdpMid']
|
||||||
|
rtc_candidate.sdpMLineIndex = candidate['sdpMLineIndex']
|
||||||
|
await pc.addIceCandidate(rtc_candidate)
|
||||||
|
# print('addIceCandidate ok')
|
||||||
|
|
||||||
|
handlers = {
|
||||||
|
'onlineList':save_onlineList,
|
||||||
|
'callRequest':auto_accept_call,
|
||||||
|
'offer':response_offer,
|
||||||
|
'answer':accept_answer,
|
||||||
|
'iceCandidate':accept_iceCandidate,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
hf_socks5proxy()
|
||||||
|
agent = RTCLLM(ws_url='wss://sage.open-computing.cn/wss/ws/rtc_signaling.ws',
|
||||||
|
iceServers=[{
|
||||||
|
'urls':'stun:stun.open-computing.cn:13478'},{
|
||||||
|
'urls':'turn:stun.open-computing.cn:13479',
|
||||||
|
'username':'turn',
|
||||||
|
'credential':'server'
|
||||||
|
}])
|
||||||
|
print('running ...')
|
||||||
|
await agent.run()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
426
rtcllm/rtc.py
Normal file
426
rtcllm/rtc.py
Normal file
@ -0,0 +1,426 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.append('./mini_omni')
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from appPublic.dictObject import DictObject
|
||||||
|
from appPublic.hf import hf_socks5proxy
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from appPublic.log import debug, error, exception, critical
|
||||||
|
from aiortc import MediaStreamTrack, VideoStreamTrack, AudioStreamTrack, RTCPeerConnection, RTCSessionDescription, RTCIceCandidate
|
||||||
|
from aiortc.sdp import candidate_from_sdp, candidate_to_sdp
|
||||||
|
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
|
||||||
|
|
||||||
|
# from websockets.asyncio.client import connect
|
||||||
|
# from websockets.asyncio.client import connect
|
||||||
|
from websockets.client import connect
|
||||||
|
import websockets
|
||||||
|
from stt import asr
|
||||||
|
from vad import AudioTrackVad
|
||||||
|
from a2a import LLMAudioStreamTrack
|
||||||
|
from aav import MyMediaPlayer, MyAudioStreamTrack, MyVideoStreamTrack
|
||||||
|
|
||||||
|
from mini_omni.inference import OmniInference
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class RTCSignaling:
|
||||||
|
/*
|
||||||
|
{
|
||||||
|
signaling_url:
|
||||||
|
info:
|
||||||
|
connect_opts:
|
||||||
|
onclose:
|
||||||
|
onlogin:
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
def __init__(self, opts):
|
||||||
|
self.signaling_url = opts.signaling_url
|
||||||
|
self.info = opts.info
|
||||||
|
self.connect_opts = opts.connect_opts
|
||||||
|
self.peers = []
|
||||||
|
self.sessions = DictObject()
|
||||||
|
self.handlers = DictObject()
|
||||||
|
self.sessionhandlers = DictObject()
|
||||||
|
self.init_websocket()
|
||||||
|
self.hb_task = None
|
||||||
|
self.socket = None
|
||||||
|
self.heartbeat_period = opts.heartbeat_period
|
||||||
|
self.onclose = opts.onclose
|
||||||
|
self.onlogin = opts.onlogin
|
||||||
|
self.onopen = opts.onopen
|
||||||
|
if not self.heartbeat_period:
|
||||||
|
self.heartbeat_period = 0
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
async with connect(self.ws_url) as self.socket:
|
||||||
|
await self.login()
|
||||||
|
while True:
|
||||||
|
msg = await self.socket.recv()
|
||||||
|
data = DictObject(**json.loads(msg))
|
||||||
|
d = data.data
|
||||||
|
debug(f'ws recv(): {d.type=}')
|
||||||
|
await self.signaling_recvdata(d)
|
||||||
|
await self.socket.close()
|
||||||
|
|
||||||
|
def reconnect(self):
|
||||||
|
debug('eror happened')
|
||||||
|
if self.hb_task:
|
||||||
|
self.hb_task.cancel()
|
||||||
|
self.hb_task = None
|
||||||
|
|
||||||
|
if self.onclose:
|
||||||
|
self.onclose()
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def del_handler(self, sessionid):
|
||||||
|
handlers = { k:v \
|
||||||
|
for k,v in self.handlers.items() if k!=sessionid}
|
||||||
|
self.handlers = handlers
|
||||||
|
|
||||||
|
async def signaling_recvdata(self, data):
|
||||||
|
debug(f'ws recv data={data}')
|
||||||
|
if data.session:
|
||||||
|
sessionid = data.session.sessionid
|
||||||
|
sessiontype = data.session.sessiontype
|
||||||
|
handler = self.handlers[sessionid]
|
||||||
|
if not handler:
|
||||||
|
k = self.sessionhandlers[sessiontype]
|
||||||
|
if not k:
|
||||||
|
e = Exception('recvdata_handler() exception(' + sessiontype + ')')
|
||||||
|
raise e
|
||||||
|
|
||||||
|
h = k(self, data.session, self.connect_opts)
|
||||||
|
handler = h.recvdata_handler.bind(h)
|
||||||
|
self.add_handler(sessionid, handler)
|
||||||
|
|
||||||
|
await handler(data)
|
||||||
|
else:
|
||||||
|
await self.recvdata_handler(data)
|
||||||
|
|
||||||
|
async def add_handler(self, key, handler):
|
||||||
|
self.handlers[key] = handler
|
||||||
|
|
||||||
|
async def add_sessionhandler(self, sessiontype, handler):
|
||||||
|
self.sessionhandlers[sessiontype] = handler
|
||||||
|
|
||||||
|
async def recvdata_handler(self, data):
|
||||||
|
debug(f'recv data={data}')
|
||||||
|
if data.type == 'online'):
|
||||||
|
for p in data.online:
|
||||||
|
d = self.peers[p.id]
|
||||||
|
if (!d) d = DictOject()
|
||||||
|
d.update(p)
|
||||||
|
self.peers[p.id] = d
|
||||||
|
|
||||||
|
if self.onlogin:
|
||||||
|
self.onlogin(data.online)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
debug(f'recv data= {data} NOT HANDLED')
|
||||||
|
|
||||||
|
async def new_session(self, sessiontype, peer):
|
||||||
|
k = self.sessionhandlers[sessiontype]
|
||||||
|
if not k:
|
||||||
|
e = Exception('new_session() exception(' + sessiontype + ')')
|
||||||
|
raise e
|
||||||
|
sessionid = getID()
|
||||||
|
session = DictObject(**{
|
||||||
|
sessiontype:sessiontype,
|
||||||
|
sessionid:sessionid
|
||||||
|
})
|
||||||
|
d = DictObject(**{
|
||||||
|
"type":'new_session',
|
||||||
|
"session": session
|
||||||
|
})
|
||||||
|
await self.send_data(d)
|
||||||
|
opts = self.connect_options.clone()
|
||||||
|
opts.peer_info = peer
|
||||||
|
h = k(self, session, opts)
|
||||||
|
self.add_handler(sessionid, h.recvdata_handler)
|
||||||
|
return h
|
||||||
|
|
||||||
|
async def login(self):
|
||||||
|
debug(f'login send {self.heartbeat_period=}')
|
||||||
|
d = DictObject(
|
||||||
|
type='login'
|
||||||
|
)
|
||||||
|
self.send_data(d)
|
||||||
|
if self.heartbeat_period > 0):
|
||||||
|
debug(f'call login again in {self.heartbeat_period} seconds')
|
||||||
|
self.hb_task = schedule_once(self.login.bind(self), self.heartbeat_period)
|
||||||
|
|
||||||
|
|
||||||
|
async def logout(self):
|
||||||
|
d = DictObject(
|
||||||
|
type='logout'
|
||||||
|
)
|
||||||
|
await self.send_data(d)
|
||||||
|
|
||||||
|
await send_data(self, d):
|
||||||
|
d.msgfrom = self.info
|
||||||
|
s = json.dumps(d)
|
||||||
|
debug(f'send_data() {s=}')
|
||||||
|
await self.socket.send(s)
|
||||||
|
|
||||||
|
async def socket_send(self, s):
|
||||||
|
await self.socket.send(s)
|
||||||
|
|
||||||
|
class RTCP2PConnect:
|
||||||
|
/*
|
||||||
|
opts:{
|
||||||
|
ice_servers:
|
||||||
|
peer_info:
|
||||||
|
auto_callaccept: true or false
|
||||||
|
media_options: { video:trur or false, audio:true or false }
|
||||||
|
data_connect: true or false
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
def __init__(self, signaling, session, opts):
|
||||||
|
self.id = bricks.uuid()
|
||||||
|
self.signaling = signaling
|
||||||
|
self.session = session
|
||||||
|
self.requester = false
|
||||||
|
self.opts = opts
|
||||||
|
self.peers = DictObject()
|
||||||
|
self.signal_handlers = DictObject()
|
||||||
|
self.local_stream = None
|
||||||
|
self.localVideo = None
|
||||||
|
self.add_handler('sessioncreated', self.h_sessioncreated)
|
||||||
|
self.add_handler('callrequest', self.h_callrequest)
|
||||||
|
self.add_handler('callaccepted', self.h_callaccepted)
|
||||||
|
self.add_handler('offer', self.h_offer)
|
||||||
|
self.add_handler('answer', self.h_answer)
|
||||||
|
self.add_handler('icecandidate', self.h_icecandidate)
|
||||||
|
self.add_handler('sessionquit', self.h_sessionquit)
|
||||||
|
|
||||||
|
def add_handler(self, typ, f):
|
||||||
|
self.signal_handlers[typ] = f
|
||||||
|
|
||||||
|
def get_handler(self, typ):
|
||||||
|
return self.signal_handlers[typ]
|
||||||
|
|
||||||
|
|
||||||
|
async def p2pconnect(self, peer):
|
||||||
|
await self.getLocalStream()
|
||||||
|
p = self.peers.get(peer.id)
|
||||||
|
if not p:
|
||||||
|
await self.createPeerConnection(peer)
|
||||||
|
else:
|
||||||
|
debug(f'{peer=}, connect exists {self=}')
|
||||||
|
|
||||||
|
debug(f'p2pconnect() called {self=} {peer=}')
|
||||||
|
|
||||||
|
async def h_sessioncreated(self, data):
|
||||||
|
await self.p2pconnect(self.opts.peer_info, 'requester')
|
||||||
|
if self.opts.peer_info:
|
||||||
|
d = DictObject(**{
|
||||||
|
"type":'callrequest',
|
||||||
|
"msgto":self.opts.peer_info
|
||||||
|
})
|
||||||
|
await self.signaling_send(d)
|
||||||
|
self.requester = true
|
||||||
|
|
||||||
|
async def h_callrequest(self, data):
|
||||||
|
if self.opts.auto_callaccept or true):
|
||||||
|
await self.p2pconnect(data.msgfrom, 'responser')
|
||||||
|
d = DictObject(**{
|
||||||
|
"type":'callaccepted',
|
||||||
|
"msgto":data.msgfrom
|
||||||
|
})
|
||||||
|
await self.signaling_send(d)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def h_callaccepted(self, data):
|
||||||
|
self.createDataChannel(data.msgfrom)
|
||||||
|
await self.send_offer(data.msgfrom, true)
|
||||||
|
|
||||||
|
async def h_offer(self, data):
|
||||||
|
debug(f'h_offer(), {self=}, peer={data.msgfrom}')
|
||||||
|
pc = self.peers[data.msgfrom.id].pc
|
||||||
|
offer = RTCSessionDescription(data.offer)
|
||||||
|
await pc.setRemoteDescription(offer)
|
||||||
|
answer = await pc.createAnswer()
|
||||||
|
await pc.setLocalDescription(answer)
|
||||||
|
d = DictObject(**{
|
||||||
|
type:'answer',
|
||||||
|
answer:pc.localDescription,
|
||||||
|
msgto:data.msgfrom
|
||||||
|
})
|
||||||
|
await self.signaling_send(d)
|
||||||
|
if self.peers[data.msgfrom.id].role != 'requester':
|
||||||
|
await self.send_offer(data.msgfrom)
|
||||||
|
|
||||||
|
async def h_answer(self, data):
|
||||||
|
desc = RTCSessionDescription(data.answer)
|
||||||
|
pc = self.peers[data.msgfrom.id].pc
|
||||||
|
await pc.setRemoteDescription(desc)
|
||||||
|
|
||||||
|
async def h_icecandidate(self, data):
|
||||||
|
candidate = RTCIceCandidate(data.candidate)
|
||||||
|
pc = self.peers[data.msgfrom.id].pc
|
||||||
|
await pc.addIceCandidate(candidate)
|
||||||
|
|
||||||
|
async def h_sessionquit(self, data):
|
||||||
|
pc = self.peers[data.msgfrom.id].pc
|
||||||
|
self.peer_close(data.msgfrom)
|
||||||
|
|
||||||
|
async send_offer(self, peer, initial){
|
||||||
|
debug(f'send_offer(), peers={self.peers}, {peer=}')
|
||||||
|
pc = self.peers[peer.id].pc
|
||||||
|
if initial:
|
||||||
|
self.peers[peer.id].role = 'requester'
|
||||||
|
else:
|
||||||
|
self.peers[peer.id].role = 'responser'
|
||||||
|
|
||||||
|
offer = await pc.createOffer()
|
||||||
|
await pc.setLocalDescription(offer)
|
||||||
|
d = DictObject(**{
|
||||||
|
"type":'offer',
|
||||||
|
"offer":pc.localDescription,
|
||||||
|
"msgto":peer
|
||||||
|
})
|
||||||
|
await self.signaling_send(d)
|
||||||
|
|
||||||
|
async def send_candidate(self, peer, event):
|
||||||
|
debug(f'send_candidate called, {peer=},{event=}')
|
||||||
|
if (event.candidate) {
|
||||||
|
candidate = event.candidate
|
||||||
|
self.signaling_send({
|
||||||
|
type: 'icecandidate',
|
||||||
|
msgto:peer,
|
||||||
|
candidate: candidate
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
signaling_send(d){
|
||||||
|
d.session = self.session
|
||||||
|
self.signaling.send_data(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
async recvdata_handler(data){
|
||||||
|
debug('recvdata=', data, self.signal_handlers)
|
||||||
|
f = self.get_handler(data.type)
|
||||||
|
if (f) {
|
||||||
|
await f(data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
debug('recvdata=', data, 'NOT HANDLED')
|
||||||
|
}
|
||||||
|
|
||||||
|
async ice_statechange(peer, event){
|
||||||
|
pc = self.peers[peer.id].pc
|
||||||
|
debug(`oniceconnectionstatechange, pc.iceConnectionState is ${pc.iceConnectionState}.`)
|
||||||
|
}
|
||||||
|
|
||||||
|
async connection_statechange(peer, event){
|
||||||
|
pc = self.peers[peer.id].pc
|
||||||
|
debug(`${peer.id} state changed. new state=${pc.connectionState}`)
|
||||||
|
debug('state=', pc.connectionState, typeof(pc.connectionState))
|
||||||
|
if (pc.connectionState == 'disconnected'){
|
||||||
|
self.peer_close(peer)
|
||||||
|
if (self.opts.on_pc_disconnected){
|
||||||
|
self.opts.on_pc_disconnected(peer)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (pc.connectionState == 'connected'){
|
||||||
|
debug('state is connected, data_connect=',
|
||||||
|
self.opts.data_connect)
|
||||||
|
if(self.opts.on_pc_connected){
|
||||||
|
self.opts.on_pc_connected(peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async dc_accepted(peer, event){
|
||||||
|
debug('accept datachannel ....')
|
||||||
|
self.peers[peer.id].dc = event.channel
|
||||||
|
await self.dc_created(peer, self.peers[peer.id].dc)
|
||||||
|
}
|
||||||
|
async dc_created(self, peer, dc):
|
||||||
|
debug('dc_created.....', dc)
|
||||||
|
dc.onmessage = self.datachannel_message.bind(peer)
|
||||||
|
dc.onopen = self.datachannel_open(peer)
|
||||||
|
dc.onclose = self.datachannel_close(peer)
|
||||||
|
|
||||||
|
async datachannel_message(self, peer, event):
|
||||||
|
debug('datachannel_message():', self, arguments)
|
||||||
|
dc = self.peers[peer.id].dc
|
||||||
|
if self.opts.on_dc_messaage:
|
||||||
|
await self.opts.on_dc_message(dc, event.data)
|
||||||
|
|
||||||
|
async def datachannel_open(self, peer):
|
||||||
|
debug('datachannel_open():', self, arguments)
|
||||||
|
dc = self.peers[peer.id].dc
|
||||||
|
if self.opts.on_dc_open:
|
||||||
|
await self.opts.on_dc_open(dc)
|
||||||
|
|
||||||
|
async def datachannel_close(self, peer):
|
||||||
|
debug('datachannel_close():', self, arguments)
|
||||||
|
dc = self.peers[peer.id].dc
|
||||||
|
if self.opts.on_dc_close:
|
||||||
|
await self.opts.on_dc_close(dc)
|
||||||
|
|
||||||
|
async def createDataChannel(self, peer):
|
||||||
|
pc = self.peers[peer.id].pc
|
||||||
|
self.peers[peer.id].dc = pc.createDataChannel('chat', {ordered:true})
|
||||||
|
dc = self.peers[peer.id].dc
|
||||||
|
await self.dc_created(peer, self.peers[peer.id].dc)
|
||||||
|
debug('dc created', self.peers[peer.id].dc)
|
||||||
|
|
||||||
|
async def createPeerConnection(self, peer):
|
||||||
|
configuration = DictObject(**{
|
||||||
|
"iceServers":self.opts.ice_servers
|
||||||
|
})
|
||||||
|
debug(f'RTCPC {configuration=}')
|
||||||
|
pc = RTCPeerConnection(configuration)
|
||||||
|
if self.local_stream:
|
||||||
|
for track in self.local_stream.getTracks():
|
||||||
|
pc.addTrack(track, self.local_stream)
|
||||||
|
self.peers[peer.id] = peer
|
||||||
|
self.peers[peer.id].pc = pc
|
||||||
|
self.peers[peer.id].role = ''
|
||||||
|
pc.onicecandidate = partial(self.send_candidate, peer)
|
||||||
|
pc.oniceconnectionstatechange = \
|
||||||
|
partial(self.ice_statechange, peer)
|
||||||
|
pc.onconnectionstatechange = \
|
||||||
|
partial(self.connection_statechange, peer)
|
||||||
|
pc.ondatachanel = partial(self.dc_accepted, peer)
|
||||||
|
pc.ontrack = partial(self.onpctrack, peer)
|
||||||
|
|
||||||
|
def peer_close(self, peer){
|
||||||
|
pc = self.peers[peer.id].pc
|
||||||
|
video = self.peers[peer.id].video
|
||||||
|
for r in pc.getReceivers():
|
||||||
|
r.track.stop()
|
||||||
|
|
||||||
|
for s in pc.getSenders():
|
||||||
|
s.track.stop()
|
||||||
|
|
||||||
|
dc = self.peers[peer.id].dc
|
||||||
|
if (dc){
|
||||||
|
dc.close()
|
||||||
|
}
|
||||||
|
pc.close()
|
||||||
|
self.peers = {k:v for k,v in self.peers.items() if k!=peer.id}
|
||||||
|
cnt = len(self.peers.keys())
|
||||||
|
if (keys.length == 0){
|
||||||
|
self.localVideo.get_stream()
|
||||||
|
.getTracks().forEach(track => track.stop())
|
||||||
|
self.local_stream.getTracks().forEach(track => track.stop())
|
||||||
|
self.local_screan.getTracks().forEach(track => track.stop())
|
||||||
|
self.signaling.del_handler(self.session.sessionid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
65
rtcllm/stt.py
Normal file
65
rtcllm/stt.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import aiohttp
|
||||||
|
from appPublic.dictObject import DictObject
|
||||||
|
from appPublic.oauth_client import OAuthClient
|
||||||
|
|
||||||
|
desc = {
|
||||||
|
"path":"/api/generate.dspy",
|
||||||
|
"method":"POST",
|
||||||
|
"headers":[
|
||||||
|
{
|
||||||
|
"name":"Content-Type",
|
||||||
|
"value":"application/json"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"data":[{
|
||||||
|
"name":"audio_file",
|
||||||
|
"value":"${audio_file}"
|
||||||
|
},{
|
||||||
|
"name":"model",
|
||||||
|
"value":"whisper"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"resp":[
|
||||||
|
{
|
||||||
|
"name":"content",
|
||||||
|
"value":"content"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
opts = {
|
||||||
|
"data":{
|
||||||
|
"oops":0
|
||||||
|
},
|
||||||
|
"asr":desc
|
||||||
|
}
|
||||||
|
|
||||||
|
async def asr(a_file):
|
||||||
|
if not a_file:
|
||||||
|
return
|
||||||
|
"""
|
||||||
|
r = None
|
||||||
|
with open(a_file, 'rb') as f:
|
||||||
|
oc = OAuthClient(DictObject(**opts))
|
||||||
|
r = await oc('https://sage.open-computing.cn/asr', 'asr', {'audio_file':f})
|
||||||
|
print(f'{r=}')
|
||||||
|
return r
|
||||||
|
"""
|
||||||
|
with open(a_file, 'rb') as f:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post("https://sage.open-computing.cn/asr/api/generate.dspy", data={
|
||||||
|
"model":"whisper",
|
||||||
|
"audio_file":f
|
||||||
|
}) as response:
|
||||||
|
r = await response.json()
|
||||||
|
print(f'{r=}')
|
||||||
|
return r
|
||||||
|
"""
|
||||||
|
segments, info = model.transcribe(a_file, beam_size=5)
|
||||||
|
txt = ''
|
||||||
|
for s in segments:
|
||||||
|
txt += s.text
|
||||||
|
return {
|
||||||
|
'content': txt,
|
||||||
|
'language': info.language
|
||||||
|
}
|
||||||
|
"""
|
||||||
181
rtcllm/vad.py
Normal file
181
rtcllm/vad.py
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
import base64
|
||||||
|
from inspect import isfunction, iscoroutinefunction
|
||||||
|
from traceback import print_exc
|
||||||
|
import asyncio
|
||||||
|
import collections
|
||||||
|
import contextlib
|
||||||
|
from appPublic.folderUtils import temp_file
|
||||||
|
from appPublic.worker import awaitify
|
||||||
|
from aiortc import MediaStreamTrack
|
||||||
|
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder, MediaRelay
|
||||||
|
import webrtcvad
|
||||||
|
import wave
|
||||||
|
from scipy.io.wavfile import write
|
||||||
|
import numpy as np
|
||||||
|
import av
|
||||||
|
from av import AudioLayout, AudioResampler, AudioFrame, AudioFormat
|
||||||
|
|
||||||
|
def frames_write_wave(frames):
|
||||||
|
path = temp_file(suffix='.wav')
|
||||||
|
output_container = av.open(path, 'w')
|
||||||
|
out_stream = output_container.add_stream('pcm_s16le')
|
||||||
|
for frame in frames:
|
||||||
|
for packet in out_stream.encode(frame):
|
||||||
|
output_container.mux(packet)
|
||||||
|
for packet in out_stream.encode(None):
|
||||||
|
output_container.mux(packet)
|
||||||
|
output_container.close()
|
||||||
|
return path
|
||||||
|
|
||||||
|
def bytes2frame(byts, channels=1, sample_rate=16000):
|
||||||
|
audio_data = np.frombuffer(byts, np.int16)
|
||||||
|
audio_data = audio_data.reshape((channels, -1))
|
||||||
|
layout = 'mono'
|
||||||
|
if channels == 2:
|
||||||
|
layout = 'stereo'
|
||||||
|
# Create an AV frame from the audio data
|
||||||
|
frame = av.AudioFrame.from_ndarray(audio_data, format='s16', layout='mono')
|
||||||
|
frame.sample_rate = sample_rate
|
||||||
|
return frame
|
||||||
|
|
||||||
|
def frame2bytes(frame):
|
||||||
|
audio_array = frame.to_ndarray()
|
||||||
|
dtype = audio_array.dtype
|
||||||
|
audio_bytes = audio_array.tobytes()
|
||||||
|
return audio_bytes
|
||||||
|
|
||||||
|
def resample(frame, sample_rate=None):
|
||||||
|
if sample_rate is None:
|
||||||
|
sample_rate = frame.rate
|
||||||
|
r = AudioResampler(format='s16', layout='mono', rate=sample_rate)
|
||||||
|
frame = r.resample(frame)
|
||||||
|
return frame
|
||||||
|
|
||||||
|
class MyVad(webrtcvad.Vad):
|
||||||
|
def __init__(self, callback=None):
|
||||||
|
super().__init__(3)
|
||||||
|
self.voiced_frames = []
|
||||||
|
self.num_padding_frames = 40
|
||||||
|
self.ring_buffer = collections.deque(maxlen=self.num_padding_frames)
|
||||||
|
self.onvoiceend = callback
|
||||||
|
self.triggered = False
|
||||||
|
self.cnt = 0
|
||||||
|
|
||||||
|
def voice_duration(self):
|
||||||
|
duration = 0
|
||||||
|
for f in self.voiced_frames:
|
||||||
|
duration = f.samples * 1000 / f.sample_rate + duration
|
||||||
|
return duration
|
||||||
|
|
||||||
|
async def vad_check(self, inframe):
|
||||||
|
"""
|
||||||
|
ONLY SUPPORT frame with sample_rate = 16000 samples = 160
|
||||||
|
"""
|
||||||
|
frame = inframe
|
||||||
|
byts = frame2bytes(frame)
|
||||||
|
if self.cnt == 0:
|
||||||
|
f = frame
|
||||||
|
print(f'{f.sample_rate=}, {f.samples=},{f.layout=}, {len(byts)=}')
|
||||||
|
if not webrtcvad.valid_rate_and_frame_length(frame.sample_rate, frame.samples):
|
||||||
|
print('ftcygvhbunjiokmpl,mknjbhvgc')
|
||||||
|
is_speech = self.is_speech(byts, frame.sample_rate, length=frame.samples)
|
||||||
|
if not self.triggered:
|
||||||
|
self.ring_buffer.append((inframe, is_speech))
|
||||||
|
num_voiced = len([f for f, speech in self.ring_buffer if speech])
|
||||||
|
# If we're NOTTRIGGERED and more than 90% of the frames in
|
||||||
|
# the ring buffer are voiced frames, then enter the
|
||||||
|
# TRIGGERED state.
|
||||||
|
if num_voiced > 0.9 * self.ring_buffer.maxlen:
|
||||||
|
self.triggered = True
|
||||||
|
# We want to yield all the audio we see from now until
|
||||||
|
# we are NOTTRIGGERED, but we have to start with the
|
||||||
|
# audio that's already in the ring buffer.
|
||||||
|
for f, s in self.ring_buffer:
|
||||||
|
self.voiced_frames.append(f)
|
||||||
|
self.ring_buffer.clear()
|
||||||
|
# print('start voice .....', len(self.voiced_frames))
|
||||||
|
else:
|
||||||
|
# We're in the TRIGGERED state, so collect the audio data
|
||||||
|
# and add it to the ring buffer.
|
||||||
|
self.voiced_frames.append(inframe)
|
||||||
|
self.ring_buffer.append((frame, is_speech))
|
||||||
|
num_unvoiced = len([f for f, speech in self.ring_buffer if not speech])
|
||||||
|
# If more than 90% of the frames in the ring buffer are
|
||||||
|
# unvoiced, then enter NOTTRIGGERED and yield whatever
|
||||||
|
# audio we've collected.
|
||||||
|
if num_unvoiced > 0.9 * self.ring_buffer.maxlen:
|
||||||
|
self.triggered = False
|
||||||
|
duration = self.voice_duration()
|
||||||
|
if duration > 500:
|
||||||
|
ret = frames_write_wave(self.voiced_frames)
|
||||||
|
if self.onvoiceend:
|
||||||
|
if iscoroutinefunction(self.onvoiceend):
|
||||||
|
await self.onvoiceend(ret)
|
||||||
|
else:
|
||||||
|
self.onvoiceend(ret)
|
||||||
|
else:
|
||||||
|
print('-----short voice------')
|
||||||
|
|
||||||
|
|
||||||
|
self.ring_buffer.clear()
|
||||||
|
self.voiced_frames = []
|
||||||
|
self.cnt += 1
|
||||||
|
|
||||||
|
class AudioTrackVad(MediaStreamTrack):
|
||||||
|
def __init__(self, track, stage=3, onvoiceend=None):
|
||||||
|
super().__init__()
|
||||||
|
self.track = track
|
||||||
|
self.vad = MyVad(callback=onvoiceend)
|
||||||
|
# self.sample_rate = self.track.getSettings().sampleRate
|
||||||
|
# frameSize = self.track.getSettings().frameSize
|
||||||
|
# self.frame_duration_ms = (1000 * frameSize) / self.sample_rate
|
||||||
|
self.frame_duration_ms = 0.02
|
||||||
|
self.remind_byts = b''
|
||||||
|
self.loop = asyncio.get_event_loop()
|
||||||
|
self.task = None
|
||||||
|
self.debug = True
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
def start_vad(self):
|
||||||
|
self.running = True
|
||||||
|
self.task = self.loop.call_later(self.frame_duration_ms, self._recv)
|
||||||
|
|
||||||
|
def _recv(self):
|
||||||
|
asyncio.create_task(self.recv())
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
def to16000_160_frames(self, frame):
|
||||||
|
frames = resample(frame, sample_rate=16000)
|
||||||
|
ret_frames = []
|
||||||
|
for f in frames:
|
||||||
|
if f.samples == 160:
|
||||||
|
return frames
|
||||||
|
for f in frames:
|
||||||
|
b1 = self.remind_byts + frame2bytes(f)
|
||||||
|
while len(b1) >= 320:
|
||||||
|
b = b1[:320]
|
||||||
|
b1 = b1[320:]
|
||||||
|
ret_frames.append(bytes2frame(b))
|
||||||
|
self.remind_byts = b1
|
||||||
|
return ret_frames
|
||||||
|
|
||||||
|
async def recv(self):
|
||||||
|
frame = await self.track.recv()
|
||||||
|
self.sample_rate = frame.sample_rate
|
||||||
|
duration = (frame.samples * 1000) / frame.sample_rate
|
||||||
|
# print(f'{self.__class__.__name__}.recv(): {duration=}, {frame.samples=}, {frame.sample_rate=}')
|
||||||
|
try:
|
||||||
|
frames = self.to16000_160_frames(frame)
|
||||||
|
for frame in frames:
|
||||||
|
await self.vad.vad_check(frame)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'{e=}')
|
||||||
|
print_exc()
|
||||||
|
return
|
||||||
|
if self.task:
|
||||||
|
self.task.cancel()
|
||||||
|
if self.running:
|
||||||
|
self.task = self.loop.call_later(self.frame_duration_ms, self._recv)
|
||||||
|
return frame
|
||||||
142
rtcllm/wss.py
Normal file
142
rtcllm/wss.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
|
||||||
|
from websockets.client import connect
|
||||||
|
import websockets
|
||||||
|
from appPublic.background import Background
|
||||||
|
from appPublic.uniqueID import getID
|
||||||
|
from appPublic.dictObject import DictObject
|
||||||
|
|
||||||
|
class Wss:
|
||||||
|
"""
|
||||||
|
new session
|
||||||
|
session_type: p2p, aba, meeting
|
||||||
|
|
||||||
|
signaling data format
|
||||||
|
create a new session
|
||||||
|
{
|
||||||
|
type:new_session,
|
||||||
|
session:{
|
||||||
|
sessiontype:
|
||||||
|
sessionid
|
||||||
|
}
|
||||||
|
from:
|
||||||
|
to:
|
||||||
|
}
|
||||||
|
(will boastcast to all peer online)login:
|
||||||
|
{
|
||||||
|
type:login,
|
||||||
|
from:
|
||||||
|
}
|
||||||
|
(will boastcast to all peer online)logout:
|
||||||
|
{
|
||||||
|
type:logout
|
||||||
|
from:
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
def __init__(self, app, wss_url, info=None, opts={}):
|
||||||
|
self.wss_url = wss_url
|
||||||
|
if info is None:
|
||||||
|
info = {
|
||||||
|
'id':getID()
|
||||||
|
}
|
||||||
|
self.opts = opts
|
||||||
|
self.info = info
|
||||||
|
self.ws = None
|
||||||
|
self.handler_name = 'sessionid'
|
||||||
|
self.app = app
|
||||||
|
self.running = False
|
||||||
|
self.app.wss = self
|
||||||
|
self.handlers = {}
|
||||||
|
self.sessionhandlers = {}
|
||||||
|
self.peers = []
|
||||||
|
|
||||||
|
def add_handler(self, key, handler):
|
||||||
|
sef.handlers[key] = handler
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
f = asyncio.run_coroutine_threadsafe(self._start, loop)
|
||||||
|
return f.result()
|
||||||
|
|
||||||
|
def add_sessionhandler(self, sessiontype, handler):
|
||||||
|
self.sessionhandlers[sessiontype] = handler
|
||||||
|
|
||||||
|
async def _start(self):
|
||||||
|
self.ws = await connect(self.wss_url)
|
||||||
|
self.running = True
|
||||||
|
while self.running:
|
||||||
|
msg = await self.ws.recv()
|
||||||
|
data = DictObject(**json.loads(msg))
|
||||||
|
name = data.get(self.handler_name)
|
||||||
|
handler = self.handlers.get(name)
|
||||||
|
if handler is None:
|
||||||
|
handler = self.recvdata_handler
|
||||||
|
await handler(data)
|
||||||
|
|
||||||
|
def add_peer(self, info):
|
||||||
|
self.peers = [ i for i in self.peers if i.id != info.id ]
|
||||||
|
self.peers.append(info)
|
||||||
|
|
||||||
|
def delete_peer(self, info):
|
||||||
|
self.peers = [ i for i in self.peers if i.id != info.id ]
|
||||||
|
|
||||||
|
async def recvdata_handler(self, data):
|
||||||
|
if data.type == 'login':
|
||||||
|
self.add_peer(data.msgfrom)
|
||||||
|
return
|
||||||
|
if data.type == 'logout':
|
||||||
|
self.delete_peer(data.msgfrom)
|
||||||
|
return
|
||||||
|
if data.type == 'new_session'
|
||||||
|
h = self.sessionhandlers.get(data.session.sessiontype)
|
||||||
|
i = h(data.from)
|
||||||
|
k = data.session.sessionid
|
||||||
|
self.add_handler(k, i.recvdata_handler)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
self.running = False
|
||||||
|
await self.ws.close()
|
||||||
|
|
||||||
|
async def restart(self):
|
||||||
|
try:
|
||||||
|
await self.stop()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
await self.start()
|
||||||
|
|
||||||
|
async def send(self, dic):
|
||||||
|
dic.msgfrom = self.info
|
||||||
|
txt = json.dumps(dic)
|
||||||
|
try:
|
||||||
|
return await self.ws.send(txt)
|
||||||
|
except:
|
||||||
|
await self.restart()
|
||||||
|
return self.send(dic)
|
||||||
|
|
||||||
|
async def new_session(self, sessiontype, opts):
|
||||||
|
k = self.sessionhandlers.get(sessiontype)
|
||||||
|
if k is None:
|
||||||
|
e = Exception(f'Sessiontype({sessiontype}) not registed')
|
||||||
|
exception(f'new_session() exception {e}')
|
||||||
|
raise e
|
||||||
|
d = DictObject()
|
||||||
|
d.type = 'new_session'
|
||||||
|
d.msgfrom = self.info
|
||||||
|
sessionid = getID()
|
||||||
|
d.session = DictObject(sessiontype=sessiontype, sessionid=sessionid)
|
||||||
|
await self.send(d)
|
||||||
|
h = k(self, sessionid, opts)
|
||||||
|
self.add_sessionhandler(sessionid, h.recvdata_handler)
|
||||||
|
return h
|
||||||
|
|
||||||
|
async def login(self):
|
||||||
|
d = DictObject()
|
||||||
|
d.msgfrom = self.info
|
||||||
|
d.type = 'login'
|
||||||
|
await self.send(d)
|
||||||
|
|
||||||
|
async def logout(self):
|
||||||
|
d = DictObject()
|
||||||
|
d.msgfrom = self.info
|
||||||
|
d.type = 'logout'
|
||||||
|
await self.send(d)
|
||||||
Loading…
x
Reference in New Issue
Block a user