在go语言中使用Goroutine + Channel实现端口扫描器代码示例
端口扫描器,顾名思义就是用来批量扫描服务器端口的。
使用Goroutine + Channel的方式,示例代码:
package main
import (
"fmt"
"net"
"sync"
"time"
"unsafe"
)
func main() {
tcpScanByGoroutineWithChannel("127.0.0.1", 1, 65535)
}
func handleWorker(ip string, ports chan int, wg *sync.WaitGroup) {
for p := range ports {
address := fmt.Sprintf("%s:%d", ip, p)
conn, err := net.Dial("tcp", address)
if err != nil {
fmt.Printf("[info] %s Close \n", address)
wg.Done()
continue
}
conn.Close()
fmt.Printf("[info] %s Open \n", address)
wg.Done()
}
}
func tcpScanByGoroutineWithChannel(ip string, portStart int, portEnd int) {
start := time.Now()
// 参数校验
isok := verifyParam(ip, portStart, portEnd)
if isok == false {
fmt.Printf("[Exit]\n")
}
ports := make(chan int, 100)
var wg sync.WaitGroup
for i := 0; i < cap(ports); i++ {
go handleWorker(ip, ports, &wg)
}
for i := portStart; i <= portEnd; i++ {
wg.Add(1)
ports <- i
}
wg.Wait()
close(ports)
cost := time.Since(start)
fmt.Printf("[tcpScanByGoroutineWithChannel] cost %s second \n", cost)
}
func verifyParam(ip string, portStart int, portEnd int) bool {
netip := net.ParseIP(ip)
if netip == nil {
fmt.Println("[Error] ip type is must net.ip")
return false
}
fmt.Printf("[Info] ip=%s | ip type is: %T | ip size is: %d \n", netip, netip, unsafe.Sizeof(netip))
if portStart < 1 || portEnd > 65535 {
fmt.Println("[Error] port is must in the range of 1~65535")
return false
}
fmt.Printf("[Info] port start:%d end:%d \n", portStart, portEnd)
return true
}
进一步优化,引入两个Channel方式,示例代码:
package main
import (
"fmt"
"net"
"sort"
"time"
"unsafe"
)
func main() {
tcpScanByGoroutineWithChannelAndSort("127.0.0.1", 1, 65535)
}
//该函数用于检查一个给定的IP地址的端口是否开放或关闭
func handleWorker(ip string, ports chan int, results chan int) {
for p := range ports {
address := fmt.Sprintf("%s:%d", ip, p)
conn, err := net.Dial("tcp", address)
if err != nil {
// fmt.Printf("[debug] ip %s Close \n", address)
results <- (-p)
continue
}
// fmt.Printf("[debug] ip %s Open \n", address)
conn.Close()
results <- p
}
}
func tcpScanByGoroutineWithChannelAndSort(ip string, portStart int, portEnd int) {
start := time.Now()
// 参数校验
isok := verifyParam(ip, portStart, portEnd)
if isok == false {
fmt.Printf("[Exit]\n")
}
ports := make(chan int, 50)
results := make(chan int)
var openSlice []int
var closeSlice []int
// 任务生产者-分发任务 (新起一个 goroutinue ,进行分发数据)
go func(a int, b int) {
for i := a; i <= b; i++ {
ports <- i
}
}(portStart, portEnd)
// 任务消费者-处理任务 (每一个端口号都分配一个 goroutinue ,进行扫描)
// 结果生产者-每次得到结果 再写入 结果 chan 中
for i := 0; i < cap(ports); i++ {
go handleWorker(ip, ports, results)
}
// 结果消费者-等待收集结果 (main中的 goroutinue 不断从 chan 中阻塞式读取数据)
for i := portStart; i <= portEnd; i++ {
resPort := <-results
if resPort > 0 {
openSlice = append(openSlice, resPort)
} else {
closeSlice = append(closeSlice, -resPort)
}
}
// 关闭 chan
close(ports)
close(results)
// 排序
sort.Ints(openSlice)
sort.Ints(closeSlice)
// 输出
for _, p := range openSlice {
fmt.Printf("[info] %s:%-8d Open\n", ip, p)
}
// for _, p := range closeSlice {
// fmt.Printf("[info] %s:%-8d Close\n", ip, p)
// }
cost := time.Since(start)
fmt.Printf("[tcpScanByGoroutineWithChannelAndSort] cost %s second \n", cost)
}
func verifyParam(ip string, portStart int, portEnd int) bool {
netip := net.ParseIP(ip)
if netip == nil {
fmt.Println("[Error] ip type is must net.ip")
return false
}
fmt.Printf("[Info] ip=%s | ip type is: %T | ip size is: %d \n", netip, netip, unsafe.Sizeof(netip))
if portStart < 1 || portEnd > 65535 {
fmt.Println("[Error] port is must in the range of 1~65535")
return false
}
fmt.Printf("[Info] port start:%d end:%d \n", portStart, portEnd)
return true
}
相关文章