mirror of
https://github.com/Anuken/Mindustry.git
synced 2024-09-22 05:47:44 +03:00
408 lines
13 KiB
Java
408 lines
13 KiB
Java
package mindustry.net;
|
|
|
|
import arc.*;
|
|
import arc.func.*;
|
|
import arc.net.*;
|
|
import arc.net.FrameworkMessage.*;
|
|
import arc.struct.*;
|
|
import arc.util.*;
|
|
import arc.util.async.*;
|
|
import arc.util.pooling.*;
|
|
import mindustry.net.Net.*;
|
|
import mindustry.net.Packets.*;
|
|
|
|
import java.io.*;
|
|
import java.net.*;
|
|
import java.nio.*;
|
|
import java.nio.channels.*;
|
|
import java.util.concurrent.*;
|
|
|
|
import static mindustry.Vars.*;
|
|
|
|
public class ArcNetProvider implements NetProvider{
|
|
final Client client;
|
|
final Prov<DatagramPacket> packetSupplier = () -> new DatagramPacket(new byte[512], 512);
|
|
|
|
final Server server;
|
|
final CopyOnWriteArrayList<ArcConnection> connections = new CopyOnWriteArrayList<>();
|
|
Thread serverThread;
|
|
|
|
public ArcNetProvider(){
|
|
ArcNet.errorHandler = e -> Log.debug(Strings.getStackTrace(e));
|
|
|
|
client = new Client(8192, 8192, new PacketSerializer());
|
|
client.setDiscoveryPacket(packetSupplier);
|
|
client.addListener(new NetListener(){
|
|
@Override
|
|
public void connected(Connection connection){
|
|
Connect c = new Connect();
|
|
c.addressTCP = connection.getRemoteAddressTCP().getAddress().getHostAddress();
|
|
if(connection.getRemoteAddressTCP() != null) c.addressTCP = connection.getRemoteAddressTCP().toString();
|
|
|
|
Core.app.post(() -> net.handleClientReceived(c));
|
|
}
|
|
|
|
@Override
|
|
public void disconnected(Connection connection, DcReason reason){
|
|
if(connection.getLastProtocolError() != null){
|
|
netClient.setQuiet();
|
|
}
|
|
|
|
Disconnect c = new Disconnect();
|
|
c.reason = reason.toString();
|
|
Core.app.post(() -> net.handleClientReceived(c));
|
|
}
|
|
|
|
@Override
|
|
public void received(Connection connection, Object object){
|
|
if(object instanceof FrameworkMessage) return;
|
|
|
|
Core.app.post(() -> {
|
|
try{
|
|
net.handleClientReceived(object);
|
|
}catch(Throwable e){
|
|
net.handleException(e);
|
|
}
|
|
});
|
|
|
|
}
|
|
});
|
|
|
|
server = new Server(32768, 8192, new PacketSerializer());
|
|
server.setMulticast(multicastGroup, multicastPort);
|
|
server.setDiscoveryHandler((address, handler) -> {
|
|
ByteBuffer buffer = NetworkIO.writeServerData();
|
|
buffer.position(0);
|
|
handler.respond(buffer);
|
|
});
|
|
|
|
server.addListener(new NetListener(){
|
|
|
|
@Override
|
|
public void connected(Connection connection){
|
|
String ip = connection.getRemoteAddressTCP().getAddress().getHostAddress();
|
|
|
|
ArcConnection kn = new ArcConnection(ip, connection);
|
|
|
|
Connect c = new Connect();
|
|
c.addressTCP = ip;
|
|
|
|
Log.debug("&bReceived connection: @", c.addressTCP);
|
|
|
|
connections.add(kn);
|
|
Core.app.post(() -> net.handleServerReceived(kn, c));
|
|
}
|
|
|
|
@Override
|
|
public void disconnected(Connection connection, DcReason reason){
|
|
ArcConnection k = getByArcID(connection.getID());
|
|
if(k == null) return;
|
|
|
|
Disconnect c = new Disconnect();
|
|
c.reason = reason.toString();
|
|
|
|
Core.app.post(() -> {
|
|
net.handleServerReceived(k, c);
|
|
connections.remove(k);
|
|
});
|
|
}
|
|
|
|
@Override
|
|
public void received(Connection connection, Object object){
|
|
ArcConnection k = getByArcID(connection.getID());
|
|
if(object instanceof FrameworkMessage || k == null) return;
|
|
|
|
Core.app.post(() -> {
|
|
try{
|
|
net.handleServerReceived(k, object);
|
|
}catch(Throwable e){
|
|
e.printStackTrace();
|
|
}
|
|
});
|
|
}
|
|
});
|
|
}
|
|
|
|
private static boolean isLocal(InetAddress addr){
|
|
if(addr.isAnyLocalAddress() || addr.isLoopbackAddress()) return true;
|
|
|
|
try{
|
|
return NetworkInterface.getByInetAddress(addr) != null;
|
|
}catch(Exception e){
|
|
return false;
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public void connectClient(String ip, int port, Runnable success){
|
|
Threads.daemon(() -> {
|
|
try{
|
|
//just in case
|
|
client.stop();
|
|
|
|
Threads.daemon("Net Client", () -> {
|
|
try{
|
|
client.run();
|
|
}catch(Exception e){
|
|
if(!(e instanceof ClosedSelectorException)) net.handleException(e);
|
|
}
|
|
});
|
|
|
|
client.connect(5000, ip, port, port);
|
|
success.run();
|
|
}catch(Exception e){
|
|
net.handleException(e);
|
|
}
|
|
});
|
|
}
|
|
|
|
@Override
|
|
public void disconnectClient(){
|
|
client.close();
|
|
}
|
|
|
|
@Override
|
|
public void sendClient(Object object, SendMode mode){
|
|
try{
|
|
if(mode == SendMode.tcp){
|
|
client.sendTCP(object);
|
|
}else{
|
|
client.sendUDP(object);
|
|
}
|
|
//sending things can cause an under/overflow, catch it and disconnect instead of crashing
|
|
}catch(BufferOverflowException | BufferUnderflowException e){
|
|
net.showError(e);
|
|
}
|
|
|
|
Pools.free(object);
|
|
}
|
|
|
|
@Override
|
|
public void pingHost(String address, int port, Cons<Host> valid, Cons<Exception> invalid){
|
|
Threads.daemon(() -> {
|
|
try{
|
|
DatagramSocket socket = new DatagramSocket();
|
|
long time = Time.millis();
|
|
socket.send(new DatagramPacket(new byte[]{-2, 1}, 2, InetAddress.getByName(address), port));
|
|
socket.setSoTimeout(2000);
|
|
|
|
DatagramPacket packet = packetSupplier.get();
|
|
socket.receive(packet);
|
|
|
|
ByteBuffer buffer = ByteBuffer.wrap(packet.getData());
|
|
Host host = NetworkIO.readServerData((int)Time.timeSinceMillis(time), packet.getAddress().getHostAddress(), buffer);
|
|
|
|
Core.app.post(() -> valid.get(host));
|
|
}catch(Exception e){
|
|
Core.app.post(() -> invalid.get(e));
|
|
}
|
|
});
|
|
}
|
|
|
|
@Override
|
|
public void discoverServers(Cons<Host> callback, Runnable done){
|
|
Seq<InetAddress> foundAddresses = new Seq<>();
|
|
long time = Time.millis();
|
|
client.discoverHosts(port, multicastGroup, multicastPort, 3000, packet -> {
|
|
Core.app.post(() -> {
|
|
try{
|
|
if(foundAddresses.contains(address -> address.equals(packet.getAddress()) || (isLocal(address) && isLocal(packet.getAddress())))){
|
|
return;
|
|
}
|
|
ByteBuffer buffer = ByteBuffer.wrap(packet.getData());
|
|
Host host = NetworkIO.readServerData((int)Time.timeSinceMillis(time), packet.getAddress().getHostAddress(), buffer);
|
|
callback.get(host);
|
|
foundAddresses.add(packet.getAddress());
|
|
}catch(Exception e){
|
|
//don't crash when there's an error pinging a a server or parsing data
|
|
e.printStackTrace();
|
|
}
|
|
});
|
|
}, () -> Core.app.post(done));
|
|
}
|
|
|
|
@Override
|
|
public void dispose(){
|
|
disconnectClient();
|
|
closeServer();
|
|
try{
|
|
client.dispose();
|
|
}catch(IOException ignored){
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public Iterable<ArcConnection> getConnections(){
|
|
return connections;
|
|
}
|
|
|
|
@Override
|
|
public void hostServer(int port) throws IOException{
|
|
connections.clear();
|
|
server.bind(port, port);
|
|
|
|
serverThread = new Thread(() -> {
|
|
try{
|
|
server.run();
|
|
}catch(Throwable e){
|
|
if(!(e instanceof ClosedSelectorException)) Threads.throwAppException(e);
|
|
}
|
|
}, "Net Server");
|
|
serverThread.setDaemon(true);
|
|
serverThread.start();
|
|
}
|
|
|
|
@Override
|
|
public void closeServer(){
|
|
connections.clear();
|
|
Threads.daemon(server::stop);
|
|
}
|
|
|
|
ArcConnection getByArcID(int id){
|
|
for(int i = 0; i < connections.size(); i++){
|
|
ArcConnection con = connections.get(i);
|
|
if(con.connection != null && con.connection.getID() == id){
|
|
return con;
|
|
}
|
|
}
|
|
|
|
return null;
|
|
}
|
|
|
|
class ArcConnection extends NetConnection{
|
|
public final Connection connection;
|
|
|
|
public ArcConnection(String address, Connection connection){
|
|
super(address);
|
|
this.connection = connection;
|
|
}
|
|
|
|
@Override
|
|
public boolean isConnected(){
|
|
return connection.isConnected();
|
|
}
|
|
|
|
@Override
|
|
public void sendStream(Streamable stream){
|
|
connection.addListener(new InputStreamSender(stream.stream, 512){
|
|
int id;
|
|
|
|
@Override
|
|
protected void start(){
|
|
//send an object so the receiving side knows how to handle the following chunks
|
|
StreamBegin begin = new StreamBegin();
|
|
begin.total = stream.stream.available();
|
|
begin.type = Registrator.getID(stream.getClass());
|
|
connection.sendTCP(begin);
|
|
id = begin.id;
|
|
}
|
|
|
|
@Override
|
|
protected Object next(byte[] bytes){
|
|
StreamChunk chunk = new StreamChunk();
|
|
chunk.id = id;
|
|
chunk.data = bytes;
|
|
return chunk; //wrap the byte[] with an object so the receiving side knows how to handle it.
|
|
}
|
|
});
|
|
}
|
|
|
|
@Override
|
|
public void send(Object object, SendMode mode){
|
|
try{
|
|
if(mode == SendMode.tcp){
|
|
connection.sendTCP(object);
|
|
}else{
|
|
connection.sendUDP(object);
|
|
}
|
|
}catch(Exception e){
|
|
Log.err(e);
|
|
Log.info("Error sending packet. Disconnecting invalid client!");
|
|
connection.close(DcReason.error);
|
|
|
|
ArcConnection k = getByArcID(connection.getID());
|
|
if(k != null) connections.remove(k);
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public void close(){
|
|
if(connection.isConnected()) connection.close(DcReason.closed);
|
|
}
|
|
}
|
|
|
|
@SuppressWarnings("unchecked")
|
|
public static class PacketSerializer implements NetSerializer{
|
|
|
|
@Override
|
|
public Object read(ByteBuffer byteBuffer){
|
|
byte id = byteBuffer.get();
|
|
if(id == -2){
|
|
return readFramework(byteBuffer);
|
|
}else{
|
|
Packet packet = Pools.obtain((Class<Packet>)Registrator.getByID(id).type, (Prov<Packet>)Registrator.getByID(id).constructor);
|
|
packet.read(byteBuffer);
|
|
return packet;
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public void write(ByteBuffer byteBuffer, Object o){
|
|
if(o instanceof FrameworkMessage){
|
|
byteBuffer.put((byte)-2); //code for framework message
|
|
writeFramework(byteBuffer, (FrameworkMessage)o);
|
|
}else{
|
|
if(!(o instanceof Packet)) throw new RuntimeException("All sent objects must implement be Packets! Class: " + o.getClass());
|
|
byte id = Registrator.getID(o.getClass());
|
|
if(id == -1) throw new RuntimeException("Unregistered class: " + o.getClass());
|
|
byteBuffer.put(id);
|
|
((Packet)o).write(byteBuffer);
|
|
}
|
|
}
|
|
|
|
public void writeFramework(ByteBuffer buffer, FrameworkMessage message){
|
|
if(message instanceof Ping p){
|
|
buffer.put((byte)0);
|
|
buffer.putInt(p.id);
|
|
buffer.put(p.isReply ? 1 : (byte)0);
|
|
}else if(message instanceof DiscoverHost){
|
|
buffer.put((byte)1);
|
|
}else if(message instanceof KeepAlive){
|
|
buffer.put((byte)2);
|
|
}else if(message instanceof RegisterUDP p){
|
|
buffer.put((byte)3);
|
|
buffer.putInt(p.connectionID);
|
|
}else if(message instanceof RegisterTCP p){
|
|
buffer.put((byte)4);
|
|
buffer.putInt(p.connectionID);
|
|
}
|
|
}
|
|
|
|
public FrameworkMessage readFramework(ByteBuffer buffer){
|
|
byte id = buffer.get();
|
|
|
|
if(id == 0){
|
|
Ping p = new Ping();
|
|
p.id = buffer.getInt();
|
|
p.isReply = buffer.get() == 1;
|
|
return p;
|
|
}else if(id == 1){
|
|
return new DiscoverHost();
|
|
}else if(id == 2){
|
|
return new KeepAlive();
|
|
}else if(id == 3){
|
|
RegisterUDP p = new RegisterUDP();
|
|
p.connectionID = buffer.getInt();
|
|
return p;
|
|
}else if(id == 4){
|
|
RegisterTCP p = new RegisterTCP();
|
|
p.connectionID = buffer.getInt();
|
|
return p;
|
|
}else{
|
|
throw new RuntimeException("Unknown framework message!");
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|