| |
| |
| |
| |
| |
| |
| |
| |
| |
| import tensorflow.compat.v1 as tf |
| import tensorflow_hub as hub |
| import numpy as np |
| import matplotlib.pyplot as plt |
| |
| |
| |
| |
| tf.compat.v1.disable_eager_execution() |
| embed = hub.Module(r"D:\nlp\model") |
| |
| |
| def convert_text_2_dot_vector(messages): |
| similarity_input_placeholder = tf.placeholder(tf.string, shape=None) |
| similarity_message_encodings = embed(similarity_input_placeholder) |
| with tf.Session() as session: |
| session.run(tf.global_variables_initializer()) |
| session.run(tf.tables_initializer()) |
| message_embeddings_ = session.run(similarity_message_encodings, |
| feed_dict={similarity_input_placeholder: messages}) |
| corr = np.inner(message_embeddings_, message_embeddings_) |
| print(corr) |
| return corr |
| |
| |
| def heatmap(x_labels, y_labels, values): |
| fig, ax = plt.subplots() |
| im = ax.imshow(values) |
| |
| |
| ax.set_xticks(np.arange(len(x_labels))) |
| ax.set_yticks(np.arange(len(y_labels))) |
| |
| ax.set_xticklabels(x_labels) |
| ax.set_yticklabels(y_labels) |
| |
| |
| plt.setp(ax.get_xticklabels(), rotation=45, ha="right", fontsize=10, rotation_mode="anchor") |
| |
| |
| for i in range(len(y_labels)): |
| for j in range(len(x_labels)): |
| text = ax.text(j, i, "%.2f"%values[i, j], |
| ha="center", va="center", color="w", fontsize=6) |
| fig.tight_layout() |
| plt.show() |
| |
| |
| if __name__ == '__main__': |
| |
| messages = [ |
| |
| "My phone is not good.", |
| "Your cellphone looks great.", |
| |
| |
| "Will it snow tomorrow?", |
| "Recently a lot of hurricanes have hit the US", |
| |
| |
| "An apple a day, keeps the doctors away", |
| "Eating strawberries is healthy", |
| ] |
| dot_vec = convert_text_2_dot_vector(messages) |
| heatmap(messages, messages, dot_vec) |