Spaces:
Runtime error
Runtime error
# | |
# Copyright 2021 Systems & Technology Research. All rights reserved. | |
# This software and associated documentation is subject to the use restrictions stated in the LICENSE.txt file. | |
# | |
import streamlit as st | |
# import _json | |
import pandas as pd | |
import json | |
from PIL import ImageColor | |
import math | |
import numpy as np | |
from colorcet import blues | |
from transformers import RobertaTokenizerFast, RobertaForMaskedLM | |
import torch | |
import os | |
import hashlib | |
device = "cpu" | |
sample_text="""SAN FRANCISCO — A Facebook-appointed panel of journalists, activists and lawyers on Wednesday upheld the social network’s ban of former President Donald J. Trump, ending any immediate return by Mr. Trump to mainstream social media and renewing a debate about tech power over online speech. | |
Facebook’s Oversight Board, which acts as a quasi-court over the company’s content decisions, ruled the social network was right to bar Mr. Trump after the insurrection in Washington in January, saying he “created an environment where a serious risk of violence was possible.” The panel said that ongoing risk “justified” the move. | |
But the board also kicked the case back to Facebook and its top executives. It said that an indefinite suspension was “not appropriate” because it was not a penalty defined in Facebook’s policies and that the company should apply a standard punishment, such as a time-bound suspension or a permanent ban. The board gave Facebook six months to make a final decision on Mr. Trump’s account status. | |
“Our sole job is to hold this extremely powerful organization, Facebook, accountable,” Michael McConnell, co-chair of the Oversight Board, said on a call with reporters. The ban on Mr. Trump “did not meet these standards,” he said.""" | |
st.sidebar.success(f"running on {device}") | |
def get_color(norm_value,cmap): | |
idx = int(math.floor((len(cmap)-1)*norm_value)) | |
return cmap[idx] | |
def get_color_cat(idx,cmap): | |
return cmap[idx % len(cmap)] | |
def make_html_text_with_color(text,color): | |
rgba = "rgba"+str(ImageColor.getrgb(color) + (.6,)) | |
return f'<span style="background-color: {rgba}">{text}</span>' | |
def replace(text): | |
if text in ['<s>', '</s>', '<unk>', '<pad>', '<mask>']: | |
text = "" | |
return text.replace("�","") | |
def make_full_html(tokens, values, cmap=["yellow"], bounds=None, categotical = True): | |
if not categotical: | |
if bounds is None: | |
vmn = values.min() | |
vmx = values.max() | |
values = (values-vmn)/(vmx-vmn+1e-6) | |
else: | |
vmn,vmx = bounds | |
values = np.clip(values, vmn, vmx) | |
values = (values-vmn)/(vmx-vmn) | |
return "".join([make_html_text_with_color(replace(t),get_color(v,cmap)) for t,v in zip(tokens,values)]) | |
else: | |
return "".join([make_html_text_with_color(replace(t),get_color_cat(v,cmap)) if v>=0 else replace(t) for t,v in zip(tokens,values)]) | |
emotions = ["anger", "joy", "fear", "trust", "anticipation", "sadness", "disgust", "surprise"] | |
PATH_CONN = "noun_adj_conntation_lexicon.csv" | |
def get_connotations(emotion, vocab): | |
data = pd.read_csv(PATH_CONN) | |
data.conn = data.conn.apply(json.loads) | |
i = emotions.index(emotion) | |
mask = data.conn.apply(lambda e: e["Emo"][i]==1.) | |
word_set = set(data.loc[mask,"word"].values.tolist()) | |
vocab_mask = torch.from_numpy(vocab.isin(word_set).values) | |
return word_set, vocab_mask | |
def get_model(): | |
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base') | |
model = RobertaForMaskedLM.from_pretrained('roberta-base').eval().requires_grad_(False).to(device) | |
clean_vocab = pd.Series({v:tokenizer.convert_tokens_to_string(k).strip().lower() for k,v in tokenizer.get_vocab().items()}).sort_index() | |
return tokenizer, clean_vocab, model | |
tokenizer, clean_vocab, model = get_model() | |
f"## Change Connotation" | |
col1,col2 = st.columns(2) | |
emotion_source = col1.selectbox("Source Emotion", emotions, index = 1) | |
emotion_target = col2.selectbox("Target Emotion", emotions, index = 0) | |
_, emotion_words_source = get_connotations(emotion_source,clean_vocab) | |
_, emotion_words_taget = get_connotations(emotion_target,clean_vocab) | |
# st.sidebar.write(emotion_words) | |
# custom_input = st.sidebar.checkbox("Custom Input",value = True) | |
custom_input = True | |
if custom_input: | |
article = st.sidebar.text_area("Paste Text Here", value =sample_text, height = 600) | |
else: | |
articles = get_articles() | |
keyword = st.sidebar.text_input("Keywords",value="virus") | |
article = search_articles(keyword, articles) | |
inputs = tokenizer(article, max_length=512, truncation = True,return_tensors = "pt" ) | |
original_input_ids = inputs["input_ids"][0].clone() | |
words = [tokenizer.convert_tokens_to_string(s) for s in tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])] | |
#predict masked out words | |
mask = (inputs["input_ids"][0][:,None] == emotion_words_source.nonzero(as_tuple = False).flatten()).any(-1) | |
if not mask.any(): | |
st.warning("no source words found, try another input") | |
scores = -np.ones(len(words)) | |
words_mod = words | |
else: | |
inputs["input_ids"][0][mask] = tokenizer.mask_token_id | |
with torch.no_grad(): | |
logits = model(**{k:v.to(device) for k,v in inputs.items()}).logits[0] | |
logits[:,~emotion_words_taget] = float("-inf") | |
logits[mask,original_input_ids[mask]] = float("-inf") | |
idx = logits[mask,:].argmax(-1).cpu() | |
# vals, idx = .topk(5,dim = -1) | |
inputs["input_ids"][0,mask] = idx | |
words_mod = [tokenizer.convert_tokens_to_string(s) for s in tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])] | |
# [tokenizer.decode(el) for el in idx.cpu()] | |
scores = mask.numpy().astype(int) | |
scores[scores==0] = -1 | |
with col1: | |
# f"*{article.title}*" | |
html_str = make_full_html(words, scores,cmap=["blue"]) | |
st.markdown(html_str, unsafe_allow_html=True) | |
with col2: | |
# f"*{article.title}*" | |
html_str = make_full_html(words_mod, scores,cmap=["yellow"]) | |
st.markdown(html_str, unsafe_allow_html=True) | |