KN123 commited on
Commit
a711d1c
·
verified ·
1 Parent(s): e8e4100

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +30 -0
  2. app.py +126 -0
  3. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ # Set working directory
4
+ WORKDIR /code
5
+
6
+ # Copy requirements file and install dependencies
7
+ COPY ./requirements.txt /code/requirements.txt
8
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
9
+
10
+ # Create a non-root user
11
+ RUN useradd user
12
+
13
+ # Set environment variables
14
+ ENV HOME=/home/user \
15
+ PATH=/home/user/.local/bin:$PATH
16
+
17
+ # Switch to the non-root user
18
+ USER user
19
+
20
+ # Set working directory for the application
21
+ WORKDIR $HOME/app
22
+
23
+ # Copy application code into the container
24
+ COPY --chown=user . $HOME/app
25
+
26
+ # Expose port
27
+ EXPOSE 7860
28
+
29
+ # Command to run the FastAPI application using uvicorn
30
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from transformers import GenerationConfig
5
+ from time import perf_counter
6
+ import json
7
+
8
+ from typing import List, Dict
9
+ import time
10
+ import datetime
11
+ import uvicorn
12
+ import torch
13
+
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+ # device = 'cpu'
16
+ print("LLM using", device)
17
+
18
+ REMOTE_PATH = "KN123/nl2csv4instructions-TinyLlama-v2.0"
19
+ LOCAL_PATH = "nl2csv4instructions-TinyLlama-v2.0"
20
+ print("🟢 Fetching the models ....")
21
+ model = AutoModelForCausalLM.from_pretrained(REMOTE_PATH, device_map = device)
22
+ tokenizer = AutoTokenizer.from_pretrained(REMOTE_PATH)
23
+ print("🚀 Ready! nl2csv4instructions at your service!")
24
+
25
+
26
+ # def get_prompt(tables, question):
27
+ # prompt = f"""Made. tables: {tables}. question: {question}"""
28
+ # # print(prompt)
29
+ # return prompt
30
+
31
+ # def prepare_input(question: str, tables: Dict[str, List[str]]):
32
+ # tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
33
+ # # print(tables)
34
+ # tables = ", ".join(tables)
35
+ # # print(tables)
36
+ # prompt = get_prompt(tables, question)
37
+ # # print(prompt)
38
+ # input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
39
+ # # print(input_ids)
40
+ # return input_ids
41
+
42
+ # def inference(question: str, tables: Dict[str, List[str]]) -> str:
43
+ # input_data = prepare_input(question=question, tables=tables)
44
+ # input_data = input_data.to(model.device)
45
+ # outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
46
+ # # print("Outputs", outputs)
47
+ # result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
48
+ # return result
49
+
50
+ def parse(output):
51
+ # Find the index of '<|assistant|>'
52
+ end_tag_index = output.find('<|assistant|>')
53
+ extracted_text = ""
54
+ if end_tag_index != -1:
55
+ # Extract text after '<|assistant|>'
56
+ extracted_text = output[end_tag_index + len('<|assistant|>'):]
57
+ # Remove any leading '\n' characters
58
+ extracted_text = extracted_text.lstrip('\n')
59
+ #print(extracted_text)
60
+ else:
61
+ print("End tag '<|assistant|>' not found in output.")
62
+
63
+ return extracted_text
64
+
65
+ def formatted_prompt(question)-> str:
66
+ return f"<|user|>\n{question}</s>\n<|assistant|>"
67
+
68
+ def generate_response(user_input):
69
+
70
+ prompt = formatted_prompt(user_input)
71
+
72
+ inputs = tokenizer([prompt], return_tensors="pt")
73
+ generation_config = GenerationConfig(penalty_alpha=0.6,do_sample = True,
74
+ top_k=5,temperature=0.1,repetition_penalty=1.2,pad_token_id=tokenizer.eos_token_id,max_new_tokens=20,
75
+ )
76
+ start_time = perf_counter()
77
+
78
+ # inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
79
+ inputs = tokenizer(prompt, return_tensors="pt").to(device=device)
80
+
81
+ outputs = model.generate(**inputs, generation_config=generation_config)
82
+ llm_output = (tokenizer.decode(outputs[0], skip_special_tokens=True))
83
+ # print(llm_output)
84
+ output_time = perf_counter() - start_time
85
+ output_time = round(output_time,2)
86
+ # print(f"Time taken for inference: {output_time} seconds")
87
+ res = {}
88
+ res['llm_output'] = llm_output
89
+ res['time_taken'] = output_time
90
+ res['parsed_text'] = parse(llm_output)
91
+ return res
92
+
93
+ app = FastAPI()
94
+ app.add_middleware(
95
+ CORSMiddleware,
96
+ allow_origins=["*"], # Allows all origins
97
+ allow_credentials=True,
98
+ allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # Allows all methods
99
+ allow_headers=["*"], # Allows all headers
100
+ )
101
+
102
+ @app.get("/")
103
+ def home():
104
+ return {
105
+ "message" : "Hello there! Everything is working fine!",
106
+ "api-version": "2.0.0",
107
+ "role": "nl2csv4instructions",
108
+ "description": "This api can be used to convert natural language into instructions in the form of comma separate values, Ex: pick the items F-1222 and E-2222 or ask to reset all the settings."
109
+ }
110
+
111
+ @app.get("/test-generate")
112
+ def generate():
113
+ res = generate_response(user_input='set the color Blue to F-2244')
114
+ return res
115
+
116
+ @app.post("/generate")
117
+ def generate(request_body:str):
118
+ print("Request Got: ", request_body)
119
+ res = generate_response(user_input=request_body)
120
+ print(res)
121
+ return res
122
+
123
+
124
+
125
+ if __name__ == "__main__":
126
+ uvicorn.run(app, host="127.0.0.1", port=7860)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ requests
3
+ uvicorn[standard]
4
+ sentencepiece
5
+ torch
6
+ transformers
7
+ aiohttp
8
+ accelerate