File size: 6,168 Bytes
0d3b8f7
 
 
 
 
 
 
3690e30
0d3b8f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d16c16
0d3b8f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#
# Copyright 2021 Systems & Technology Research. All rights reserved.
# This software and associated documentation is subject to the use restrictions stated in the LICENSE.txt file.
#


import streamlit as st
# import _json
import pandas as pd
import json
from PIL import ImageColor
import math
import numpy as np
from colorcet import blues
from transformers import RobertaTokenizerFast, RobertaForMaskedLM
import torch
import os
import hashlib
device = "cpu"

sample_text="""SAN FRANCISCO — A Facebook-appointed panel of journalists, activists and lawyers on Wednesday upheld the social network’s ban of former President Donald J. Trump, ending any immediate return by Mr. Trump to mainstream social media and renewing a debate about tech power over online speech.
Facebook’s Oversight Board, which acts as a quasi-court over the company’s content decisions, ruled the social network was right to bar Mr. Trump after the insurrection in Washington in January, saying he “created an environment where a serious risk of violence was possible.” The panel said that ongoing risk “justified” the move.
But the board also kicked the case back to Facebook and its top executives. It said that an indefinite suspension was “not appropriate” because it was not a penalty defined in Facebook’s policies and that the company should apply a standard punishment, such as a time-bound suspension or a permanent ban. The board gave Facebook six months to make a final decision on Mr. Trump’s account status.
“Our sole job is to hold this extremely powerful organization, Facebook,  accountable,” Michael McConnell, co-chair of the Oversight Board, said on a call with reporters. The ban on Mr. Trump “did not meet these standards,” he said."""





st.sidebar.success(f"running on {device}")

def get_color(norm_value,cmap):
    idx = int(math.floor((len(cmap)-1)*norm_value))
    return cmap[idx]

def get_color_cat(idx,cmap):
    return cmap[idx % len(cmap)]

def make_html_text_with_color(text,color):
    rgba = "rgba"+str(ImageColor.getrgb(color) + (.6,))
    return f'<span style="background-color: {rgba}">{text}</span>'

def replace(text):
    if text in ['<s>', '</s>', '<unk>', '<pad>', '<mask>']:
        text = ""
    return text.replace("�","")

def make_full_html(tokens, values, cmap=["yellow"], bounds=None, categotical = True):
    if not categotical:
        if bounds is None:
            vmn = values.min()
            vmx = values.max()
            values = (values-vmn)/(vmx-vmn+1e-6)
        else:
            vmn,vmx = bounds
            values =  np.clip(values, vmn, vmx)
            values = (values-vmn)/(vmx-vmn)
        return "".join([make_html_text_with_color(replace(t),get_color(v,cmap)) for t,v in zip(tokens,values)])
    else:
        return "".join([make_html_text_with_color(replace(t),get_color_cat(v,cmap)) if v>=0 else replace(t) for t,v in zip(tokens,values)])


emotions = ["anger", "joy", "fear", "trust", "anticipation", "sadness", "disgust", "surprise"]

PATH_CONN = "noun_adj_conntation_lexicon.csv"

@st.cache(allow_output_mutation = True,hash_funcs={'_json.Scanner': hash})
def get_connotations(emotion, vocab):
    data = pd.read_csv(PATH_CONN)
    data.conn = data.conn.apply(json.loads)
    i = emotions.index(emotion)
    mask = data.conn.apply(lambda e: e["Emo"][i]==1.)
    word_set =  set(data.loc[mask,"word"].values.tolist())
    vocab_mask = torch.from_numpy(vocab.isin(word_set).values)
    return word_set, vocab_mask

@st.cache(allow_output_mutation = True)
def get_model():
    tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
    model = RobertaForMaskedLM.from_pretrained('roberta-base').eval().requires_grad_(False).to(device)
    clean_vocab = pd.Series({v:tokenizer.convert_tokens_to_string(k).strip().lower() for k,v in tokenizer.get_vocab().items()}).sort_index()
    return tokenizer, clean_vocab, model


tokenizer, clean_vocab, model = get_model()




f"## Change Connotation"

col1,col2 = st.columns(2)

emotion_source = col1.selectbox("Source Emotion", emotions, index = 1)
emotion_target = col2.selectbox("Target Emotion", emotions, index = 0)


_, emotion_words_source = get_connotations(emotion_source,clean_vocab)
_, emotion_words_taget = get_connotations(emotion_target,clean_vocab)

# st.sidebar.write(emotion_words)

# custom_input = st.sidebar.checkbox("Custom Input",value = True)

custom_input = True

if custom_input:
    article = st.sidebar.text_area("Paste Text Here", value =sample_text, height = 600)
else:
    articles = get_articles()
    keyword = st.sidebar.text_input("Keywords",value="virus")
    article = search_articles(keyword, articles)

inputs = tokenizer(article, max_length=512, truncation = True,return_tensors = "pt" )
original_input_ids = inputs["input_ids"][0].clone()
words = [tokenizer.convert_tokens_to_string(s) for s in tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])]

#predict masked out words

mask = (inputs["input_ids"][0][:,None] == emotion_words_source.nonzero(as_tuple = False).flatten()).any(-1)
if not mask.any():
    st.warning("no source words found, try another input")
    scores = -np.ones(len(words))
    words_mod = words

else:
    inputs["input_ids"][0][mask] = tokenizer.mask_token_id

    with torch.no_grad():
        logits = model(**{k:v.to(device) for k,v in inputs.items()}).logits[0]
        logits[:,~emotion_words_taget] = float("-inf")
        logits[mask,original_input_ids[mask]] = float("-inf")
        idx = logits[mask,:].argmax(-1).cpu()
        # vals, idx = .topk(5,dim = -1)

    

    inputs["input_ids"][0,mask] = idx
    words_mod = [tokenizer.convert_tokens_to_string(s) for s in tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])]

    # [tokenizer.decode(el) for el in idx.cpu()]


    scores = mask.numpy().astype(int)
    scores[scores==0] = -1






with col1:
    # f"*{article.title}*"
    html_str = make_full_html(words, scores,cmap=["blue"])
    st.markdown(html_str, unsafe_allow_html=True)

with col2:
    # f"*{article.title}*"
    html_str = make_full_html(words_mod, scores,cmap=["yellow"])
    st.markdown(html_str, unsafe_allow_html=True)