simple_plot_scatter_2_distrib.py
2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
local-twitter
We read 2 ditributions from file
the first 1 - we take k top topics - so we read the file up to and index
then we search for the ngram in the second file and take the frequencies
we do a scatterplot
@autor: cristina muntean
@date: 28/06/16
"""
import codecs
import logging
import sys
from collections import defaultdict
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpltools import style
def setStyle():
style.use('ggplot')
def readPreprocessedData(filename):
"""
The format if: X Y label , separated by tabs
:param filename:
:return: X,Y, labels
"""
X = list()
Y = list()
labels = list()
for line in codecs.open(filename, "r", "utf-8"):
line = line.replace("\n", "")
data = line.split("\t")
X.append(float(data[0]))
Y.append(float(data[1]))
labels.append(data[2])
return X,Y,labels
def scatter_plot(X,Y, labels, plotname):
"""
:param X:
:param Y:
:param labels: list of labels - orderred
:return:
"""
fig = plt.figure()
ax = fig.add_subplot(111)
T = np.arctan2(Y, X)
# plt.axes([0.025, 0.025, 0.95, 0.95])
plt.scatter(X, Y, s=75, c=T, alpha=.5)
# Plot diagonal line (45 degrees)
plt.plot(np.arange(0.0, 1.0, 0.01), np.arange(0.0, 1.0, 0.01))
# for i, xy in enumerate(zip(X, Y)): # <--
# # ax.annotate('(%s, %s)' % xy, xy=xy, textcoords='data') # <--
# ax.annotate(labels[i], xy=xy, textcoords='data') # <--
# plt.xlim(-1.5, 1.5)
# plt.xticks(())
# plt.ylim(-1.5, 1.5)
# plt.yticks(())
# ax1.scatter(x, y, color='blue', s=5, edgecolor='none')
# ax1.set_aspect(1. / ax1.get_data_ratio()) # make axes square
plt.tight_layout()
plt.savefig(plotname)
#plt.show()
if __name__ == '__main__':
logger = logging.getLogger("simple_plot_scatter_2_distrib.py")
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s;%(levelname)s;%(message)s")
if len(sys.argv) != 3:
print "You need to pass the following 2 params: <inputFile1> <plotname.pdf>"
sys.exit(-1)
inputFile = sys.argv[1]
plotName = sys.argv[2]
setStyle()
X,Y,labels = readPreprocessedData(inputFile)
print len(X), len(Y), len(labels)
scatter_plot(X, Y, labels, plotName)
# for (x,y,label) in zip(X,Y,labels):
# if x > 20000 and y < 45000 : print label, x
# Yprime = [y-x for (x,y,label) in zip(X,Y, labels) ]
# print X[:10]
# print Y[:10]
# print Yprime[:10]
# scatter_plot(X, Yprime, labels, plotName.replace(".pdf", "") + "Y_prime.pdf")