File size: 3,042 Bytes
bd1c71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import requests
import subprocess
import time
import json

from ollama import chat
from ollama import ChatResponse


def start_ollama_server():
    # Start Ollama server in the background
    print("Starting Ollama server...")
    subprocess.Popen(["ollama", "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    time.sleep(5)  # Give some time for the server to start

    # Pull the required model
    print("Pulling the required model...")
    subprocess.run(["ollama", "pull", "llama3.2:1b"], check=True)

    print("Pulling the required model...")
    subprocess.Popen(["ollama", "run", "llama3.2:1b"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    print("Ollama started model.")

# Function to send a prompt to the Ollama server and return the response
def ask_ollama(question):
    # url = "http://localhost:11434/api/generate"
    # headers = {"Content-Type": "application/json"}
    # data = {
    #     "model": "llama3.1",
    #     "prompt": prompt,
    #     "format": "json",
    #     "stream": False
    # }
    # try:
    #     # Send the POST request to the API
    #     response = requests.post(url, headers=headers, json=data)
    #     response.raise_for_status()  # Raise an exception for HTTP errors
    #     result = response.json()  # Parse the JSON response
    #
    #     # Extract and clean the "response" field
    #     actual_response = result.get("response", "").strip()
    #     print(actual_response)
    #     return actual_response if actual_response else "No response found"
    # except requests.exceptions.RequestException as e:
    #     return f"Error: {str(e)}"

    prompt_template = f"""
    ### You are an expert in the subreddit r/AmItheAsshole. 

    ### The task for you is to classify the given text content as YTA or NTA label and give an explanation for the same. 

    ### The output format is as follows:
    "YTA" or "NTA", explanation for the label.

    ### Input Text :  {question}
    """

    response: ChatResponse = chat(model='llama3.2:1b', messages=[
        {
            'role': 'user',
            'content': prompt_template,
        },
    ])
    print(response['message']['content'])
    # or access fields directly from the response object
    return response['message']['content']


# Gradio Interface
def gradio_interface(prompt):
    return ask_ollama(prompt)


# Build the Gradio app
with gr.Blocks() as demo:
    gr.Markdown("# Ollama Server Interface")
    gr.Markdown("Ask questions and get responses from the Ollama server.")

    with gr.Row():
        input_prompt = gr.Textbox(label="Enter your question", placeholder="Type your question here...")

    with gr.Row():
        submit_button = gr.Button("Ask")

    with gr.Row():
        output_response = gr.Textbox(label="Response", lines=10)

    submit_button.click(gradio_interface, inputs=input_prompt, outputs=output_response)

# Launch the app
if __name__ == "__main__":
    start_ollama_server()
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)