import json import os import shutil import requests import gradio as gr from huggingface_hub import Repository from text_generation import Client from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css HF_TOKEN = os.environ.get("HF_TOKEN", None) API_URL = "https://api-inference.huggingface.co/models/codellama/CodeLlama-7b-hf" FIM_PREFIX = "
 "
FIM_MIDDLE = " "
FIM_SUFFIX = ""

FIM_INDICATOR = ""

EOS_STRING = ""
EOT_STRING = ""

theme = gr.themes.Monochrome(
    primary_hue="indigo",
    secondary_hue="blue",
    neutral_hue="slate",
    radius_size=gr.themes.sizes.radius_sm,
    font=[
        gr.themes.GoogleFont("Open Sans"),
        "ui-sans-serif",
        "system-ui",
        "sans-serif",
    ],
)

client = Client(
    API_URL,
    headers={"Authorization": f"Bearer {HF_TOKEN}"},
)


def generate(
    prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):

    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    fim_mode = False

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    if FIM_INDICATOR in prompt:
        fim_mode = True
        try:
            prefix, suffix = prompt.split(FIM_INDICATOR)
        except:
            raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!")
        prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"

    
    stream = client.generate_stream(prompt, **generate_kwargs)
    

    if fim_mode:
        output = prefix
    else:
        output = prompt

    previous_token = ""
    for response in stream:
        if response.token.text in [EOS_STRING, EOT_STRING]:
            if fim_mode:
                output += suffix
            else:
                return output
        else:
            output += response.token.text
        previous_token = response.token.text
        yield output
    return output


examples = [
    "X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score",
    "// Returns every other value in the array as a new array.\nfunction everyOther(arr) {",
    "Poor English: She no went to the market. Corrected English:",
    "def alternating(list1, list2):\n   results = []\n   for i in range(min(len(list1), len(list2))):\n       results.append(list1[i])\n       results.append(list2[i])\n   if len(list1) > len(list2):\n       \n   else:\n       results.extend(list2[i+1:])\n   return results",
]


def process_example(args):
    for x in generate(args):
        pass
    return x


css = ".generating {visibility: hidden}"

monospace_css = """
#q-input textarea {
    font-family: monospace, 'Consolas', Courier, monospace;
}
"""


css += share_btn_css + monospace_css + ".gradio-container {color: black}"

description = """

🦙 CodeLlama Playground

This is a demo to generate text and code with the following Code Llama model (7B). Please note that this model is not designed for instruction purposes but for code completion. If you're looking for instruction or want to chat with a fine-tuned model, you can visit the Code Llama Org and select an instruct model.

""" with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo: with gr.Column(): gr.Markdown(description) with gr.Row(): with gr.Column(): instruction = gr.Textbox( placeholder="Enter your code here", lines=5, label="Input", elem_id="q-input", ) submit = gr.Button("Generate", variant="primary") output = gr.Code(elem_id="q-output", lines=30, label="Output") with gr.Row(): with gr.Column(): with gr.Accordion("Advanced settings", open=False): with gr.Row(): column_1, column_2 = gr.Column(), gr.Column() with column_1: temperature = gr.Slider( label="Temperature", value=0.2, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs", ) max_new_tokens = gr.Slider( label="Max new tokens", value=256, minimum=0, maximum=8192, step=64, interactive=True, info="The maximum numbers of new tokens", ) with column_2: top_p = gr.Slider( label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens", ) repetition_penalty = gr.Slider( label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens", ) with gr.Group(elem_id="share-btn-container"): community_icon = gr.HTML(community_icon_html, visible=True) loading_icon = gr.HTML(loading_icon_html, visible=True) share_button = gr.Button( "Share to community", elem_id="share-btn", visible=True ) gr.Examples( examples=examples, inputs=[instruction], cache_examples=False, fn=process_example, outputs=[output], ) submit.click( generate, inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty], outputs=[output], ) share_button.click(None, [], [], _js=share_js) demo.queue(concurrency_count=16).launch(debug=True)