You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

583 lines
20 KiB

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Runtime.InteropServices;
namespace ET
{
public static class KcpProtocalType
{
public const byte SYN = 1;
public const byte ACK = 2;
public const byte FIN = 3;
public const byte MSG = 4;
}
public enum ServiceType
{
Outer,
Inner,
}
public sealed class KService: AService
{
// KService创建的时间
private readonly long startTime;
// 当前时间 - KService创建的时间, 线程安全
public uint TimeNow
{
get
{
return (uint) (TimeHelper.ClientNow() - this.startTime);
}
}
private Socket socket;
#region 回调方法
static KService()
{
//Kcp.KcpSetLog(KcpLog);
Kcp.KcpSetoutput(KcpOutput);
}
private static readonly byte[] logBuffer = new byte[1024];
#if ENABLE_IL2CPP
[AOT.MonoPInvokeCallback(typeof(KcpOutput))]
#endif
private static void KcpLog(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
{
try
{
Marshal.Copy(bytes, logBuffer, 0, len);
Log.Info(logBuffer.ToStr(0, len));
}
catch (Exception e)
{
Log.Error(e);
}
}
#if ENABLE_IL2CPP
[AOT.MonoPInvokeCallback(typeof(KcpOutput))]
#endif
private static int KcpOutput(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
{
try
{
if (kcp == IntPtr.Zero)
{
return 0;
}
if (!KChannel.KcpPtrChannels.TryGetValue(kcp, out KChannel kChannel))
{
return 0;
}
kChannel.Output(bytes, len);
}
catch (Exception e)
{
Log.Error(e);
return len;
}
return len;
}
#endregion
public KService(ThreadSynchronizationContext threadSynchronizationContext, IPEndPoint ipEndPoint, ServiceType serviceType)
{
this.ServiceType = serviceType;
this.ThreadSynchronizationContext = threadSynchronizationContext;
this.startTime = TimeHelper.ClientNow();
this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
this.socket.SendBufferSize = Kcp.OneM * 64;
this.socket.ReceiveBufferSize = Kcp.OneM * 64;
}
this.socket.Bind(ipEndPoint);
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
const uint IOC_IN = 0x80000000;
const uint IOC_VENDOR = 0x18000000;
uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
}
}
public KService(ThreadSynchronizationContext threadSynchronizationContext, ServiceType serviceType)
{
this.ServiceType = serviceType;
this.ThreadSynchronizationContext = threadSynchronizationContext;
this.startTime = TimeHelper.ClientNow();
this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
// 作为客户端不需要修改发送跟接收缓冲区大小
this.socket.Bind(new IPEndPoint(IPAddress.Any, 0));
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
const uint IOC_IN = 0x80000000;
const uint IOC_VENDOR = 0x18000000;
uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
}
}
public void ChangeAddress(long id, IPEndPoint address)
{
KChannel kChannel = this.Get(id);
if (kChannel == null)
{
return;
}
Log.Info($"channel change address: {id} {address}");
kChannel.RemoteAddress = address;
}
// 保存所有的channel
private readonly Dictionary<long, KChannel> idChannels = new Dictionary<long, KChannel>();
private readonly Dictionary<long, KChannel> localConnChannels = new Dictionary<long, KChannel>();
private readonly Dictionary<long, KChannel> waitConnectChannels = new Dictionary<long, KChannel>();
private readonly byte[] cache = new byte[8192];
private EndPoint ipEndPoint = new IPEndPoint(IPAddress.Any, 0);
// 下帧要更新的channel
private readonly HashSet<long> updateChannels = new HashSet<long>();
// 下次时间更新的channel
private readonly MultiMap<long, long> timeId = new MultiMap<long, long>();
private readonly List<long> timeOutTime = new List<long>();
// 记录最小时间,不用每次都去MultiMap取第一个值
private long minTime;
private List<long> waitRemoveChannels = new List<long>();
public override bool IsDispose()
{
return this.socket == null;
}
public override void Dispose()
{
foreach (long channelId in this.idChannels.Keys.ToArray())
{
this.Remove(channelId);
}
this.socket.Close();
this.socket = null;
}
private IPEndPoint CloneAddress()
{
IPEndPoint ip = (IPEndPoint) this.ipEndPoint;
return new IPEndPoint(ip.Address, ip.Port);
}
private void Recv()
{
if (this.socket == null)
{
return;
}
while (socket != null && this.socket.Available > 0)
{
int messageLength = this.socket.ReceiveFrom(this.cache, ref this.ipEndPoint);
// 长度小于1,不是正常的消息
if (messageLength < 1)
{
continue;
}
// accept
byte flag = this.cache[0];
// conn从100开始,如果为1,2,3则是特殊包
uint remoteConn = 0;
uint localConn = 0;
try
{
KChannel kChannel = null;
switch (flag)
{
#if NOT_UNITY
case KcpProtocalType.SYN: // accept
{
// 长度!=5,不是SYN消息
if (messageLength < 9)
{
break;
}
string realAddress = null;
remoteConn = BitConverter.ToUInt32(this.cache, 1);
if (messageLength > 9)
{
realAddress = this.cache.ToStr(9, messageLength - 9);
}
remoteConn = BitConverter.ToUInt32(this.cache, 1);
localConn = BitConverter.ToUInt32(this.cache, 5);
this.waitConnectChannels.TryGetValue(remoteConn, out kChannel);
if (kChannel == null)
{
localConn = CreateRandomLocalConn();
// 已存在同样的localConn,则不处理,等待下次sync
if (this.localConnChannels.ContainsKey(localConn))
{
break;
}
long id = this.CreateAcceptChannelId(localConn);
if (this.idChannels.ContainsKey(id))
{
break;
}
kChannel = new KChannel(id, localConn, remoteConn, this.socket, this.CloneAddress(), this);
this.idChannels.Add(kChannel.Id, kChannel);
this.waitConnectChannels.Add(kChannel.RemoteConn, kChannel); // 连接上了或者超时后会删除
this.localConnChannels.Add(kChannel.LocalConn, kChannel);
kChannel.RealAddress = realAddress;
IPEndPoint realEndPoint = kChannel.RealAddress == null? kChannel.RemoteAddress : NetworkHelper.ToIPEndPoint(kChannel.RealAddress);
this.OnAccept(kChannel.Id, realEndPoint);
}
if (kChannel.RemoteConn != remoteConn)
{
break;
}
// 地址跟上次的不一致则跳过
if (kChannel.RealAddress != realAddress)
{
Log.Error($"kchannel syn address diff: {kChannel.Id} {kChannel.RealAddress} {realAddress}");
break;
}
try
{
byte[] buffer = this.cache;
buffer.WriteTo(0, KcpProtocalType.ACK);
buffer.WriteTo(1, kChannel.LocalConn);
buffer.WriteTo(5, kChannel.RemoteConn);
Log.Info($"kservice syn: {kChannel.Id} {remoteConn} {localConn}");
this.socket.SendTo(buffer, 0, 9, SocketFlags.None, kChannel.RemoteAddress);
}
catch (Exception e)
{
Log.Error(e);
kChannel.OnError(ErrorCore.ERR_SocketCantSend);
}
break;
}
#endif
case KcpProtocalType.ACK: // connect返回
// 长度!=9,不是connect消息
if (messageLength != 9)
{
break;
}
remoteConn = BitConverter.ToUInt32(this.cache, 1);
localConn = BitConverter.ToUInt32(this.cache, 5);
kChannel = this.GetByLocalConn(localConn);
if (kChannel != null)
{
Log.Info($"kservice ack: {kChannel.Id} {remoteConn} {localConn}");
kChannel.RemoteConn = remoteConn;
kChannel.HandleConnnect();
}
break;
case KcpProtocalType.FIN: // 断开
// 长度!=13,不是DisConnect消息
if (messageLength != 13)
{
break;
}
remoteConn = BitConverter.ToUInt32(this.cache, 1);
localConn = BitConverter.ToUInt32(this.cache, 5);
int error = BitConverter.ToInt32(this.cache, 9);
// 处理chanel
kChannel = this.GetByLocalConn(localConn);
if (kChannel == null)
{
break;
}
// 校验remoteConn,防止第三方攻击
if (kChannel.RemoteConn != remoteConn)
{
break;
}
Log.Info($"kservice recv fin: {kChannel.Id} {localConn} {remoteConn} {error}");
kChannel.OnError(ErrorCore.ERR_PeerDisconnect);
break;
case KcpProtocalType.MSG: // 断开
// 长度<9,不是Msg消息
if (messageLength < 9)
{
break;
}
// 处理chanel
remoteConn = BitConverter.ToUInt32(this.cache, 1);
localConn = BitConverter.ToUInt32(this.cache, 5);
kChannel = this.GetByLocalConn(localConn);
if (kChannel == null)
{
// 通知对方断开
this.Disconnect(localConn, remoteConn, ErrorCore.ERR_KcpNotFoundChannel, (IPEndPoint) this.ipEndPoint, 1);
break;
}
// 校验remoteConn,防止第三方攻击
if (kChannel.RemoteConn != remoteConn)
{
break;
}
kChannel.HandleRecv(this.cache, 5, messageLength - 5);
break;
}
}
catch (Exception e)
{
Log.Error($"kservice error: {flag} {remoteConn} {localConn}\n{e}");
}
}
}
public KChannel Get(long id)
{
KChannel channel;
this.idChannels.TryGetValue(id, out channel);
return channel;
}
private KChannel GetByLocalConn(uint localConn)
{
KChannel channel;
this.localConnChannels.TryGetValue(localConn, out channel);
return channel;
}
protected override void Get(long id, IPEndPoint address)
{
if (this.idChannels.TryGetValue(id, out KChannel kChannel))
{
return;
}
try
{
// 低32bit是localConn
uint localConn = (uint) ((ulong) id & uint.MaxValue);
kChannel = new KChannel(id, localConn, this.socket, address, this);
this.idChannels.Add(id, kChannel);
this.localConnChannels.Add(kChannel.LocalConn, kChannel);
}
catch (Exception e)
{
Log.Error($"kservice get error: {id}\n{e}");
}
}
public override void Remove(long id)
{
if (!this.idChannels.TryGetValue(id, out KChannel kChannel))
{
return;
}
Log.Info($"kservice remove channel: {id} {kChannel.LocalConn} {kChannel.RemoteConn}");
this.idChannels.Remove(id);
this.localConnChannels.Remove(kChannel.LocalConn);
if (this.waitConnectChannels.TryGetValue(kChannel.RemoteConn, out KChannel waitChannel))
{
if (waitChannel.LocalConn == kChannel.LocalConn)
{
this.waitConnectChannels.Remove(kChannel.RemoteConn);
}
}
kChannel.Dispose();
}
private void Disconnect(uint localConn, uint remoteConn, int error, IPEndPoint address, int times)
{
try
{
if (this.socket == null)
{
return;
}
byte[] buffer = this.cache;
buffer.WriteTo(0, KcpProtocalType.FIN);
buffer.WriteTo(1, localConn);
buffer.WriteTo(5, remoteConn);
buffer.WriteTo(9, (uint) error);
for (int i = 0; i < times; ++i)
{
this.socket.SendTo(buffer, 0, 13, SocketFlags.None, address);
}
}
catch (Exception e)
{
Log.Error($"Disconnect error {localConn} {remoteConn} {error} {address} {e}");
}
Log.Info($"channel send fin: {localConn} {remoteConn} {address} {error}");
}
protected override void Send(long channelId, long actorId, MemoryStream stream)
{
KChannel channel = this.Get(channelId);
if (channel == null)
{
return;
}
channel.Send(actorId, stream);
}
// 服务端需要看channel的update时间是否已到
public void AddToUpdateNextTime(long time, long id)
{
if (time == 0)
{
this.updateChannels.Add(id);
return;
}
if (time < this.minTime)
{
this.minTime = time;
}
this.timeId.Add(time, id);
}
public override void Update()
{
this.Recv();
this.TimerOut();
foreach (long id in updateChannels)
{
KChannel kChannel = this.Get(id);
if (kChannel == null)
{
continue;
}
if (kChannel.Id == 0)
{
continue;
}
kChannel.Update();
}
this.updateChannels.Clear();
this.RemoveConnectTimeoutChannels();
}
private void RemoveConnectTimeoutChannels()
{
waitRemoveChannels.Clear();
foreach (long channelId in this.waitConnectChannels.Keys)
{
this.waitConnectChannels.TryGetValue(channelId, out KChannel kChannel);
if (kChannel == null)
{
Log.Error($"RemoveConnectTimeoutChannels not found kchannel: {channelId}");
continue;
}
// 连接上了要马上删除
if (kChannel.IsConnected)
{
waitRemoveChannels.Add(channelId);
}
// 10秒连接超时
if (this.TimeNow > kChannel.CreateTime + 10 * 1000)
{
waitRemoveChannels.Add(channelId);
}
}
foreach (long channelId in waitRemoveChannels)
{
this.waitConnectChannels.Remove(channelId);
}
}
// 计算到期需要update的channel
private void TimerOut()
{
if (this.timeId.Count == 0)
{
return;
}
uint timeNow = this.TimeNow;
if (timeNow < this.minTime)
{
return;
}
this.timeOutTime.Clear();
foreach (KeyValuePair<long, List<long>> kv in this.timeId)
{
long k = kv.Key;
if (k > timeNow)
{
minTime = k;
break;
}
this.timeOutTime.Add(k);
}
foreach (long k in this.timeOutTime)
{
foreach (long v in this.timeId[k])
{
this.updateChannels.Add(v);
}
this.timeId.Remove(k);
}
}
}
}