import jieba
from gensim import corpora
from gensim import models
from scipy.sparse import csr_matrix
from sklearn import svm
from sklearn.model_selection import train_test_split
import os
import logging
import pickle
import codecs
import glob
from collections import defaultdict
import string
import datetime
logging.basicConfig(level=logging.WARNING,
format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
datefmt='%a, %d %b %Y %H:%M:%S',
)
class classfication(object):
def __init__(self, path, sample):
self.path = path
self.sample = sample
self.dictionary_path = "model/dictionary.dict"
self.tfIdfPath = "model/tfidf"
self.lsiModel = "model/fullLsi.model"
self.lsiPath = "model/lsi"
self.predictor = "model/predictor.model"
self.tag = os.listdir(self.path)
def _fullTagFile(self):
self.tagFile = {}
for tag in self.tag:
fullPath = os.path.join(self.path, tag)
fileName = glob.glob(os.path.join(fullPath, "*.txt"))
self.tagFile[tag] = fileName
return self.tagFile
def _segement(self, filepath):
words_list = []
stops_words = set([i.strip() for i in codecs.open("stop_words.txt", encoding="utf-8").readlines()])
with codecs.open(filepath, encoding="utf-8") as fp:
for line in fp:
line = line.replace("\u3000", "").replace("\n", "")
words_list.extend([i for i in jieba.cut(line, cut_all=False)
if i not in stops_words])
return words_list
def _getDictionary(self):
globa N
dictionary = corpora.Dictionary()
for tag in self.fullTagFile:
tagPath = self.fullTagFile[tag]
for i, filepath in enumerate(tagPath):
if i % self.sample == 0:
word_list = self._segement(filepath)
dictionary.add_documents([word_list])
N += 1
if N % 1000 == 0:
print('{t} *** {i} \t docs has been dealed'
.format(i=N, t=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
small_freq_ids = [tokenid for tokenid, docfreq in dictionary.dfs.items() if docfreq < 5]
dictionary.filter_tokens(small_freq_ids)
dictionary.compactify()
dictionary.save(self.dictionary_path)
return dictionary
def _getTfIdf(self, dictionary):
global N
tagTfidf = defaultdict(list)
tfIdfModel = models.TfidfModel(dictionary=dictionary)
for tag in self.tagFile:
tagPath = self.tagFile[tag]
for i, filepath in enumerate(tagPath):
if i % self.sample == 0:
word_list = self._segement(filepath)
doc2bow = dictionary.doc2bow(word_list)
doc_tfidf = tfIdfModel[doc2bow]
tagTfidf[tag].append(doc_tfidf)
N += 1
if N % 1000 == 0:
print('{t} *** {i} \t docs has been dealed'
.format(i=N, t=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
if not os.path.isdir(self.tfIdfPath):
os.makedirs(self.tfIdfPath)
for tag in tagTfidf:
corpora.MmCorpus.serialize(os.path.join(self.tfIdfPath, "%s.mm" % tag),
tagTfidf.get(tag), id2word=dictionary)
return tagTfidf
def _getLsi(self, dictionary, tagTfidf):
corpus = []
for value in tagTfidf.values():
corpus.extend(value)
lsi_model = models.LsiModel(corpus=corpus, id2word=dictionary, num_topics=50)
with open(self.lsiModel, mode="wb") as fp:
pickle.dump(lsi_model, fp)
corpus_lsi = {}
if not os.path.isdir(self.lsiPath):
os.makedirs(self.lsiPath)
for tag in tagTfidf:
corpus = [lsi_model[doc] for doc in tagTfidf.get(tag)]
corpus_lsi[tag] = corpus
corpora.MmCorpus.serialize(os.path.join(self.lsiPath, '%s.mm' % tag),
corpus, id2word=dictionary)
return corpus_lsi
def _getPredictor(self, corpus_lsi):
corpus_lsi_total = []
tag_list_all = []
for index, tag in enumerate(self.tag):
temp = corpus_lsi[tag]
corpus_lsi_total.extend(temp)
tag_list_all.extend([index] * len(temp))
lsi_matrix = self._csr_matrix(corpus_lsi_total)
x_train, x_test, y_train, y_test = train_test_split(lsi_matrix, tag_list_all, test_size=0.2, random_state=422)
clf = svm.LinearSVC()
clf_res = clf.fit(x_train, y_train)
x_test_pred = clf_res.predict(x_test)
accuracy = sum([1 for i, j in zip(x_test_pred, y_test) if i == j]) / len(x_test)
print('=== 分类训练完毕,分类结果如下 ===')
print('测试集正确率: {e}'.format(e=accuracy))
with open(self.predictor, "wb") as fp:
pickle.dump(clf_res, fp)
return clf_res
def _csr_matrix(self, corpus_lsi, type="train"):
data = []
rows = []
columns = []
line_count = 0
if type == "train":
for line in corpus_lsi:
for elem in line:
rows.append(line_count)
columns.append(elem[0])
data.append(elem[1])
line_count += 1
lsi_array = csr_matrix((data, (rows, columns))).toarray()
elif type == "test":
for item in corpus_lsi:
data.append(item[1])
columns.append(item[0])
rows.append(0)
lsi_array = csr_matrix((data, (rows, columns))).toarray()
return lsi_array
def train(self):
self.fullTagFile = self._fullTagFile()
if os.path.exists(self.dictionary_path):
dictionary = corpora.Dictionary.load(self.dictionary_path)
else:
dictionary = self._getDictionary()
tagTfidf = {}
if not os.path.exists(self.tfIdfPath):
tagTfidf = self._getTfIdf(dictionary)
else:
filePath = glob.glob(os.path.join(self.tfIdfPath, "*.mm"))
for file in filePath:
tag = os.path.split(file)[-1].split(".")[0]
tagTfidf[tag] = corpora.MmCorpus(file)
corpus_lsi = {}
if not os.path.exists(self.lsiPath):
corpus_lsi = self._getLsi(dictionary, tagTfidf)
else:
filePath = glob.glob(os.path.join(self.tfIdfPath, "*.mm"))
for file in filePath:
tag = os.path.split(file)[-1].split(".")[0]
corpus_lsi[tag] = corpora.MmCorpus(file)
if os.path.exists(self.lsiModel):
with open(self.lsiModel, 'rb') as fp:
lsi_model = pickle.load(fp)
else:
corpus = []
for value in tagTfidf.values():
corpus.extend(value)
lsi_model = models.LsiModel(corpus=corpus, id2word=dictionary, num_topics=50)
if not os.path.exists(self.predictor):
predictor = self._getPredictor(corpus_lsi)
else:
with open(self.predictor, 'rb') as fp:
predictor = pickle.load(fp)
return predictor, dictionary, lsi_model
def predict(self, sentences):
predictor, dictionary, lsi_model = self.train()
demo_doc = list(jieba.cut(sentences, cut_all=False))
demo_bow = dictionary.doc2bow(demo_doc)
tfidf_model = models.TfidfModel(dictionary=dictionary)
demo_tfidf = tfidf_model[demo_bow]
demo_lsi = lsi_model[demo_tfidf]
demo_matrix = self._csr_matrix(demo_lsi, type="test")
x = predictor.predict(demo_matrix)
print(self.tag[x[0]])
if __name__ == '__main__':
train = classfication(r"E:\迅雷下载\THUCNews\THUCNews", sample=10)
test = train.predict(''' 股价现在怎么样了 ''')