all_in_one_plot_scatter_2_distrib.py 2.92 KB
"""
local-twitter

We read 2 distributions from file
the first 1 - we take k top topics - so we read the file up to and index
then we search for the ngram in the second file and take the frequencies
we do a scatterplot


@autor: cristina muntean
@date: 28/06/16
"""
import codecs
import logging
import sys
from collections import defaultdict
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpltools import style


def setStyle():
    style.use('ggplot')


def loadData(filename):
    # data = np.genfromtxt(filename, dtype=str,  delimiter='\t', usecols=np.arange(0,2)) #what happens when they're less
    data = pd.read_csv(filename, delimiter='\t', dtype=str)
    return data

def readFromFile(filename):
    docList = list()
    for line in codecs.open(filename, "r", "utf-8"):
        if len(line.split("\t")) == 2:
            word, counter = line.replace("\n", "").split("\t")
            docList.append(tuple([word, int(counter)]))
    return docList

def scatter_plot(X,Y, labels, plotname):
    """

    :param X:
    :param Y:
    :param labels: list of labels - orderred
    :return:
    """
    fig = plt.figure()
    ax = fig.add_subplot(111)

    T = np.arctan2(Y, X)

    # plt.axes([0.025, 0.025, 0.95, 0.95])
    plt.scatter(X, Y, s=75, c=T, alpha=.5)

    # for i, xy in enumerate(zip(X, Y)):  # <--
    #     # ax.annotate('(%s, %s)' % xy, xy=xy, textcoords='data')  # <--
    #     ax.annotate(labels[i], xy=xy, textcoords='data')  # <--

    # plt.xlim(-1.5, 1.5)
    # plt.xticks(())
    # plt.ylim(-1.5, 1.5)
    # plt.yticks(())

    # ax1.scatter(x, y, color='blue', s=5, edgecolor='none')
    # ax1.set_aspect(1. / ax1.get_data_ratio())  # make axes square

    plt.tight_layout()
    plt.savefig(plotname)
    plt.show()


if __name__ == '__main__':

    logger = logging.getLogger("all_in_one_plot_scatter_2_distrib.py")
    logging.basicConfig(level=logging.DEBUG, format="%(asctime)s;%(levelname)s;%(message)s")

    if len(sys.argv) != 6:
        print "You need to pass the following 5 params: <inputFile1> <inputFile2> <k> <plotname.pdf> <data-file>"
        sys.exit(-1)
    inputFile1 = sys.argv[1]
    inputFile2 = sys.argv[2]
    k = int(sys.argv[3])
    plotName = sys.argv[4]
    dataFile = codecs.open(sys.argv[5], "w", "utf8")

    setStyle()
    a = readFromFile(inputFile1)
    b = readFromFile(inputFile2)
    print len(a), len(b)

    bDict = {rows[0]: int(rows[1]) for rows in b}
    bDict = defaultdict(int, bDict)

    X = list()
    Y = list()

    local_topics = []
    for word, counter in a[:k]:
        X.append(int(counter))
        Y.append(int(bDict[word]))
        if int(bDict[word]) < 10:
            local_topics.append(word)
    labels = [row[0] for row in a[:k]]

    print local_topics
    print len(X), len(Y), len(labels)

    for (x,y,label) in zip(X,Y,labels):
        dataFile.write("{}\t{}\t{}\n".format(str(x),str(y),label))
    dataFile.close()

    scatter_plot(X, Y, labels, plotName)