Spaces:
Sleeping
Sleeping
import streamlit as st | |
from easynmt import EasyNMT | |
from nltk import word_tokenize | |
from simalign import SentenceAligner | |
import re | |
import penman | |
import amrlib | |
from amrlib.graph_processing.annotator import add_lemmas | |
from amrlib.alignments.rbw_aligner import RBWAligner | |
def load_easynmt(): | |
return EasyNMT('opus-mt') | |
def load_stog_model(): | |
return amrlib.load_stog_model(model_dir='model_stog') | |
def load_gtos_model(): | |
return amrlib.load_gtos_model(model_dir='model_gtos') | |
# Find a node corresponding targetWord in the graph: | |
def getTargetWordNode(segmentTokens, aligner, alignments, target): | |
# Get target word in English: | |
if target in segmentTokens: | |
targetIndexFr = segmentTokens.index(target) | |
targetIndexesEn = [i for i in alignments['mwmf'] if i[0]==targetIndexFr] | |
if len(targetIndexesEn) > 0: | |
targetIndexEn = targetIndexesEn[0][1] | |
# Get a full name of the graph node: | |
if aligner.alignments[targetIndexEn] != None: | |
nodeConcepts = [i for i in re.split(',|\(|\"|\'', str(aligner.alignments[targetIndexEn])) if i.strip() != ''] | |
return nodeConcepts[0]+' / '+nodeConcepts[2] | |
else: | |
return 'Error!' # Alignment between target word in French and its English instance not found | |
else: | |
return 'Error!' # Alignment between target word in French and its English instance not found | |
else: | |
return 'Error!' # Alignment between target word in French and its English instance not found | |
# Extract a subgraph containing target word with full path (all the node) to it: | |
def getTargetWordSubGraphFullPath(amrGraph, target): | |
stringTmp = [i+' ' for i in re.split('\n', amrGraph) if i[0] !='#'] | |
stringTmp2 = [] | |
for s in stringTmp: | |
stringTmp2+=[i for i in re.split('(:\w+\s|:\w+-\w+\s)', s) if i.strip() !=''] | |
string = [] | |
for s in stringTmp2: | |
string+=[i for i in re.split('(\(|\))', s) if i.strip() !=''] | |
openListGlobal = [] | |
openList = [] | |
subGraph = "" | |
subGraphGlobal = [] | |
flag = False | |
stop = False | |
for i in range(len(string)): | |
if flag: | |
if string[i] == '(': | |
openList.append('(') | |
subGraph+=string[i] | |
elif string[i] == ')': | |
openList.pop() | |
if openList == []: | |
flag = False | |
stop = True | |
subGraph+=')' | |
subGraphGlobal.append(subGraph) | |
else: | |
subGraph+=string[i] | |
else: | |
subGraph+=string[i] | |
else: | |
if target in string[i].strip(): | |
flag = True | |
subGraph+=string[i] | |
openList.append('(') | |
else: | |
if not stop and string[i] == '(': | |
openListGlobal.append('(') | |
subGraphGlobal.append(string[i]) | |
elif not stop and string[i] == ')': | |
openListGlobal.pop() | |
while subGraphGlobal[-1] != '(': | |
subGraphGlobal.pop() | |
subGraphGlobal.pop() | |
subGraphGlobal.pop() | |
elif not stop: | |
subGraphGlobal.append(string[i]) | |
for i in openListGlobal: | |
if i=='(': | |
subGraphGlobal.append(')') | |
resultGraph = "" | |
for i in subGraphGlobal: | |
resultGraph+=i | |
# Fix the formatting: | |
g = penman.decode(resultGraph) | |
return penman.encode(g) | |
def main(): | |
st.header('Abstract Meaning Representation based summary of French text', divider='blue') | |
segmentFr = st.text_area( | |
"Text to summarize:", | |
"Article 2 : Occupations ou utilisations du sol soumises à des conditions particulières\n\n" | |
"2) Dans les périmètres en bordure des cours d’eau définis dans les annexes sanitaires du PLU :\n\n" | |
"− Seules les clôtures en grillage pourront être autorisées à condition qu'elles soient conçues de\n" | |
"manière à ne pas faire obstacle au libre écoulement des eaux.", | |
height=170, | |
) | |
## Alternative example: | |
#segmentFr = st.text_area( | |
#"Text to summarize:", | |
#"Article 1: Le classement interdit tout changement d'affectation ou tout mode d'occupation du sol de nature à compromettre la conservation, la protection ou la création des boisements. Dans les bois, forêts ou parcs situés sur le territoire de communes où l'établissement d'un plan d'occupation des sols a été prescrit mais où ce plan n'a pas encore été rendu public, ainsi que dans tout espace boisé classé, les coupes et abattages d'arbres sont soumis à autorisation préalable.", | |
#height=170, | |
#) | |
targetWord = st.text_input('Keyword:', 'clôtures') | |
##targetWord = st.text_input('Keyword:', 'compromettre') | |
if st.button('Summarize'): | |
# Fix input formatting: | |
segmentFr = segmentFr.replace('\n',' ') | |
# Translate segment into English: | |
model = load_easynmt() | |
segmentEn = model.translate(segmentFr , source_lang='fr', target_lang='en') | |
# Get an AMR graph: | |
stog = load_stog_model() | |
inputGraph = stog.parse_sents([segmentEn]) | |
# Get tokenized representation of segment in French: | |
segmentFrTokens = word_tokenize(segmentFr, language='french') | |
# Get tokenized representation of segment in English: | |
penmanGraph = add_lemmas(inputGraph[0], snt_key='snt') | |
aligner = RBWAligner.from_penman_w_json(penmanGraph) | |
segmentEnTokens = aligner.lemmas | |
# Get alignments between original version and translation: | |
myaligner = SentenceAligner(model="bert", token_type="bpe", matching_methods="mai") | |
alignments = myaligner.get_word_aligns(segmentFrTokens, segmentEnTokens) | |
# Find a node corresponding targetWord in the graph: | |
targetNode = getTargetWordNode(segmentFrTokens, aligner, alignments, targetWord) | |
# Check if targetNode is in the graph: | |
errorFlag = False | |
if targetNode not in inputGraph[0]: | |
#if targetWord in inputGraph[0]: | |
if targetWord in ''.join(inputGraph[0].split('\n')[1:]): | |
targetNode = targetWord | |
else: | |
errorFlag = True | |
# Extract a subgraph containing target word with full path (all the node) to it: | |
if not errorFlag: | |
if targetNode != 'Error!': | |
targetSubGraph = getTargetWordSubGraphFullPath(inputGraph[0], targetNode) | |
# Generate text from given AMR-graph: | |
gtos = load_gtos_model() | |
rulesEn, _ = gtos.generate([targetSubGraph]) | |
# Remove "1." from the text: | |
rulesEn = [re.sub('\d. ', '', rulesEn[0])] | |
# Translate it back to French | |
rulesFr = model.translate(rulesEn[0], source_lang='en', target_lang='fr') | |
st.write("Summary: ", rulesFr) | |
else: | |
st.write('Error! Alignment between target word in French and its English instance not found') | |
else: | |
st.write('Error! Cannot find keyword in the graph') | |
if __name__ == "__main__": | |
main() | |