是否有可能找到Spacy POS标签的不确定性?

2022-05-15 00:00:00 python nlp spacy spell-checking

问题描述

我正在尝试构建一个非英语拼写检查器,它依赖于按拼写对句子进行分类,这允许我的算法然后使用词性标签和单个标记的语法依赖来确定拼写错误(在我的情况下,更具体地说:荷兰语复合词的错误拆分)。

然而,如果句子包含语法错误,例如将名词归类为动词,即使分类的单词看起来甚至不像动词,Spacy似乎也会错误地对句子进行分类。

正因为如此,我想知道是否有可能获得Spacy分类的不确定性,从而有可能判断Spacy是否正在为句子而苦苦挣扎。毕竟,如果Spacy正在为分类而苦苦挣扎,这将使我的拼写检查器更有信心地确定句子中包含错误。

有没有办法知道Spacy认为一个句子在语法上是正确的(而不必指定我的语言中所有正确句子结构的模式),或者获得分类的确定性?


根据@Sergey Bushmanov评论中的建议进行编辑:

我找到了https://spacy.io/api/tagger#predict,这可能有助于获取标记的概率。然而,我真的不确定我看到的是什么,我也没有真正理解文档中关于输出的含义。我使用以下代码:

import spacy

nlp = spacy.load('en_core_web_sm')
text = "This is an example sentence for the Spacy tagger."
doc = nlp(text)

docs = nlp(text, disable=['tagger'])
scores, tensors = nlp.tagger.predict([docs])

print(scores)
probs = tensors[0]
for p in probs:
    print(p, max(p), p.tolist().index(max(p)))

我猜这打印的是预测的一些整数表示(考虑到‘整数’和‘表示’得到相同的分数),然后句子中每个单词的96个浮点数的数组。它还列出了最高分和最高分的位置,但似乎对于大多数单词来说,p数组中有多个项获得了相似的值。现在我想知道这些数组的含义,以及如何从中提取每个分类的概率。


问题是:我如何解释此输出以获取Spacy的标记器找到的特定标记的特定概率?或者,提出同样的问题的另一种方式是:上述代码生成的输出意味着什么?


解决方案

>>> nlp = spacy.load("en_core_web_sm")
>>> tagger = nlp.get_pipe("tagger")
>>> doc = nlp("Turn left")
>>> tagger.model.predict([doc])[0][1]
array([2.4706091e-07, 9.5889463e-06, 7.8214543e-07, 1.0063847e-06,
       1.4711081e-07, 8.9995199e-05, 1.3229882e-05, 1.7524673e-07,
       1.8464769e-05, 2.4248957e-06, 1.2176755e-06, 3.3774859e-07,
       1.3199920e-06, 1.2011193e-06, 9.4455345e-06, 2.1991875e-05,
       1.6732251e-02, 1.3964747e-07, 2.0764594e-07, 7.0467541e-07,
       1.4303426e-07, 3.7962508e-07, 1.2130551e-03, 3.1479198e-07,
       4.8646534e-08, 6.1310317e-07, 1.0607551e-05, 3.7493783e-06,
       2.7809198e-08, 1.2118652e-05, 9.9081490e-03, 1.8219554e-06,
       4.7322575e-07, 1.8754436e-05, 6.2416703e-08, 9.5453437e-08,
       1.8937490e-05, 6.3916352e-03, 3.7999314e-01, 1.5741379e-03,
       5.8360571e-01, 9.6441705e-05, 1.7456010e-04, 5.1820080e-06,
       1.2672864e-06, 9.7453121e-06, 2.4000105e-05, 5.1192428e-06,
       2.4821245e-05], dtype=float32)
>>> r = [*enumerate(tagger.model.predict([doc])[0][1])]
>>> r.sort(key=lambda x: x[1])
>>> r
[(28, 2.7809198e-08), (24, 4.8646534e-08), (34, 6.24167e-08), (35, 9.545344e-08), (17, 1.3964747e-07), (20, 1.4303426e-07), (4, 1.4711081e-07), (7, 1.7524673e-07), (18, 2.0764594e-07), (0, 2.470609e-07), (23, 3.1479198e-07), (11, 3.377486e-07), (21, 3.7962508e-07), (32, 4.7322575e-07), (25, 6.1310317e-07), (19, 7.046754e-07), (2, 7.8214543e-07), (3, 1.0063847e-06), (13, 1.2011193e-06), (10, 1.2176755e-06), (44, 1.2672864e-06), (12, 1.319992e-06), (31, 1.8219554e-06), (9, 2.4248957e-06), (27, 3.7493783e-06), (47, 5.119243e-06), (43, 5.182008e-06), (14, 9.4455345e-06), (1, 9.588946e-06), (45, 9.745312e-06), (26, 1.0607551e-05), (29, 1.2118652e-05), (6, 1.3229882e-05), (8, 1.8464769e-05), (33, 1.8754436e-05), (36, 1.893749e-05), (15, 2.1991875e-05), (46, 2.4000105e-05), (48, 2.4821245e-05), (5, 8.99952e-05), (41, 9.6441705e-05), (42, 0.0001745601), (22, 0.0012130551), (39, 0.001574138), (37, 0.006391635), (30, 0.009908149), (16, 0.016732251), (38, 0.37999314), (40, 0.5836057)]

您在此处看到的前2个匹配项(在列表末尾)(38,0.37999314)、(40,0.5836057)置信度不高(~50%),因此您有一些不明确的迹象。

>>> tagger.labels
('$', "''", ',', '-LRB-', '-RRB-', '.', ':', 'ADD', 'AFX', 'CC', 'CD', 'DT', 'EX', 'FW', 'HYPH', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NFP', 'NN', 'NNP', 'NNPS', 'NNS', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB', 'XX', '``')
>>> tagger.labels[40]
'VBN'
>>> tagger.labels[38]
'VBD'

看起来有一些语言特定的标记,需要一些映射才能获得通用的POS标记。

相关文章