在go语言中模仿NumPy的方式创建矩阵的代码示例

2023-06-01 00:00:00 示例 矩阵 模仿

你已经习惯了NumPy创建矩阵的方式,你很难适应Gonum的方法。Golang中是否有任何函数可以提供类似于NumPy从输入字符串创建矩阵的方法?比如说:

  >>> import numpy as np
  >>> A = np.mat('[1 2;3 4]')
  >>> A
  matrix([[1, 2],
          [3, 4]])

解决方案:

我创建了这个函数,作为我自己使用的一个快速和肮脏的方法,可能对你也有用。

但是,请记住,这是一个原始的Golang函数,没有对输入的字符串进行错误检查。


示例代码

package main

  import (
          "fmt"
          "github.com/gonum/matrix/mat64"
          "strconv"
          "strings"
  )
 
  // 警告:此函数不对输入字符串进行错误检查。
  // 你可能需要修改这个函数,以便对诸如以下内容进行合理性检查。
  // 如只有1对[ ]的情况
  // 所有的列和行都有数字或0
  // 所有的行都是一样大的 [pynum会抛出 "ValueError: Rows not the same size."]
  func matrix(str string) *mat64.Dense {
          // remove [ and ]
          str = strings.Replace(str, "[", "", -1)
          str = strings.Replace(str, "]", "", -1)
          // calculate total number of rows
          parts := strings.SplitN(str, ";", -1)
          rows := len(parts)
          // calculate total number of columns
          colSlice := strings.Fields(parts[0])
          columns := len(colSlice)
          // replace all ; with space
          str = strings.Replace(str, ";", " ", -1)
         
          // 将str转换为slice
              // 取自于 在go语言中将字符串转换为数组/片断(array/slice)
              // https://www.zongscan.com/demo333/96243.html
          elements := strings.Fields(str)
         
          //fmt.Println("Rows : ", rows)
          //fmt.Println("Columns : ", columns)
         
          //为新矩阵填充数据(密集型)
          data := make([]float64, rows*columns)
          for i := range data {
                  floatValue, _ := strconv.ParseFloat(elements[i], 64)
                  data[i] = floatValue
          }
         
          M := mat64.NewDense(rows, columns, data)
          return M
  }
 
  func main() {
          str := "[1 2 3 4 5 6 7 8;9 10 11 12 13 14 15 16;8 7 6 5 4 3 2 1]"
          m := matrix(str)
          // 打印所有m元素
          fmt.Printf("m :\n%v\n\n", mat64.Formatted(m, mat64.Prefix(""), mat64.Excerpt(0)))
             
              //不检查非浮点或整数的情况
              //将用0替换字母表
          m1 := matrix("[1 x 3; 4 5 a]")
         
          //打印所有M1元素
          fmt.Printf("m1 :\n%v\n\n", mat64.Formatted(m1, mat64.Prefix(""), mat64.Excerpt(0)))
          m2 := matrix("[1.1 2 3.4;4 5 68.8]")
         
          // 打印所有m2元素
          fmt.Printf("m2 :\n%v\n\n", mat64.Formatted(m2, mat64.Prefix(""), mat64.Excerpt(0)))
  }



输出:

m :
「1 2 3 4 5 6 7 8」
「9 10 11 12 13 14 15 16」
「8 7 6 5 4 3 2 1」
m1 :
「1 0 3」
「4 5 0」
m2 :
「 1.1 2 3.4」
「 4 5 68.8」

相关文章