amrdemo / app.py
koptelovmax's picture
Update app
e56df02
raw
history blame
8.28 kB
import streamlit as st
from easynmt import EasyNMT
from nltk import word_tokenize
from simalign import SentenceAligner
import re
import penman
import amrlib
from amrlib.graph_processing.annotator import add_lemmas
from amrlib.alignments.rbw_aligner import RBWAligner
# Find a node corresponding targetWord in the graph:
def getTargetWordNode(segmentTokens, aligner, alignments, target):
# Get target word in English:
if target in segmentTokens:
targetIndexFr = segmentTokens.index(target)
targetIndexesEn = [i for i in alignments['mwmf'] if i[0]==targetIndexFr]
if len(targetIndexesEn) > 0:
targetIndexEn = targetIndexesEn[0][1]
# Get a full name of the graph node:
if aligner.alignments[targetIndexEn] != None:
nodeConcepts = [i for i in re.split(',|\(|\"|\'', str(aligner.alignments[targetIndexEn])) if i.strip() != '']
return nodeConcepts[0]+' / '+nodeConcepts[2]
else:
return 'Error!' # Alignment between target word in French and its English instance not found
else:
return 'Error!' # Alignment between target word in French and its English instance not found
else:
return 'Error!' # Alignment between target word in French and its English instance not found
# Extract a subgraph containing target word with full path (all the node) to it:
def getTargetWordSubGraphFullPath(amrGraph, target):
stringTmp = [i+' ' for i in re.split('\n', amrGraph) if i[0] !='#']
stringTmp2 = []
for s in stringTmp:
stringTmp2+=[i for i in re.split('(:\w+\s|:\w+-\w+\s)', s) if i.strip() !='']
string = []
for s in stringTmp2:
string+=[i for i in re.split('(\(|\))', s) if i.strip() !='']
openListGlobal = []
openList = []
subGraph = ""
subGraphGlobal = []
flag = False
stop = False
for i in range(len(string)):
if flag:
if string[i] == '(':
openList.append('(')
subGraph+=string[i]
elif string[i] == ')':
openList.pop()
if openList == []:
flag = False
stop = True
subGraph+=')'
subGraphGlobal.append(subGraph)
else:
subGraph+=string[i]
else:
subGraph+=string[i]
else:
if target in string[i].strip():
flag = True
subGraph+=string[i]
openList.append('(')
else:
if not stop and string[i] == '(':
openListGlobal.append('(')
subGraphGlobal.append(string[i])
elif not stop and string[i] == ')':
openListGlobal.pop()
while subGraphGlobal[-1] != '(':
subGraphGlobal.pop()
subGraphGlobal.pop()
subGraphGlobal.pop()
elif not stop:
subGraphGlobal.append(string[i])
for i in openListGlobal:
if i=='(':
subGraphGlobal.append(')')
resultGraph = ""
for i in subGraphGlobal:
resultGraph+=i
# Fix the formatting:
g = penman.decode(resultGraph)
return penman.encode(g)
def main():
st.header('Abstract Meaning Representation based summary of French text', divider='blue')
#segmentFr = st.text_area(
#"Text to summarize:",
#"Article 1 : Occupations ou utilisations du sol interdites\n\n"
#"1) Dans l’ensemble de la zone sont interdits :\n\n"
#"Les pylônes et poteaux, supports d’enseignes et d’antennes d’émission ou de réception de \n"
#"signaux radioélectriques.",
#height=170,
#)#.replace('\n',' ')
segmentFr = st.text_area(
"Text to summarize:",
"Article 2 : Occupations ou utilisations du sol soumises à des conditions particulières\n\n"
"2) Dans les périmètres en bordure des cours d’eau définis dans les annexes sanitaires du PLU :\n\n"
"− Seules les clôtures en grillage pourront être autorisées à condition qu'elles soient conçues de\n"
"manière à ne pas faire obstacle au libre écoulement des eaux.",
height=170,
)
## Alternative example:
#segmentFr = st.text_area(
#"Text to summarize:",
#"Article 1: Le classement interdit tout changement d'affectation ou tout mode d'occupation du sol de nature à compromettre la conservation, la protection ou la création des boisements. Dans les bois, forêts ou parcs situés sur le territoire de communes où l'établissement d'un plan d'occupation des sols a été prescrit mais où ce plan n'a pas encore été rendu public, ainsi que dans tout espace boisé classé, les coupes et abattages d'arbres sont soumis à autorisation préalable.",
#height=170,
#)
#targetWord = st.text_input('Keyword:', 'Occupations')
targetWord = st.text_input('Keyword:', 'clôtures')
##targetWord = st.text_input('Keyword:', 'compromettre')
if st.button('Summarize'):
# Fix input formatting:
segmentFr = segmentFr.replace('\n',' ')
#st.code(segmentFr)
# Translate segment into English:
model = EasyNMT('opus-mt')
#segmentEn = model.translate(segmentFr.lower() , source_lang='fr', target_lang='en')
segmentEn = model.translate(segmentFr , source_lang='fr', target_lang='en')
# Get an AMR graph:
stog = amrlib.load_stog_model(model_dir='model_stog')
inputGraph = stog.parse_sents([segmentEn])
## Output the resulting graph:
#print(inputGraph[0])
# Get tokenized representation of segment in French:
segmentFrTokens = word_tokenize(segmentFr, language='french')
# Get tokenized representation of segment in English:
penmanGraph = add_lemmas(inputGraph[0], snt_key='snt')
aligner = RBWAligner.from_penman_w_json(penmanGraph)
segmentEnTokens = aligner.lemmas
# Get alignments between original version and translation:
myaligner = SentenceAligner(model="bert", token_type="bpe", matching_methods="mai")
alignments = myaligner.get_word_aligns(segmentFrTokens, segmentEnTokens)
# Find a node corresponding targetWord in the graph:
targetNode = getTargetWordNode(segmentFrTokens, aligner, alignments, targetWord)
# Check if targetNode is in the graph:
errorFlag = False
if targetNode not in inputGraph[0]:
#if targetWord in inputGraph[0]:
if targetWord in ''.join(inputGraph[0].split('\n')[1:]):
targetNode = targetWord
else:
errorFlag = True
# Extract a subgraph containing target word with full path (all the node) to it:
if not errorFlag:
if targetNode != 'Error!':
targetSubGraph = getTargetWordSubGraphFullPath(inputGraph[0], targetNode)
#print(targetSubGraph)
# Generate text from given AMR-graph:
gtos = amrlib.load_gtos_model(model_dir='model_gtos')
rulesEn, _ = gtos.generate([targetSubGraph])
# Remove "1." from the text:
rulesEn = [re.sub('\d. ', '', rulesEn[0])]
# Translate it back to French
rulesFr = model.translate(rulesEn[0], source_lang='en', target_lang='fr')
#print(rulesEn[0])
#print(rulesFr)
st.write("Summary: ", rulesFr)
else:
#print('Alignment between target word in French and its English instance not found')
st.write('Error! Alignment between target word in French and its English instance not found')
else:
#print('Error! Cannot find keyword in the graph')
st.write('Error! Cannot find keyword in the graph')
if __name__ == "__main__":
main()