phi-2-mongodb / README.md
Chirayu's picture
Update README.md
ee2e597 verified
|
raw
history blame
4.99 kB
---
library_name: peft
base_model: microsoft/phi-2
---
# Model Card for Model ID
phi-2-mongodb is a fine-tuned version of microsoft/phi-2 to generate MongoDB pipeline queries. It was fine-tuned on a custom curated natural language to MongoDB queries dataset, I'll be releasing that next week.
## Model Details
Further details about fine-tuned model can be found at : https://github.com/Chirayu-Tripathi/nl2query. It can also be used via nl2query library.
### Model Description
<!-- Provide a longer summary of what this model is. -->
- **Fine-tuned by:** [`Chirayu Tripathi`](http://www.linkedin.com/in/chirayu-tripathi)
- **Developed by:** [`Microsoft`]
- **Language(s) (NLP):** English
- **License:** MIT
- **Finetuned from model:** [`microsoft/phi-2`](https://huggingface.co/microsoft/phi-2)
### Prompt Template
```
prompt_template = f"""<s>
Task Description:
Your task is to create a MongoDB query that accurately fulfills the provided Instruct while strictly adhering to the given MongoDB schema. Ensure that the query solely relies on keys and columns present in the schema. Minimize the usage of lookup operations wherever feasible to enhance query efficiency.
MongoDB Schema:
{db_schema}
### Instruct:
{text}
### Output:
"""
```
## How to Get Started with the Model
Use the code sample provided in the original post to interact with the model.
```python
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
import torch
from peft import PeftModel
db_schema = '''{
"collections": [
{
"name": "shipwrecks",
"indexes": [
{
"key": {
"_id": 1
}
},
{
"key": {
"feature_type": 1
}
},
{
"key": {
"chart": 1
}
},
{
"key": {
"latdec": 1,
"londec": 1
}
}
],
"uniqueIndexes": [],
"document": {
"properties": {
"_id": {
"bsonType": "string"
},
"recrd": {
"bsonType": "string"
},
"vesslterms": {
"bsonType": "string"
},
"feature_type": {
"bsonType": "string"
},
"chart": {
"bsonType": "string"
},
"latdec": {
"bsonType": "double"
},
"londec": {
"bsonType": "double"
},
"gp_quality": {
"bsonType": "string"
},
"depth": {
"bsonType": "string"
},
"sounding_type": {
"bsonType": "string"
},
"history": {
"bsonType": "string"
},
"quasou": {
"bsonType": "string"
},
"watlev": {
"bsonType": "string"
},
"coordinates": {
"bsonType": "array",
"items": {
"bsonType": "double"
}
}
}
}
}
],
"version": 1
}'''
text = ''''Find the count of shipwrecks for each unique combination of "latdec" and "longdec"'''
prompt = f"""<s>
Task Description:
Your task is to create a MongoDB query that accurately fulfills the provided Instruct while strictly adhering to the given MongoDB schema. Ensure that the query solely relies on keys and columns present in the schema. Minimize the usage of lookup operations wherever feasible to enhance query efficiency.
MongoDB Schema:
{db_schema}
### Instruct:
{text}
### Output:
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model_id = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=True)
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
trust_remote_code=True,
quantization_config=bnb_config,
revision="refs/pr/23",
device_map={"": 0},
torch_dtype="auto",
flash_attn=True,
flash_rotary=True,
fused_dense=True,
)
adapter = 'Chirayu/phi-2-mongodb'
model = PeftModel.from_pretrained(model, adapter).to(device)
model_inputs = tokenizer(prompt, return_tensors="pt").to(device)
output = model.generate(
**model_inputs,
max_length=1024,
no_repeat_ngram_size=10,
repetition_penalty=1.02,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)[0]
prompt_length = model_inputs['input_ids'].shape[1]
query = tokenizer.decode(output[prompt_length:], skip_special_tokens=False)
try:
stop_idx = query.index("</s>")
except Exception as e:
print(e)
stop_idx = len(query)
print(query[: stop_idx].strip())
```
- PEFT 0.10.0