simple_plot_scatter_2_distrib.py 2.58 KB
"""
local-twitter

We read 2 ditributions from file
the first 1 - we take k top topics - so we read the file up to and index
then we search for the ngram in the second file and take the frequencies
we do a scatterplot


@autor: cristina muntean
@date: 28/06/16
"""
import codecs
import logging
import sys
from collections import defaultdict
import pandas as pd

import numpy as np
import matplotlib.pyplot as plt
from mpltools import style


def setStyle():
    style.use('ggplot')

def readPreprocessedData(filename):
    """
    The format if: X Y label , separated by tabs
    :param filename:
    :return: X,Y, labels
    """
    X = list()
    Y = list()
    labels = list()

    for line in codecs.open(filename, "r", "utf-8"):
        line = line.replace("\n", "")
        data = line.split("\t")
        X.append(float(data[0]))
        Y.append(float(data[1]))
        labels.append(data[2])
    return X,Y,labels

def scatter_plot(X,Y, labels, plotname):
    """

    :param X:
    :param Y:
    :param labels: list of labels - orderred
    :return:
    """
    fig = plt.figure()
    ax = fig.add_subplot(111)

    T = np.arctan2(Y, X)

    # plt.axes([0.025, 0.025, 0.95, 0.95])
    plt.scatter(X, Y, s=75, c=T, alpha=.5)

    # Plot diagonal line (45 degrees)
    plt.plot(np.arange(0.0, 1.0, 0.01), np.arange(0.0, 1.0, 0.01))

    # for i, xy in enumerate(zip(X, Y)):  # <--
    #     # ax.annotate('(%s, %s)' % xy, xy=xy, textcoords='data')  # <--
    #     ax.annotate(labels[i], xy=xy, textcoords='data')  # <--

    # plt.xlim(-1.5, 1.5)
    # plt.xticks(())
    # plt.ylim(-1.5, 1.5)
    # plt.yticks(())

    # ax1.scatter(x, y, color='blue', s=5, edgecolor='none')
    # ax1.set_aspect(1. / ax1.get_data_ratio())  # make axes square

    plt.tight_layout()
    plt.savefig(plotname)
    #plt.show()


if __name__ == '__main__':

    logger = logging.getLogger("simple_plot_scatter_2_distrib.py")
    logging.basicConfig(level=logging.DEBUG, format="%(asctime)s;%(levelname)s;%(message)s")

    if len(sys.argv) != 3:
        print "You need to pass the following 2 params: <inputFile1> <plotname.pdf>"
        sys.exit(-1)
    inputFile = sys.argv[1]
    plotName = sys.argv[2]

    setStyle()
    X,Y,labels = readPreprocessedData(inputFile)
    print len(X), len(Y), len(labels)
    scatter_plot(X, Y, labels, plotName)

    # for (x,y,label) in zip(X,Y,labels):
    #     if x > 20000 and y < 45000 : print label, x


    # Yprime = [y-x for (x,y,label) in zip(X,Y, labels) ]
    # print X[:10]
    # print Y[:10]
    # print Yprime[:10]
    # scatter_plot(X, Yprime, labels, plotName.replace(".pdf", "") + "Y_prime.pdf")