all_in_one_plot_scatter_2_distrib.py
2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
local-twitter
We read 2 distributions from file
the first 1 - we take k top topics - so we read the file up to and index
then we search for the ngram in the second file and take the frequencies
we do a scatterplot
@autor: cristina muntean
@date: 28/06/16
"""
import codecs
import logging
import sys
from collections import defaultdict
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpltools import style
def setStyle():
style.use('ggplot')
def loadData(filename):
# data = np.genfromtxt(filename, dtype=str, delimiter='\t', usecols=np.arange(0,2)) #what happens when they're less
data = pd.read_csv(filename, delimiter='\t', dtype=str)
return data
def readFromFile(filename):
docList = list()
for line in codecs.open(filename, "r", "utf-8"):
if len(line.split("\t")) == 2:
word, counter = line.replace("\n", "").split("\t")
docList.append(tuple([word, int(counter)]))
return docList
def scatter_plot(X,Y, labels, plotname):
"""
:param X:
:param Y:
:param labels: list of labels - orderred
:return:
"""
fig = plt.figure()
ax = fig.add_subplot(111)
T = np.arctan2(Y, X)
# plt.axes([0.025, 0.025, 0.95, 0.95])
plt.scatter(X, Y, s=75, c=T, alpha=.5)
# for i, xy in enumerate(zip(X, Y)): # <--
# # ax.annotate('(%s, %s)' % xy, xy=xy, textcoords='data') # <--
# ax.annotate(labels[i], xy=xy, textcoords='data') # <--
# plt.xlim(-1.5, 1.5)
# plt.xticks(())
# plt.ylim(-1.5, 1.5)
# plt.yticks(())
# ax1.scatter(x, y, color='blue', s=5, edgecolor='none')
# ax1.set_aspect(1. / ax1.get_data_ratio()) # make axes square
plt.tight_layout()
plt.savefig(plotname)
plt.show()
if __name__ == '__main__':
logger = logging.getLogger("all_in_one_plot_scatter_2_distrib.py")
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s;%(levelname)s;%(message)s")
if len(sys.argv) != 6:
print "You need to pass the following 5 params: <inputFile1> <inputFile2> <k> <plotname.pdf> <data-file>"
sys.exit(-1)
inputFile1 = sys.argv[1]
inputFile2 = sys.argv[2]
k = int(sys.argv[3])
plotName = sys.argv[4]
dataFile = codecs.open(sys.argv[5], "w", "utf8")
setStyle()
a = readFromFile(inputFile1)
b = readFromFile(inputFile2)
print len(a), len(b)
bDict = {rows[0]: int(rows[1]) for rows in b}
bDict = defaultdict(int, bDict)
X = list()
Y = list()
local_topics = []
for word, counter in a[:k]:
X.append(int(counter))
Y.append(int(bDict[word]))
if int(bDict[word]) < 10:
local_topics.append(word)
labels = [row[0] for row in a[:k]]
print local_topics
print len(X), len(Y), len(labels)
for (x,y,label) in zip(X,Y,labels):
dataFile.write("{}\t{}\t{}\n".format(str(x),str(y),label))
dataFile.close()
scatter_plot(X, Y, labels, plotName)