java socket bio 改造为 netty nio

打印 上一主题 下一主题

主题 1812|帖子 1812|积分 5438

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
        公司早些时候接入一款康健监测设备,由于业务原因克日把端口袒露在公网后,每当被恶意连接时系统会创建大量线程,在排盘问题是发现是利用了厂家提供的服务端demo代码,在代码中利用的是java 原生socket,在发现连接后利用独立线程处理后续通讯,占用系统资源造成了服务宕机,因此必要进行改造。
        厂家提供的demo代码如下:
  1. import java.io.IOException;
  2. import java.net.ServerSocket;
  3. import java.net.Socket;
  4. import java.util.ArrayList;
  5. import java.util.List;
  6. public class Demo {
  7.     public static void main(String[] args) {
  8.         int port = 8003;
  9.         if (args.length == 1) {
  10.             port = Integer.parseInt(args[0]);
  11.         }
  12.         ServerSocket ss;
  13.         try {
  14.             ss = new ServerSocket(port);
  15.         }
  16.         catch (Exception e) {
  17.             System.out.println("服务端socket失败 port = " + port);
  18.             return;
  19.         }
  20.         System.out.println("启动socket监听 端口:" + port);
  21.         List<Socket> socketList = new ArrayList<>();
  22.         while (true) {
  23.             try {
  24.                 Socket socket = ss.accept();
  25.                 if (socket == null || socket.isClosed()) {
  26.                     socketList.remove(socket);
  27.                     continue;
  28.                 }
  29.                 if (socketList.contains(socket)) {
  30.                     continue;
  31.                 }
  32.                 socketList.add(socket);
  33.                 System.out.println("socket连接 address = " + socket.getInetAddress().toString() + " port = " + socket.getPort());
  34.                 new Thread(new HealthReadThread(socket)).start();
  35.             }
  36.             catch (IOException e) {
  37.                 System.out.println(e.getMessage());
  38.             }
  39.         }
  40.     }
  41. }
复制代码
  1. import java.io.*;
  2. import java.net.Socket;
  3. import java.util.ArrayList;
  4. import java.util.HashMap;
  5. import java.util.List;
  6. public class HealthReadThread implements Runnable {
  7.     private Socket socket;
  8.     HealthReadThread(Socket socket) {
  9.         this.socket = socket;
  10.     }
  11.     private static String message = "";
  12.     @Override
  13.     public void run() {
  14.         try {
  15.             //输入
  16.             InputStream inPutStream = socket.getInputStream();
  17.             BufferedInputStream bis = new BufferedInputStream(inPutStream);
  18. //            BufferedReader br = new BufferedReader(new InputStreamReader(inPutStream));
  19.             //输出
  20.             OutputStream outputStream = socket.getOutputStream();
  21.             BufferedOutputStream bw = new BufferedOutputStream(outputStream);
  22.             String ip = socket.getInetAddress().getHostAddress();
  23.             int port = socket.getPort();
  24.             String readStr = "";
  25. //            char[] buf;
  26.             byte[] buf;
  27.             int readLen = 0;
  28.             while (true) {
  29.                 if (socket.isClosed()) {
  30.                     break;
  31.                 }
  32.                 buf = new byte[1024];
  33.                 try {
  34.                     readLen = bis.read(buf);
  35.                     if (readLen <= 0) {
  36. //                        System.out.println(Thread.currentThread().getId() + "线程: " + "ip地址:" + ip + " 端口地址:" + port + "暂无接收数据");
  37.                         continue;
  38.                     }
  39.                     System.out.println(Thread.currentThread().getId() + "线程: " + "ip地址:" + ip + " 端口地址:" + port + " 接收到原始命令长度:" + readLen);
  40.                     readStr = StringUtils.byteToHexString(buf, readLen);
  41. //                    readStr = new String(buf ,0 , readLen);
  42.                 } catch (IOException e) {
  43.                     System.out.println(e.getMessage());
  44.                     socket.close();
  45. //                    continue;
  46.                 }
  47.                 if (readStr == null || "".equals(readStr)) {
  48.                     continue;
  49.                 }
  50.                 // 省略业务代码
  51.             }
  52.         }
  53.         catch (Exception e) {
  54.             System.out.println(e.getMessage());
  55.         }
  56.     }
  57. }
复制代码
利用netty进行改造:
  1. import io.netty.bootstrap.ServerBootstrap;
  2. import io.netty.channel.ChannelFuture;
  3. import io.netty.channel.ChannelInitializer;
  4. import io.netty.channel.EventLoopGroup;
  5. import io.netty.channel.nio.NioEventLoopGroup;
  6. import io.netty.channel.socket.SocketChannel;
  7. import io.netty.channel.socket.nio.NioServerSocketChannel;
  8. import lombok.extern.slf4j.Slf4j;
  9. import org.springframework.boot.ApplicationArguments;
  10. import org.springframework.boot.ApplicationRunner;
  11. import org.springframework.stereotype.Component;
  12. @Slf4j
  13. @Component
  14. public class DeviceNettyServer implements ApplicationRunner {
  15.     @Override
  16.     public void run(ApplicationArguments args) throws Exception {
  17.         start();
  18.     }
  19.     public void start() {
  20.         Thread thread = new Thread(() -> {
  21.             // 配置服务端的NIO线程组
  22.             EventLoopGroup bossGroup = new NioEventLoopGroup(1);
  23.             EventLoopGroup workerGroup = new NioEventLoopGroup(4);
  24.             ServerBootstrap b = new ServerBootstrap();
  25.             b.group(bossGroup, workerGroup)
  26.                     // 使用 NIO 方式进行网络通信
  27.                     .channel(NioServerSocketChannel.class)
  28.                     .childHandler(new ChannelInitializer<SocketChannel>() {
  29.                         @Override
  30.                         public void initChannel(SocketChannel ch) throws Exception {
  31.                             // 添加自己的处理器
  32.                             ch.pipeline().addLast(new DeviceMsgHandler());
  33.                         }
  34.                     });
  35.             try {
  36.                 int port1 = 8081;
  37.                 int port2 = 8082;
  38.                 // 绑定一个端口并且同步,生成一个ChannelFuture对象
  39.                 ChannelFuture f1 = b.bind(port1).sync();
  40.                 ChannelFuture f2 = b.bind(port2).sync();
  41.                 log.info("启动监听, 端口:" + port1 + "、" + port2);
  42.                 // 对关闭通道进行监听
  43.                 f1.channel().closeFuture().sync();
  44.                 f2.channel().closeFuture().sync();
  45.             } catch (Exception e) {
  46.                 log.error("启动监听失败", e);
  47.             } finally {
  48.                 workerGroup.shutdownGracefully();
  49.                 bossGroup.shutdownGracefully();
  50.             }
  51.         });
  52.         thread.setName("DeviceNettyServer");
  53.         thread.start();
  54.     }
  55. }
复制代码
  1. import com.alibaba.fastjson.JSON;
  2. import com.alibaba.fastjson.JSONObject;
  3. import io.netty.buffer.ByteBuf;
  4. import io.netty.buffer.Unpooled;
  5. import io.netty.channel.Channel;
  6. import io.netty.channel.ChannelHandlerContext;
  7. import io.netty.channel.SimpleChannelInboundHandler;
  8. import lombok.extern.slf4j.Slf4j;
  9. import org.apache.commons.collections.CollectionUtils;
  10. import org.apache.commons.lang3.StringUtils;
  11. import java.util.*;
  12. import java.util.concurrent.ConcurrentHashMap;
  13. @Slf4j
  14. public class DeviceMsgHandler extends SimpleChannelInboundHandler<ByteBuf> {
  15.     /**
  16.      * 已连接的设备
  17.      */
  18.     private static final ConcurrentHashMap<Channel, DeviceDTO> CONNECTION_DEVICE_MAP = new ConcurrentHashMap<>(8);
  19.     /**
  20.      * 一旦连接,第一个被执行
  21.      */
  22.     @Override
  23.     public void handlerAdded(ChannelHandlerContext ctx) {
  24.         String remoteAddress = ctx.channel().remoteAddress().toString();
  25.         log.info("发现连接, remoteAddress: " + remoteAddress);
  26.         // 发送查询设备信息指令
  27.         sendQuery(ctx.channel());
  28.     }
  29.     /**
  30.      * 读取数据
  31.      */
  32.     @Override
  33.     protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) {
  34.         byte[] bytes = new byte[msg.readableBytes()];
  35.         msg.readBytes(bytes);
  36.         // 忽略业务处理代码
  37.         // 传递给下一个处理器
  38.         ctx.fireChannelRead(msg);
  39.     }
  40.     /**
  41.      * 连接断开
  42.      *
  43.      * @param ctx
  44.      */
  45.     @Override
  46.     public void handlerRemoved(ChannelHandlerContext ctx) {
  47.         log.info("连接断开, remoteAddress: " + ctx.channel().remoteAddress());
  48.         CONNECTION_DEVICE_MAP.remove(ctx.channel());
  49.     }
  50.     /**
  51.      * 连接异常
  52.      *
  53.      * @param ctx
  54.      * @param cause
  55.      */
  56.     @Override
  57.     public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
  58.         log.info("连接异常, remoteAddress: " + ctx.channel().remoteAddress());
  59.         CONNECTION_DEVICE_MAP.remove(ctx.channel());
  60.       
  61.     }
复制代码
经过改造后利用了4个worker线程进行读写,消除了原先恶意连接造成线程数无线扩大的问题,利用nio也极大的提高了系统资源利用率。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

tsx81428

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表