doubleyyh commited on
Commit
e286ce9
·
verified ·
1 Parent(s): c49c9a9

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +134 -0
README.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - rag
5
+ - context-compression
6
+ - gemma
7
+ license: apache-2.0
8
+ datasets:
9
+ - hotpotqa
10
+ base_model:
11
+ - google/gemma-2b-it
12
+ ---
13
+
14
+ # EXIT: Context-Aware Extractive Compression for RAG
15
+
16
+ EXIT is a context-aware extractive compression model that improves the efficiency and effectiveness of Retrieval-Augmented Generation (RAG) by intelligently selecting relevant sentences while preserving contextual dependencies.
17
+
18
+ [[Paper]](https://arxiv.org/abs/2412.12559) [[GitHub]](https://github.com/ThisIsHwang/EXIT)
19
+
20
+ ## Model Description
21
+
22
+ EXIT is designed to:
23
+ - Compress retrieved documents while preserving critical information
24
+ - Consider full document context when evaluating sentence importance
25
+ - Enable parallelizable, context-aware extraction
26
+ - Adapt dynamically to query complexity
27
+ - Balance compression ratio and answer accuracy
28
+
29
+ ## Task and Intended Use
30
+
31
+ EXIT is trained to classify sentences as either relevant or irrelevant for answering a query based on their content and surrounding context. It is specifically designed for:
32
+
33
+ - RAG context compression
34
+ - Open-domain question answering
35
+ - Both single-hop and multi-hop queries
36
+
37
+ ## Quickstart
38
+
39
+ ```python
40
+ import torch
41
+ from transformers import AutoModelForCausalLM, AutoTokenizer
42
+ from peft import PeftModel
43
+ import spacy
44
+
45
+ # 1. Load models
46
+ base_model = AutoModelForCausalLM.from_pretrained(
47
+ "google/gemma-2b-it",
48
+ device_map="auto",
49
+ torch_dtype=torch.float16
50
+ )
51
+ exit_model = PeftModel.from_pretrained(
52
+ base_model,
53
+ "doubleyyh/exit-gemma-2b"
54
+ )
55
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
56
+
57
+ # 2. Initialize sentence splitter
58
+ nlp = spacy.load("en_core_web_sm", disable=[
59
+ "tok2vec", "tagger", "parser", "attribute_ruler",
60
+ "lemmatizer", "ner"
61
+ ])
62
+ nlp.enable_pipe("senter")
63
+
64
+ # 3. Input
65
+ query = "How do solid-state drives (SSDs) improve computer performance?"
66
+ context = """
67
+ Solid-state drives use flash memory to store data without moving parts.
68
+ Unlike traditional hard drives, SSDs have no mechanical components.
69
+ The absence of physical movement allows for much faster data access speeds.
70
+ I bought my computer last week.
71
+ SSDs significantly reduce boot times and application loading speeds.
72
+ They consume less power and are more reliable than mechanical drives.
73
+ The price of SSDs has decreased significantly in recent years.
74
+ """
75
+
76
+ # 4. Process sentences
77
+ def get_relevance(query: str, context: str, sentence: str, tau: float = 0.5) -> bool:
78
+ prompt = f'''<start_of_turn>user
79
+ Query:
80
+ {query}
81
+ Full context:
82
+ {context}
83
+ Sentence:
84
+ {sentence}
85
+ Is this sentence useful in answering the query? Answer only "Yes" or "No".<end_of_turn>
86
+ <start_of_turn>model
87
+ '''
88
+ inputs = tokenizer(prompt, return_tensors="pt").to(exit_model.device)
89
+
90
+ with torch.no_grad():
91
+ outputs = exit_model(**inputs)
92
+ yes_id = tokenizer.encode("Yes", add_special_tokens=False)
93
+ no_id = tokenizer.encode("No", add_special_tokens=False)
94
+ logits = outputs.logits[0, -1, [yes_id, no_id]]
95
+ prob = torch.softmax(logits, dim=0)[0].item()
96
+
97
+ return prob >= tau
98
+
99
+ # 5. Compress document
100
+ sentences = [sent.text.strip() for sent in nlp(context).sents]
101
+ compressed = [sent for sent in sentences if get_relevance(query, context, sent)]
102
+ compressed_text = " ".join(compressed)
103
+
104
+ print(f"Compressed text ({len(compressed)}/{len(sentences)} sentences):", compressed_text)
105
+ ```
106
+
107
+ ## Training Data
108
+
109
+ The model was trained on the HotpotQA dataset using:
110
+ - Positive examples: Sentences marked as supporting facts
111
+ - Hard negatives: Sentences from same documents but not supporting facts
112
+ - Random negatives: Sentences from unrelated documents
113
+
114
+ ## Parameters
115
+
116
+ - Base model: Gemma-2b-it
117
+ - Training method: PEFT/LoRA
118
+ - Recommended tau threshold: 0.5
119
+
120
+ ## Limitations
121
+
122
+ - Currently optimized for English text only
123
+ - No support for cross-lingual compression
124
+
125
+ ## Citation
126
+
127
+ ```bibtex
128
+ @article{hwang2024exit,
129
+ title={EXIT: Context-Aware Extractive Compression for Enhancing Retrieval-Augmented Generation},
130
+ author={Hwang, Taeho and Cho, Sukmin and Jeong, Soyeong and Song, Hoyun and Han, SeungYoon and Park, Jong C.},
131
+ journal={arXiv preprint arXiv:2412.12559},
132
+ year={2024}
133
+ }
134
+ ```