File size: 3,233 Bytes
b1c0569
 
 
 
49240f9
2effa61
2352bc9
 
b1c0569
2352bc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1c0569
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49240f9
b1c0569
75bced6
b1c0569
 
 
 
 
 
 
 
5947c02
 
8369f94
16d2859
2352bc9
 
 
 
b1c0569
49240f9
0e8e661
49240f9
 
 
 
b1c0569
49240f9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import OpenAIModel
from knowledge_storm.rm import YouRM
import spaces
import gradio as gr
import json
import re

def convert_references_to_links(text, json_data):
    url_mapping = json_data['url_to_unified_index']
    
    # Function to replace references with markdown links
    def replace_reference(match):
        ref_num = match.group(1)
        url = next((url for url, index in url_mapping.items() if str(index) == ref_num), None)
        if url:
            return f'[{match.group(0)}]({url})'
        return match.group(0)
    
    # Replace references in the text
    processed_text = re.sub(r'\[(\d+)\]', replace_reference, text)
    
    # Generate reference list
    reference_list = [f"[{index}] {url}" for url, index in sorted(url_mapping.items(), key=lambda x: x[1])]
    
    # Combine processed text and reference list
    markdown_output = f"{processed_text}\n\n" + "\n".join(reference_list)
    
    return markdown_output
    
lm_configs = STORMWikiLMConfigs()
openai_kwargs = {
    'api_key': os.getenv("OPENAI_API_KEY"),
    'temperature': 1.0,
    'top_p': 0.9,
}
# STORM is a LM system so different components can be powered by different models to reach a good balance between cost and quality.
# For a good practice, choose a cheaper/faster model for `conv_simulator_lm` which is used to split queries, synthesize answers in the conversation.
# Choose a more powerful model for `article_gen_lm` to generate verifiable text with citations.
gpt_35 = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs)
gpt_4 = OpenAIModel(model='gpt-4o', max_tokens=3000, **openai_kwargs)
lm_configs.set_conv_simulator_lm(gpt_4)
lm_configs.set_question_asker_lm(gpt_4)
lm_configs.set_outline_gen_lm(gpt_4)
lm_configs.set_article_gen_lm(gpt_4)
lm_configs.set_article_polish_lm(gpt_4)

# Check out the STORMWikiRunnerArguments class for more configurations.
engine_args = STORMWikiRunnerArguments("outputs")
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
runner = STORMWikiRunner(engine_args, lm_configs, rm)

@spaces.GPU
def generate_article(prompt, progress=gr.Progress(track_tqdm=True)):
    response = runner.run(
        topic=prompt,
        do_research=True,
        do_generate_outline=True,
        do_generate_article=True,
        do_polish_article=True,
    )
    runner.post_run()
    runner.summary()
    print(os.listdir())
    generated_folder = prompt.replace(" ", "_")
    with open(f'outputs/{generated_folder}/storm_gen_article.txt', 'r') as file:
        content = file.read()
    with open(f'outputs/{generated_folder}/url_to_info.json', 'r') as file:
        references_json = json.load(file)
    article_full = convert_references_to_links(f'# {prompt}\n\n'+content, references_json)
    return article_full

with gr.Blocks() as demo:
    gr.Markdown("# Omnipedia article generation demo (Storm GPT-4 + You)")
    prompt = gr.Textbox(label="Prompt")
    output = gr.Markdown(label="Output")
    btn = gr.Button("Generate")
    btn.click(fn=generate_article, inputs=prompt, outputs=output)

if __name__ == "__main__":
    demo.launch()