rockerritesh commited on
Commit
5f0af37
·
verified ·
1 Parent(s): a69d5f3

Rename app.py to main.py

Browse files
Files changed (2) hide show
  1. app.py +0 -79
  2. main.py +41 -0
app.py DELETED
@@ -1,79 +0,0 @@
1
- import gradio as gr
2
- from typing import List
3
- import requests
4
- import json
5
- import subprocess
6
- from multiprocessing import Process
7
-
8
- def get_text_embedding(text: str, model: str = "mxbai-embed-large", api_url: str = "http://localhost:11434/api/embeddings") -> List[float]:
9
- """
10
- Sends a prompt to the embedding API and retrieves the embedding.
11
-
12
- Args:
13
- text (str): The text to embed.
14
- model (str): The model to use for generating the embedding (default is "mxbai-embed-large").
15
- api_url (str): The API endpoint URL (default is "http://localhost:11434/api/embeddings").
16
-
17
- Returns:
18
- list: A list of floats representing the embedding vector.
19
-
20
- Raises:
21
- Exception: If the API request fails.
22
- """
23
- payload = {
24
- "model": model,
25
- "prompt": text
26
- }
27
-
28
- try:
29
- response = requests.post(api_url, data=json.dumps(payload), headers={"Content-Type": "application/json"})
30
- response.raise_for_status() # Raise an error for non-200 status codes
31
- data = response.json()
32
- return data.get("embedding", [])
33
- except requests.exceptions.RequestException as e:
34
- raise Exception(f"Error communicating with the embedding API: {e}")
35
-
36
- def process_text_to_embedding(text: str) -> str:
37
- """Process the text input and return the embedding as a string."""
38
- try:
39
- embedding = get_text_embedding(text)
40
- return json.dumps(embedding, indent=2)
41
- except Exception as e:
42
- return f"Error: {str(e)}"
43
-
44
- def run_ollama_serve():
45
- subprocess.run(["ollama", "serve"], check=True)
46
-
47
- # Create processes
48
- serve_process = Process(target=run_ollama_serve)
49
-
50
- # Start processes
51
- serve_process.start()
52
-
53
- # subprocess.run(["sudo", "apt", "install", "-y", "pciutils", "lshw"], check=True)
54
- # subprocess.run(["curl", "-fsSL", "https://ollama.com/install.sh", "|", "sh"], shell=True, check=True)
55
- subprocess.run(["ollama", "pull", "snowflake-arctic-embed2"], check=True)
56
-
57
- # Define the Gradio interface
58
- def main():
59
- title = "Text Embedding Generator"
60
- description = "Enter a text input, and this tool will generate an embedding using the specified model via API."
61
-
62
- with gr.Blocks() as demo:
63
- gr.Markdown(f"# {title}")
64
- gr.Markdown(description)
65
-
66
- with gr.Row():
67
- text_input = gr.Textbox(label="Input Text", placeholder="Enter your text here")
68
-
69
- with gr.Row():
70
- output = gr.Textbox(label="Embedding Output", lines=10)
71
-
72
- submit_button = gr.Button("Generate Embedding")
73
-
74
- submit_button.click(fn=process_text_to_embedding, inputs=[text_input], outputs=[output])
75
-
76
- demo.launch()
77
-
78
- if __name__ == "__main__":
79
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ from typing import List
5
+ from vllm import LLM
6
+ import numpy as np
7
+
8
+ # Initialize the model
9
+ llm = LLM(model='BAAI/bge-base-en-v1.5', task="embed")
10
+
11
+ # Initialize FastAPI app
12
+ app = FastAPI()
13
+
14
+ # Define request schemas
15
+ class DocumentsRequest(BaseModel):
16
+ documents: List[str]
17
+
18
+ class QueryRequest(BaseModel):
19
+ query: str
20
+
21
+ # API to embed documents
22
+ @app.post("/embed_documents")
23
+ def embed_documents(request: DocumentsRequest):
24
+ try:
25
+ docs = request.documents
26
+ docs_embd = llm.encode(docs)
27
+ docs_embd = [doc.outputs.data.numpy().tolist() for doc in docs_embd]
28
+ return {"embeddings": docs_embd}
29
+ except Exception as e:
30
+ raise HTTPException(status_code=500, detail=f"Error embedding documents: {str(e)}")
31
+
32
+ # API to embed query
33
+ @app.post("/embed_query")
34
+ def embed_query(request: QueryRequest):
35
+ try:
36
+ query = request.query
37
+ query_embd = llm.encode(query)
38
+ query_embd = query_embd[0].outputs.data.numpy().tolist()
39
+ return {"embedding": query_embd}
40
+ except Exception as e:
41
+ raise HTTPException(status_code=500, detail=f"Error embedding query: {str(e)}")