File size: 7,893 Bytes
2714773 5782838 9d6743e 2714773 338e482 d7a34dd 2714773 d7a34dd 4050efc 5eca98b 9d6743e 2714773 ca38751 2714773 4050efc 2714773 8233187 4050efc 115834e 8233187 01b0bb9 b84957a 7931b6e 01b0bb9 b84957a 115834e 10b34aa a582020 10b34aa 01b0bb9 2714773 115834e 2714773 aa595f0 4da7ef5 2714773 4c1270f 4da7ef5 2714773 29d45a4 2714773 978fd4d 2714773 6651d18 2714773 5758bb4 2714773 0f46be8 2714773 6931a12 2714773 5758bb4 2714773 fd2320e ccddcad 29d45a4 8233187 115834e 8233187 a46b806 7b2f839 8233187 115834e 10b34aa 0bb8f5d 2714773 29d45a4 115834e d266202 fd2320e 20a1bcd fd2320e 4050efc fd2320e d266202 ccddcad bbb2bb6 ccddcad d266202 2714773 978fd4d 5758bb4 2714773 f216d8b 2714773 82f4770 29d45a4 9c35ba8 2714773 f216d8b 4fc0d85 da85880 ccddcad 4fc0d85 2714773 01b0bb9 2714773 01b0bb9 2714773 d266202 4050efc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import gradio as gr
import numpy as np
import time
import os
import random
#import pyodbc
'''
import pkg_resources
# Get a list of installed packages and their versions
installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
# Print the list of packages
for package, version in installed_packages.items():
print(f"{package}=={version}")
'''
'''
# Replace the connection parameters with your SQL Server information
server = 'your_server'
database = 'your_database'
username = 'your_username'
password = 'your_password'
driver = 'SQL Server' # This depends on the ODBC driver installed on your system
# Create the connection string
connection_string = f'DRIVER={{{driver}}};SERVER={server};DATABASE={database};UID={username};PWD={password}'
# Connect to the SQL Server
conn = pyodbc.connect(connection_string)
#============================================================================
# Replace "your_query" with your SQL query to fetch data from the database
query = 'SELECT * FROM your_table_name'
# Use pandas to read data from the SQL Server and store it in a DataFrame
df = pd.read_sql_query(query, conn)
# Close the SQL connection
conn.close()
'''
# Create a sample DataFrame with 3,000 records and 20 columns
'''
num_records = 3000
num_columns = 20
data = {
f"column_{i}": np.random.randint(0, 100, num_records) for i in range(num_columns)
}
# Randomize the year and city columns
years = list(range(2000, 2023)) # Range of years
cities = ["New York", "Los Angeles", "Chicago", "Houston", "Miami"] # List of cities
data["year"] = [random.choice(years) for _ in range(num_records)]
data["city"] = [random.choice(cities) for _ in range(num_records)]
table = pd.DataFrame(data)
'''
#table = pd.read_csv(csv_file.name, delimiter=",")
#table.fillna(0, inplace=True)
#table = table.astype(str)
data = {
"year": [1896, 1900, 1904, 2004, 2008, 2012],
"city": ["athens", "paris", "st. louis", "athens", "beijing", "london"]
}
table = pd.DataFrame.from_dict(data)
# Load the chatbot model
chatbot_model_name = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
#cmax_token_limit = tokenizer.max_model_input_sizes[chatbot_model_name]
#print(f"Chat bot Maximum token limit for {chatbot_model_name}: {cmax_token_limit}")
# Load the SQL Model
sql_model_name = "microsoft/tapex-large-finetuned-wtq"
sql_tokenizer = TapexTokenizer.from_pretrained(sql_model_name)
sql_model = BartForConditionalGeneration.from_pretrained(sql_model_name)
#stokenizer = AutoTokenizer.from_pretrained(sql_model_name)
#max_token_limit = stokenizer.max_model_input_sizes[sql_model_name]
#print(f"SQL Maximum token limit for {sql_model_name}: {max_token_limit}")
#sql_response = None
conversation_history = []
def chat(input, history=[]):
#global sql_response
# Check if the user input is a question
#is_question = "?" in input
'''
if is_question:
sql_encoding = sql_tokenizer(table=table, query=input + sql_tokenizer.eos_token, return_tensors="pt")
sql_outputs = sql_model.generate(**sql_encoding)
sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
else:
'''
# tokenize the new input sentence
new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')
# append the new user input tokens to the chat history
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
# generate a response
history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()
# convert the tokens to text, and then split the responses into the right format
response = tokenizer.decode(history[0]).split("<|endoftext|>")
response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # convert to tuples of list
return response, history
def sqlquery(input): #, history=[]):
global conversation_history
#======================================================================
'''
batch_size = 10 # Number of records in each batch
num_records = 3000 # Total number of records in the dataset
for start_idx in range(0, num_records, batch_size):
end_idx = min(start_idx + batch_size, num_records)
# Get a batch of records
batch_data = table[start_idx:end_idx]
# Tokenize the batch
tokenized_batch = sql_tokenizer.batch_encode_plus(
batch_data, padding=True, truncation=True, return_tensors="pt"
)
# Perform inference
with torch.no_grad():
output = sql_model.generate(
input_ids=tokenized_batch["input_ids"],
max_length=1024,
pad_token_id=sql_tokenizer.eos_token_id,
)
# Decode the output and process the responses
responses = [sql_tokenizer.decode(ids, skip_special_tokens=True) for ids in output]
conversation_history.append("User: " + record["question"])
for response in enumerate(responses):
# Update conversation history
conversation_history.append("Bot: " + response)
'''
# ==========================================================================
inputs = [input]
sql_encoding = sql_tokenizer(table=table, query=input, return_tensors="pt")
sql_outputs = sql_model.generate(**sql_encoding)
sql_response = sql_tokenizer.batch_decode(sql_outputs, skip_special_tokens=True)
#history.append((input, sql_response))
conversation_history.append(("User", input))
conversation_history.append(("Bot", sql_response))
# Build conversation string
#conversation = "\n".join([f"User: {user_msg}\nBot: {resp_msg}" for user_msg, resp_msg in conversation_history])
conversation = "\n".join([f"{sender}: {msg}" for sender, msg in conversation_history])
return conversation
#return sql_response
#return sql_response, history
'''
html = "<div class='chatbot'>"
for user_msg, resp_msg in conversation_history:
html += f"<div class='user_msg'>{user_msg}</div>"
html += f"<div class='resp_msg'>{resp_msg}</div>"
html += "</div>"
return html
'''
chat_interface = gr.Interface(
fn=chat,
theme="default",
css=".footer {display:none !important}",
inputs=["text", "state"],
outputs=["chatbot", "state"],
title="ST Chatbot",
description="Type your message in the box above, and the chatbot will respond.",
)
sql_interface = gr.Interface(
fn=sqlquery,
theme="default",
css=".footer {display:none !important}",
inputs=gr.Textbox(prompt="You:"),
outputs=gr.Textbox(),
#inputs=["text", "state"],
#outputs=["chatbot", "state"],
#live=True,
#capture_session=True,
title="ST SQL Chat",
description="Type your message in the box above, and the chatbot will respond.",
)
'''
iface = gr.Interface(sqlquery, "text", "html", css="""
.chatbox {display:flex;flex-direction:column}
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
.resp_msg {background-color:lightgray;align-self:self-end}
""", allow_screenshot=False, allow_flagging=False)
'''
combine_interface = gr.TabbedInterface(
interface_list=[
sql_interface,
chat_interface
],
tab_names=['SQL Chat' ,'Chatbot'],
)
if __name__ == '__main__':
combine_interface.launch()
#iface.launch(debug=True)
|