Spaces:
Sleeping
Sleeping
Upload 11 files
Browse files- .gitattributes +2 -0
- my_model/LLAMA2/LLAMA2_config.py +15 -0
- my_model/LLAMA2/LLAMA2_model.py +173 -0
- my_model/extract_objects.py +45 -0
- my_model/fine_tuner/fine_tuner.py +347 -0
- my_model/fine_tuner/fine_tuning_config.py +114 -0
- my_model/fine_tuner/fine_tuning_data/fine_tuning_data_detic.csv +3 -0
- my_model/fine_tuner/fine_tuning_data/fine_tuning_data_yolov5.csv +3 -0
- my_model/fine_tuner/fine_tuning_data/read_me.txt +8 -0
- my_model/fine_tuner/fine_tuning_data_handler.py +182 -0
- my_model/object_detection.py +259 -0
- my_model/utilities.py +278 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
my_model/fine_tuner/fine_tuning_data/fine_tuning_data_detic.csv filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
my_model/fine_tuner/fine_tuning_data/fine_tuning_data_yolov5.csv filter=lfs diff=lfs merge=lfs -text
|
my_model/LLAMA2/LLAMA2_config.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration parameters for LLaMA-2 model
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"
|
| 6 |
+
TOKENIZER_NAME = "meta-llama/Llama-2-7b-chat-hf"
|
| 7 |
+
QUANTIZATION = '4bit' # Options: '4bit', '8bit', or None
|
| 8 |
+
FROM_SAVED = False
|
| 9 |
+
MODEL_PATH = None
|
| 10 |
+
TRUST_REMOTE = False
|
| 11 |
+
USE_FAST = True
|
| 12 |
+
ADD_EOS_TOKEN = True
|
| 13 |
+
# ACCESS_TOKEN = "xx" # My HF Read-only Token, to be added here if needed
|
| 14 |
+
huggingface_token = os.getenv('HUGGINGFACE_TOKEN') # for use as a secret on hf space
|
| 15 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
my_model/LLAMA2/LLAMA2_model.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 3 |
+
from typing import Optional
|
| 4 |
+
import bitsandbytes # only for using on GPU
|
| 5 |
+
import accelerate # only for using on GPU
|
| 6 |
+
from my_model.LLAMA2 import LLAMA2_config as config # Importing LLAMA2 configuration file
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
# Suppress only FutureWarning from transformers
|
| 10 |
+
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Llama2ModelManager:
|
| 14 |
+
"""
|
| 15 |
+
Manages loading and configuring the LLaMA-2 model and tokenizer.
|
| 16 |
+
|
| 17 |
+
Attributes:
|
| 18 |
+
device (str): Device to use for the model ('cuda' or 'cpu').
|
| 19 |
+
model_name (str): Name or path of the pre-trained model.
|
| 20 |
+
tokenizer_name (str): Name or path of the tokenizer.
|
| 21 |
+
quantization (str): Specifies the quantization level ('4bit', '8bit', or None).
|
| 22 |
+
from_saved (bool): Flag to load the model from a saved path.
|
| 23 |
+
model_path (str or None): Path to the saved model if `from_saved` is True.
|
| 24 |
+
trust_remote (bool): Whether to trust remote code when loading the tokenizer.
|
| 25 |
+
use_fast (bool): Whether to use the fast version of the tokenizer.
|
| 26 |
+
add_eos_token (bool): Whether to add an EOS token to the tokenizer.
|
| 27 |
+
access_token (str): Access token for Hugging Face Hub.
|
| 28 |
+
model (AutoModelForCausalLM or None): Loaded model, initially None.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self) -> None:
|
| 32 |
+
"""
|
| 33 |
+
Initializes the Llama2ModelManager class with configuration settings.
|
| 34 |
+
"""
|
| 35 |
+
self.device: str = config.DEVICE
|
| 36 |
+
self.model_name: str = config.MODEL_NAME
|
| 37 |
+
self.tokenizer_name: str = config.TOKENIZER_NAME
|
| 38 |
+
self.quantization: str = config.QUANTIZATION
|
| 39 |
+
self.from_saved: bool = config.FROM_SAVED
|
| 40 |
+
self.model_path: Optional[str] = config.MODEL_PATH
|
| 41 |
+
self.trust_remote: bool = config.TRUST_REMOTE
|
| 42 |
+
self.use_fast: bool = config.USE_FAST
|
| 43 |
+
self.add_eos_token: bool = config.ADD_EOS_TOKEN
|
| 44 |
+
self.access_token: str = config.ACCESS_TOKEN
|
| 45 |
+
self.model: Optional[AutoModelForCausalLM] = None
|
| 46 |
+
|
| 47 |
+
def create_bnb_config(self) -> BitsAndBytesConfig:
|
| 48 |
+
"""
|
| 49 |
+
Creates a BitsAndBytes configuration based on the quantization setting.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
BitsAndBytesConfig: Configuration for BitsAndBytes optimized model.
|
| 53 |
+
"""
|
| 54 |
+
if self.quantization == '4bit':
|
| 55 |
+
return BitsAndBytesConfig(
|
| 56 |
+
load_in_4bit=True,
|
| 57 |
+
bnb_4bit_use_double_quant=True,
|
| 58 |
+
bnb_4bit_quant_type="nf4",
|
| 59 |
+
bnb_4bit_compute_dtype=torch.bfloat16
|
| 60 |
+
)
|
| 61 |
+
elif self.quantization == '8bit':
|
| 62 |
+
return BitsAndBytesConfig(
|
| 63 |
+
load_in_8bit=True,
|
| 64 |
+
bnb_8bit_use_double_quant=True,
|
| 65 |
+
bnb_8bit_quant_type="nf4",
|
| 66 |
+
bnb_8bit_compute_dtype=torch.bfloat16
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def load_model(self) -> AutoModelForCausalLM:
|
| 70 |
+
"""
|
| 71 |
+
Loads the LLaMA-2 model based on the specified configuration. If the model is already loaded, returns the existing model.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
AutoModelForCausalLM: Loaded LLaMA-2 model.
|
| 75 |
+
"""
|
| 76 |
+
if self.model is not None:
|
| 77 |
+
print("Model is already loaded.")
|
| 78 |
+
return self.model
|
| 79 |
+
|
| 80 |
+
if self.from_saved:
|
| 81 |
+
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map="auto")
|
| 82 |
+
else:
|
| 83 |
+
bnb_config = None if self.quantization is None else self.create_bnb_config()
|
| 84 |
+
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto",
|
| 85 |
+
quantization_config=bnb_config,
|
| 86 |
+
torch_dtype=torch.float16,
|
| 87 |
+
token=self.access_token)
|
| 88 |
+
|
| 89 |
+
if self.model is not None:
|
| 90 |
+
print(f"LLAMA2 Model loaded successfully in {self.quantization} quantization.")
|
| 91 |
+
else:
|
| 92 |
+
print("LLAMA2 Model failed to load.")
|
| 93 |
+
return self.model
|
| 94 |
+
|
| 95 |
+
def load_tokenizer(self) -> AutoTokenizer:
|
| 96 |
+
"""
|
| 97 |
+
Loads the tokenizer for the LLaMA-2 model with the specified configuration.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
AutoTokenizer: Loaded tokenizer for LLaMA-2 model.
|
| 101 |
+
"""
|
| 102 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=self.use_fast,
|
| 103 |
+
token=self.access_token,
|
| 104 |
+
trust_remote_code=self.trust_remote,
|
| 105 |
+
add_eos_token=self.add_eos_token)
|
| 106 |
+
|
| 107 |
+
if self.tokenizer is not None:
|
| 108 |
+
print(f"LLAMA2 Tokenizer loaded successfully.")
|
| 109 |
+
else:
|
| 110 |
+
print("LLAMA2 Tokenizer failed to load.")
|
| 111 |
+
|
| 112 |
+
return self.tokenizer
|
| 113 |
+
|
| 114 |
+
def load_model_and_tokenizer(self, for_fine_tuning):
|
| 115 |
+
"""
|
| 116 |
+
Loads LLAMa2 model and tokenizer in one method and adds special tokens if the purpose if fine tuning.
|
| 117 |
+
:param for_fine_tuning: YES(True) / NO (False)
|
| 118 |
+
:return: LLAMA2 Model and Tokenizer
|
| 119 |
+
"""
|
| 120 |
+
if for_fine_tuning:
|
| 121 |
+
self.tokenizer = self.load_tokenizer()
|
| 122 |
+
self.model = self.load_model()
|
| 123 |
+
self.add_special_tokens()
|
| 124 |
+
else:
|
| 125 |
+
self.tokenizer = self.load_tokenizer()
|
| 126 |
+
self.model = self.load_model()
|
| 127 |
+
|
| 128 |
+
return self.model, self.tokenizer
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def add_special_tokens(self, tokens: Optional[list[str]] = None) -> None:
|
| 132 |
+
"""
|
| 133 |
+
Adds special tokens to the tokenizer and updates the model's token embeddings if the model is loaded,
|
| 134 |
+
only if the tokenizer is loaded.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
tokens (list of str, optional): Special tokens to add. Defaults to a predefined set.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
None
|
| 141 |
+
"""
|
| 142 |
+
if self.tokenizer is None:
|
| 143 |
+
print("Tokenizer is not loaded. Cannot add special tokens.")
|
| 144 |
+
return
|
| 145 |
+
|
| 146 |
+
if tokens is None:
|
| 147 |
+
tokens = ['[CAP]', '[/CAP]', '[QES]', '[/QES]', '[OBJ]', '[/OBJ]']
|
| 148 |
+
|
| 149 |
+
# Update the tokenizer with new tokens
|
| 150 |
+
print(f"Original vocabulary size: {len(self.tokenizer)}")
|
| 151 |
+
print(f"Adding the following tokens: {tokens}")
|
| 152 |
+
self.tokenizer.add_tokens(tokens, special_tokens=True)
|
| 153 |
+
self.tokenizer.add_special_tokens({'pad_token': '<pad>'})
|
| 154 |
+
print(f"Adding Padding Token {self.tokenizer.pad_token}")
|
| 155 |
+
self.tokenizer.padding_side = "right"
|
| 156 |
+
print(f'Padding side: {self.tokenizer.padding_side}')
|
| 157 |
+
|
| 158 |
+
# Resize the model token embeddings if the model is loaded
|
| 159 |
+
if self.model is not None:
|
| 160 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
| 161 |
+
self.model.config.pad_token_id = self.tokenizer.pad_token_id
|
| 162 |
+
|
| 163 |
+
print(f'Updated Vocabulary Size: {len(self.tokenizer)}')
|
| 164 |
+
print(f'Padding Token: {self.tokenizer.pad_token}')
|
| 165 |
+
print(f'Special Tokens: {self.tokenizer.added_tokens_decoder}')
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if __name__ == "__main__":
|
| 169 |
+
pass
|
| 170 |
+
LLAMA2_manager = Llama2ModelManager()
|
| 171 |
+
LLAMA2_model = LLAMA2_manager.load_model() # First time loading the model
|
| 172 |
+
LLAMA2_tokenizer = LLAMA2_manager.load_tokenizer()
|
| 173 |
+
LLAMA2_manager.add_special_tokens(LLAMA2_model, LLAMA2_tokenizer)
|
my_model/extract_objects.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from object_detection import ObjectDetector
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def detect_objects_for_image(image_name, detector):
|
| 5 |
+
|
| 6 |
+
if os.path.exists(image_path):
|
| 7 |
+
image = detector.process_image(image_path)
|
| 8 |
+
detected_objects_str, _ = detector.detect_objects(image)
|
| 9 |
+
return detected_objects_str
|
| 10 |
+
else:
|
| 11 |
+
return "Image not found"
|
| 12 |
+
|
| 13 |
+
def add_detected_objects_to_dataframe(df, image_directory, detector):
|
| 14 |
+
"""
|
| 15 |
+
Adds a column to the DataFrame with detected objects for each image specified in the 'image_name' column.
|
| 16 |
+
|
| 17 |
+
Parameters:
|
| 18 |
+
df (pd.DataFrame): DataFrame containing a column 'image_name' with image filenames.
|
| 19 |
+
image_directory (str): Path to the directory containing images.
|
| 20 |
+
detector (ObjectDetector): An instance of the ObjectDetector class.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
pd.DataFrame: The original DataFrame with an additional column 'detected_objects'.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
# Ensure 'image_name' column exists in the DataFrame
|
| 27 |
+
if 'image_name' not in df.columns:
|
| 28 |
+
raise ValueError("DataFrame must contain an 'image_name' column.")
|
| 29 |
+
|
| 30 |
+
image_path = os.path.join(image_directory, image_name)
|
| 31 |
+
|
| 32 |
+
# Function to detect objects for a given image filename
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Apply the function to each row in the DataFrame
|
| 36 |
+
df['detected_objects'] = df['image_name'].apply(detect_objects_for_image)
|
| 37 |
+
|
| 38 |
+
return df
|
| 39 |
+
|
| 40 |
+
# Example usage (assuming the function will be used in a context where 'detector' is defined and configured):
|
| 41 |
+
# df_images = pd.DataFrame({"image_name": ["image1.jpg", "image2.jpg", ...]})
|
| 42 |
+
# image_directory = "path/to/image_directory"
|
| 43 |
+
# updated_df = add_detected_objects_to_dataframe(df_images, image_directory, detector)
|
| 44 |
+
# updated_df.head()
|
| 45 |
+
|
my_model/fine_tuner/fine_tuner.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Main Fine-Tuning Script for meta-llama/Llama-2-7b-chat-hf
|
| 3 |
+
|
| 4 |
+
# This script is the central executable for fine-tuning large language models, specifically designed for the LLAMA2
|
| 5 |
+
# model.
|
| 6 |
+
# It encompasses the entire process of fine-tuning, starting from data preparation to the final model training.
|
| 7 |
+
# The script leverages the 'FinetuningDataHandler' class for data loading, inspection, preparation, and splitting.
|
| 8 |
+
# This ensures that the dataset is correctly processed and prepared for effective training.
|
| 9 |
+
|
| 10 |
+
# The fine-tuning process is managed by the Finetuner class, which handles the training of the model using specific
|
| 11 |
+
# training arguments and datasets. Advanced configurations for Quantized Low-Rank Adaptation (QLoRA) and Parameter
|
| 12 |
+
# Efficient Fine-Tuning (PEFT) are utilized to optimize the training process on limited hardware resources.
|
| 13 |
+
|
| 14 |
+
# The script is designed to be executed as a standalone process, providing an end-to-end solution for fine-tuning
|
| 15 |
+
# LLMs. It is a part of a larger project aimed at optimizing the performance of language model to adapt to
|
| 16 |
+
# OK-VQA dataset.
|
| 17 |
+
|
| 18 |
+
# Ensure all dependencies are installed and the required files are in place before running this script.
|
| 19 |
+
# The configurations for the fine-tuning process are defined in the 'fine_tuning_config.py' file.
|
| 20 |
+
|
| 21 |
+
# ---------- Please run this file for the full fine-tuning process to start ----------#
|
| 22 |
+
# ---------- Please ensure this is run on a GPU ----------#
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, TRANSFORMERS_CACHE
|
| 27 |
+
from trl import SFTTrainer
|
| 28 |
+
from datasets import Dataset, load_dataset
|
| 29 |
+
from peft import LoraConfig, prepare_model_for_kbit_training, PeftModel
|
| 30 |
+
import fine_tuning_config as config
|
| 31 |
+
from typing import List
|
| 32 |
+
import bitsandbytes # only on GPU
|
| 33 |
+
import gc
|
| 34 |
+
import os
|
| 35 |
+
import shutil
|
| 36 |
+
from my_model.LLAMA2.LLAMA2_model import Llama2ModelManager
|
| 37 |
+
from fine_tuning_data_handler import FinetuningDataHandler
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class QLoraConfig:
|
| 41 |
+
"""
|
| 42 |
+
Configures QLoRA (Quantized Low-Rank Adaptation) parameters for efficient model fine-tuning.
|
| 43 |
+
LoRA allows adapting large language models with a minimal number of trainable parameters.
|
| 44 |
+
|
| 45 |
+
Attributes:
|
| 46 |
+
lora_config (LoraConfig): Configuration object for LoRA parameters.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self) -> None:
|
| 50 |
+
"""
|
| 51 |
+
Initializes QLoraConfig with specific LoRA parameters.
|
| 52 |
+
|
| 53 |
+
"""
|
| 54 |
+
# please refer to config file 'fine_tuning_config.py' for QLORA arguments description.
|
| 55 |
+
self.lora_config = LoraConfig(
|
| 56 |
+
lora_alpha=config.LORA_ALPHA,
|
| 57 |
+
lora_dropout=config.LORA_DROPOUT,
|
| 58 |
+
r=config.LORA_R,
|
| 59 |
+
bias="none", # bias is already accounted for in LLAMA2 pre-trained model layers.
|
| 60 |
+
task_type="CAUSAL_LM",
|
| 61 |
+
target_modules=['up_proj', 'down_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj'] # modules for fine-tuning.
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Finetuner:
|
| 66 |
+
"""
|
| 67 |
+
The Finetuner class manages the fine-tuning process of a pre-trained language model using specific
|
| 68 |
+
training arguments and datasets. It is designed to adapt a pre-trained model on a specific dataset
|
| 69 |
+
to enhance its performance on similar data.
|
| 70 |
+
|
| 71 |
+
This class not only facilitates the fine-tuning of LLAMA2 but also includes advanced
|
| 72 |
+
resource management capabilities. It provides methods for deleting model and trainer objects,
|
| 73 |
+
clearing GPU memory, and cleaning up Hugging Face's Transformers cache. These functionalities
|
| 74 |
+
make the Finetuner class especially useful in environments with limited computational resources
|
| 75 |
+
or when managing multiple models or training sessions.
|
| 76 |
+
|
| 77 |
+
Additionally, the class supports configurations for Quantized Low-Rank Adaptation (QLoRA)
|
| 78 |
+
to fine-tune models with minimal trainable parameters, and Parameter Efficient Fine-Tuning (PEFT)
|
| 79 |
+
for training efficiency on limited hardware.
|
| 80 |
+
|
| 81 |
+
Attributes:
|
| 82 |
+
base_model (AutoModelForCausalLM): The pre-trained language model to be fine-tuned.
|
| 83 |
+
tokenizer (AutoTokenizer): The tokenizer associated with the model.
|
| 84 |
+
train_dataset (Dataset): The dataset used for training.
|
| 85 |
+
eval_dataset (Dataset): The dataset used for evaluation.
|
| 86 |
+
training_arguments (TrainingArguments): Configuration for training the model.
|
| 87 |
+
|
| 88 |
+
Key Methods:
|
| 89 |
+
- load_LLAMA2_for_finetuning: Loads the LLAMA2 model and tokenizer for fine-tuning.
|
| 90 |
+
- train: Trains the model using PEFT configuration.
|
| 91 |
+
- delete_model: Deletes a specified model attribute.
|
| 92 |
+
- delete_trainer: Deletes a specified trainer object.
|
| 93 |
+
- clear_training_resources: Clears GPU memory.
|
| 94 |
+
- clear_cache_and_collect_garbage: Clears Transformers cache and performs garbage collection.
|
| 95 |
+
- find_all_linear_names: Identifies linear layer names suitable for LoRA application.
|
| 96 |
+
- print_trainable_parameters: Prints the number of trainable parameters in the model.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
def __init__(self, train_dataset: Dataset, eval_dataset: Dataset) -> None:
|
| 100 |
+
"""
|
| 101 |
+
Initializes the Finetuner class with the model, tokenizer, and datasets.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
model (AutoModelForCausalLM): The pre-trained language model.
|
| 105 |
+
tokenizer (AutoTokenizer): The tokenizer for the model.
|
| 106 |
+
train_dataset (Dataset): The dataset for training the model.
|
| 107 |
+
eval_dataset (Dataset): The dataset for evaluating the model.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
self.base_model, self.tokenizer = self.load_LLAMA2_for_finetuning()
|
| 111 |
+
self.merged_model = None
|
| 112 |
+
self.train_dataset = train_dataset
|
| 113 |
+
self.eval_dataset = eval_dataset
|
| 114 |
+
# please refer to config file 'fine_tuning_config.py' for training arguments description.
|
| 115 |
+
self.training_arguments = TrainingArguments(
|
| 116 |
+
output_dir=config.OUTPUT_DIR,
|
| 117 |
+
num_train_epochs=config.NUM_TRAIN_EPOCHS,
|
| 118 |
+
per_device_train_batch_size=config.PER_DEVICE_TRAIN_BATCH_SIZE,
|
| 119 |
+
per_device_eval_batch_size=config.PER_DEVICE_EVAL_BATCH_SIZE,
|
| 120 |
+
gradient_accumulation_steps=config.GRADIENT_ACCUMULATION_STEPS,
|
| 121 |
+
fp16=config.FP16,
|
| 122 |
+
bf16=config.BF16,
|
| 123 |
+
evaluation_strategy=config.Evaluation_STRATEGY,
|
| 124 |
+
eval_steps=config.EVALUATION_STEPS,
|
| 125 |
+
max_grad_norm=config.MAX_GRAD_NORM,
|
| 126 |
+
learning_rate=config.LEARNING_RATE,
|
| 127 |
+
weight_decay=config.WEIGHT_DECAY,
|
| 128 |
+
optim=config.OPTIM,
|
| 129 |
+
lr_scheduler_type=config.LR_SCHEDULER_TYPE,
|
| 130 |
+
max_steps=config.MAX_STEPS,
|
| 131 |
+
warmup_ratio=config.WARMUP_RATIO,
|
| 132 |
+
group_by_length=config.GROUP_BY_LENGTH,
|
| 133 |
+
save_steps=config.SAVE_STEPS,
|
| 134 |
+
logging_steps=config.LOGGING_STEPS,
|
| 135 |
+
report_to="tensorboard"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def load_LLAMA2_for_finetuning(self):
|
| 139 |
+
"""
|
| 140 |
+
Loads the LLAMA2 model and tokenizer, specifically configured for fine-tuning.
|
| 141 |
+
This method ensures the model is ready to be adapted to a specific task or dataset.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
Tuple[AutoModelForCausalLM, AutoTokenizer]: The loaded model and tokenizer.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
llm_manager = Llama2ModelManager()
|
| 148 |
+
base_model, tokenizer = llm_manager.load_model_and_tokenizer(for_fine_tuning=True)
|
| 149 |
+
|
| 150 |
+
return base_model, tokenizer
|
| 151 |
+
|
| 152 |
+
def find_all_linear_names(self) -> List[str]:
|
| 153 |
+
"""
|
| 154 |
+
Identifies all linear layer names in the model that are suitable for applying LoRA.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
List[str]: A list of linear layer names.
|
| 158 |
+
"""
|
| 159 |
+
cls = bitsandbytes.nn.Linear4bit
|
| 160 |
+
lora_module_names = set()
|
| 161 |
+
for name, module in self.base_model.named_modules():
|
| 162 |
+
if isinstance(module, cls):
|
| 163 |
+
names = name.split('.')
|
| 164 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
| 165 |
+
|
| 166 |
+
# We dont want to train these two modules to avoid computational overhead.
|
| 167 |
+
lora_module_names -= {'lm_head', 'gate_proj'}
|
| 168 |
+
return list(lora_module_names)
|
| 169 |
+
|
| 170 |
+
def print_trainable_parameters(self, use_4bit: bool = False) -> None:
|
| 171 |
+
"""
|
| 172 |
+
Calculates and prints the number of trainable parameters in the model.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
use_4bit (bool): If true, calculates the parameter count considering 4-bit quantization.
|
| 176 |
+
"""
|
| 177 |
+
trainable_params = sum(p.numel() for p in self.base_model.parameters() if p.requires_grad)
|
| 178 |
+
if use_4bit:
|
| 179 |
+
trainable_params /= 2
|
| 180 |
+
|
| 181 |
+
total_params = sum(p.numel() for p in self.base_model.parameters())
|
| 182 |
+
print(f"All Parameters: {total_params:,d} || Trainable Parameters: {trainable_params:,d} "
|
| 183 |
+
f"|| Trainable Parameters %: {100 * trainable_params / total_params:.2f}%")
|
| 184 |
+
|
| 185 |
+
def train(self, peft_config: LoraConfig) -> None:
|
| 186 |
+
"""
|
| 187 |
+
Trains the model using the specified PEFT (Progressive Effort Fine-Tuning) configuration.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
peft_config (LoraConfig): Configuration for the PEFT training process.
|
| 191 |
+
"""
|
| 192 |
+
self.base_model.config.use_cache = False
|
| 193 |
+
# Set the pretraining_tp flag to 1 to enable the use of LoRA (Low-Rank Adapters) layers.
|
| 194 |
+
self.base_model.config.pretraining_tp = 1
|
| 195 |
+
# Prepare the model for k-bit training by quantizing the weights to 4 bits using bitsandbytes.
|
| 196 |
+
self.base_model = prepare_model_for_kbit_training(self.base_model)
|
| 197 |
+
self.trainer = SFTTrainer(
|
| 198 |
+
model=self.base_model,
|
| 199 |
+
train_dataset=self.train_dataset,
|
| 200 |
+
eval_dataset=self.eval_dataset,
|
| 201 |
+
peft_config=peft_config,
|
| 202 |
+
dataset_text_field='text',
|
| 203 |
+
max_seq_length=config.MAX_TOKEN_COUNT,
|
| 204 |
+
tokenizer=self.tokenizer,
|
| 205 |
+
args=self.training_arguments,
|
| 206 |
+
packing=config.PACKING
|
| 207 |
+
)
|
| 208 |
+
self.trainer.train()
|
| 209 |
+
|
| 210 |
+
def save_model(self):
|
| 211 |
+
|
| 212 |
+
"""
|
| 213 |
+
Saves the fine-tuned model to the specified directory.
|
| 214 |
+
|
| 215 |
+
This method saves the model weights and configuration of the fine-tuned model.
|
| 216 |
+
The save directory and filename are determined by the configuration provided in
|
| 217 |
+
the 'fine_tuning_config.py' file. It is useful for persisting the fine-tuned model
|
| 218 |
+
for later use or evaluation.
|
| 219 |
+
|
| 220 |
+
The saved model can be easily loaded using Hugging Face's model loading utilities.
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
self.fine_tuned_adapter_name = config.ADAPTER_SAVE_NAME
|
| 224 |
+
self.trainer.model.save_pretrained(self.fine_tuned_adapter_name)
|
| 225 |
+
|
| 226 |
+
def merge_weights(self):
|
| 227 |
+
"""
|
| 228 |
+
Merges the weights of the fine-tuned adapter with the base model.
|
| 229 |
+
|
| 230 |
+
This method integrates the fine-tuned adapter weights into the base model,
|
| 231 |
+
resulting in a single consolidated model. The merged model can then be used
|
| 232 |
+
for inference or further training.
|
| 233 |
+
|
| 234 |
+
After merging, the weights of the adapter are no longer separate from the
|
| 235 |
+
base model, enabling more efficient storage and deployment. The merged model
|
| 236 |
+
is stored in the 'self.merged_model' attribute of the Finetuner class.
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
self.merged_model = PeftModel.from_pretrained(self.base_model, self.fine_tuned_adapter_name)
|
| 240 |
+
self.merged_model = self.merged_model.merge_and_unload()
|
| 241 |
+
|
| 242 |
+
def delete_model(self, model_name: str):
|
| 243 |
+
"""
|
| 244 |
+
Deletes a specified model attribute.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
model_name (str): The name of the model attribute to delete.
|
| 248 |
+
"""
|
| 249 |
+
try:
|
| 250 |
+
if hasattr(self, model_name) and getattr(self, model_name) is not None:
|
| 251 |
+
delattr(self, model_name)
|
| 252 |
+
print(f"Model '{model_name}' has been deleted.")
|
| 253 |
+
else:
|
| 254 |
+
print(f"Warning: Model '{model_name}' has already been cleared or does not exist.")
|
| 255 |
+
except Exception as e:
|
| 256 |
+
print(f"Error occurred while deleting model '{model_name}': {str(e)}")
|
| 257 |
+
|
| 258 |
+
def delete_trainer(self, trainer_name: str):
|
| 259 |
+
"""
|
| 260 |
+
Deletes a specified trainer object.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
trainer_name (str): The name of the trainer object to delete.
|
| 264 |
+
"""
|
| 265 |
+
try:
|
| 266 |
+
if hasattr(self, trainer_name) and getattr(self, trainer_name) is not None:
|
| 267 |
+
delattr(self, trainer_name)
|
| 268 |
+
print(f"Trainer object '{trainer_name}' has been deleted.")
|
| 269 |
+
else:
|
| 270 |
+
print(f"Warning: Trainer object '{trainer_name}' has already been cleared or does not exist.")
|
| 271 |
+
except Exception as e:
|
| 272 |
+
print(f"Error occurred while deleting trainer object '{trainer_name}': {str(e)}")
|
| 273 |
+
|
| 274 |
+
def clear_training_resources(self):
|
| 275 |
+
"""
|
| 276 |
+
Clears GPU memory.
|
| 277 |
+
"""
|
| 278 |
+
try:
|
| 279 |
+
if torch.cuda.is_available():
|
| 280 |
+
torch.cuda.empty_cache()
|
| 281 |
+
print("GPU memory has been cleared.")
|
| 282 |
+
except Exception as e:
|
| 283 |
+
print(f"Error occurred while clearing GPU memory: {str(e)}")
|
| 284 |
+
|
| 285 |
+
def clear_cache_and_collect_garbage(self):
|
| 286 |
+
"""
|
| 287 |
+
Clears Hugging Face's Transformers cache and runs garbage collection.
|
| 288 |
+
"""
|
| 289 |
+
try:
|
| 290 |
+
if os.path.exists(TRANSFORMERS_CACHE):
|
| 291 |
+
shutil.rmtree(TRANSFORMERS_CACHE, ignore_errors=True)
|
| 292 |
+
print("Transformers cache has been cleared.")
|
| 293 |
+
|
| 294 |
+
gc.collect()
|
| 295 |
+
print("Garbage collection has been executed.")
|
| 296 |
+
except Exception as e:
|
| 297 |
+
print(f"Error occurred while clearing cache and collecting garbage: {str(e)}")
|
| 298 |
+
|
| 299 |
+
def fine_tune(save_fine_tuned_adapter=False, merge=False, delete_trainer_after_fine_tune=False):
|
| 300 |
+
"""
|
| 301 |
+
Conducts the fine-tuning process of a pre-trained language model using specified configurations.
|
| 302 |
+
This function encompasses the complete workflow of fine-tuning, including data handling, training,
|
| 303 |
+
and optional steps like saving the fine-tuned model and merging weights.
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
save_fine_tuned_adapter (bool): If True, saves the fine-tuned adapter after training.
|
| 307 |
+
merge (bool): If True, merges the weights of the fine-tuned adapter into the base model.
|
| 308 |
+
delete_trainer_after_fine_tune (bool): If True, deletes the trainer object after fine-tuning to free up resources.
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
The fine-tuned model after the fine-tuning process. This could be either the merged model
|
| 312 |
+
or the trained model based on the provided arguments.
|
| 313 |
+
|
| 314 |
+
The function initiates by preparing the training and evaluation datasets using the `FinetuningDataHandler`.
|
| 315 |
+
It then sets up the QLoRA configuration for the fine-tuning process. The actual training is carried out by
|
| 316 |
+
the `Finetuner` class. Post training, based on the arguments, the function can save the fine-tuned model,
|
| 317 |
+
merge the adapter weights with the base model, and clean up resources by deleting the trainer object.
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
data_handler = FinetuningDataHandler()
|
| 321 |
+
fine_tuning_data_train, fine_tuning_data_eval = data_handler.inspect_prepare_split_data()
|
| 322 |
+
qlora = QLoraConfig()
|
| 323 |
+
peft_config = qlora.lora_config
|
| 324 |
+
tuner = Finetuner(fine_tuning_data_train, fine_tuning_data_eval)
|
| 325 |
+
tuner.train(peft_config=peft_config)
|
| 326 |
+
if save_fine_tuned_adapter:
|
| 327 |
+
tuner.save_model()
|
| 328 |
+
|
| 329 |
+
if merge:
|
| 330 |
+
tuner.merge_weights()
|
| 331 |
+
|
| 332 |
+
if delete_trainer_after_fine_tune:
|
| 333 |
+
tuner.delete_trainer("trainer")
|
| 334 |
+
|
| 335 |
+
tuner.delete_model("base_model") # We always delete this as it is not required after the merger.
|
| 336 |
+
|
| 337 |
+
if save_fine_tuned_adapter:
|
| 338 |
+
tuner.save_model()
|
| 339 |
+
if tuner.merged_model is not None:
|
| 340 |
+
return tuner.merged_model
|
| 341 |
+
else:
|
| 342 |
+
return tuner.trainer.model
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
if __name__ == "__main__":
|
| 347 |
+
fine_tune()
|
my_model/fine_tuner/fine_tuning_config.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configurable parameters for fine-tuning
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# *** Dataset ***
|
| 7 |
+
# Base directory where the script is running
|
| 8 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 9 |
+
# Path to the folder containing the data files, relative to the configuration file
|
| 10 |
+
DATA_FOLDER = 'fine_tuning_data'
|
| 11 |
+
# Full path to the data folder
|
| 12 |
+
DATA_FOLDER_PATH = os.path.join(BASE_DIR, DATA_FOLDER)
|
| 13 |
+
# Path to the dataset file (CSV format)
|
| 14 |
+
DATASET_FILE = os.path.join(DATA_FOLDER_PATH, 'fine_tuning_data_yolov5.csv') # or 'fine_tuning_data_detic.csv'
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# *** Fine-tuned Adapter ***
|
| 18 |
+
TRAINED_ADAPTER_NAME = 'fine_tuned_adapter' # name of fine-tuned adapter.
|
| 19 |
+
FINE_TUNED_ADAPTER_FOLDER = 'fine_tuned_model'
|
| 20 |
+
FINE_TUNED_ADAPTER_PATH = os.path.join(BASE_DIR, FINE_TUNED_ADAPTER_FOLDER)
|
| 21 |
+
ADAPTER_SAVE_NAME = os.path.join(FINE_TUNED_ADAPTER_PATH, TRAINED_ADAPTER_NAME)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Proportion of the dataset to include in the test split (e.g., 0.1 for 10%)
|
| 25 |
+
TEST_SIZE = 0.1
|
| 26 |
+
|
| 27 |
+
# Seed for random operations to ensure reproducibility
|
| 28 |
+
SEED = 123
|
| 29 |
+
|
| 30 |
+
# *** QLoRA Configuration Parameters ***
|
| 31 |
+
# LoRA attention dimension: number of additional parameters in each LoRA layer
|
| 32 |
+
LORA_R = 64
|
| 33 |
+
|
| 34 |
+
# Alpha parameter for LoRA scaling: controls the scaling of LoRA weights
|
| 35 |
+
LORA_ALPHA = 32
|
| 36 |
+
|
| 37 |
+
# Dropout probability for LoRA layers: probability of dropping a unit in LoRA layers
|
| 38 |
+
LORA_DROPOUT = 0.05
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# *** TrainingArguments Configuration Parameters for the Transformers library ***
|
| 43 |
+
# Output directory to save model predictions and checkpoints
|
| 44 |
+
OUTPUT_DIR = "./TUNED_MODEL_LLAMA"
|
| 45 |
+
|
| 46 |
+
# Number of epochs to train the model
|
| 47 |
+
NUM_TRAIN_EPOCHS = 1
|
| 48 |
+
|
| 49 |
+
# Enable mixed-precision training using fp16 (set to True for faster training)
|
| 50 |
+
FP16 = True
|
| 51 |
+
|
| 52 |
+
# Enable mixed-precision training using bf16 (set to True if using an A100 GPU)
|
| 53 |
+
BF16 = False
|
| 54 |
+
|
| 55 |
+
# Batch size per GPU/Device for training
|
| 56 |
+
PER_DEVICE_TRAIN_BATCH_SIZE = 16
|
| 57 |
+
|
| 58 |
+
# Batch size per GPU/Device for evaluation
|
| 59 |
+
PER_DEVICE_EVAL_BATCH_SIZE = 8
|
| 60 |
+
|
| 61 |
+
# Number of update steps to accumulate gradients before performing a backward/update pass
|
| 62 |
+
GRADIENT_ACCUMULATION_STEPS = 1
|
| 63 |
+
|
| 64 |
+
# Enable gradient checkpointing to reduce memory usage at the cost of a slight slowdown
|
| 65 |
+
GRADIENT_CHECKPOINTING = True
|
| 66 |
+
|
| 67 |
+
# Maximum gradient norm for gradient clipping to prevent exploding gradients
|
| 68 |
+
MAX_GRAD_NORM = 0.3
|
| 69 |
+
|
| 70 |
+
# Initial learning rate for the AdamW optimizer
|
| 71 |
+
LEARNING_RATE = 2e-4
|
| 72 |
+
|
| 73 |
+
# Weight decay coefficient for regularization (applied to all layers except bias/LayerNorm weights)
|
| 74 |
+
WEIGHT_DECAY = 0.01
|
| 75 |
+
|
| 76 |
+
# Optimizer type, here using 'paged_adamw_8bit' for efficient training
|
| 77 |
+
OPTIM = "paged_adamw_8bit"
|
| 78 |
+
|
| 79 |
+
# Learning rate scheduler type (e.g., 'linear', 'cosine', etc.)
|
| 80 |
+
LR_SCHEDULER_TYPE = "linear"
|
| 81 |
+
|
| 82 |
+
# Maximum number of training steps, overrides 'num_train_epochs' if set to a positive number
|
| 83 |
+
# Setting MAX_STEPS = -1 in training arguments for SFTTrainer means that the number of steps will be determined by the
|
| 84 |
+
# number of epochs, the size of the dataset, the batch size, and the number of GPUs1. This is the default behavior
|
| 85 |
+
# when MAX_STEPS is not specified or set to a negative value2.
|
| 86 |
+
MAX_STEPS = -1
|
| 87 |
+
|
| 88 |
+
# Ratio of the total number of training steps used for linear warmup
|
| 89 |
+
WARMUP_RATIO = 0.03
|
| 90 |
+
|
| 91 |
+
# Whether to group sequences into batches with the same length to save memory and increase speed
|
| 92 |
+
GROUP_BY_LENGTH = False
|
| 93 |
+
|
| 94 |
+
# Save a model checkpoint every X update steps
|
| 95 |
+
SAVE_STEPS = 50
|
| 96 |
+
|
| 97 |
+
# Log training information every X update steps
|
| 98 |
+
LOGGING_STEPS = 25
|
| 99 |
+
|
| 100 |
+
PACKING = False
|
| 101 |
+
|
| 102 |
+
# Evaluation strategy during training ("steps", "epoch, "no")
|
| 103 |
+
Evaluation_STRATEGY = "steps"
|
| 104 |
+
|
| 105 |
+
# Number of update steps between two evaluations if `evaluation_strategy="steps"`.
|
| 106 |
+
# Will default to the same value as `logging_steps` if not set.
|
| 107 |
+
EVALUATION_STEPS = 5
|
| 108 |
+
|
| 109 |
+
# Maximum number of tokens per sample in the dataset
|
| 110 |
+
MAX_TOKEN_COUNT = 1024
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if __name__=="__main__":
|
| 114 |
+
pass
|
my_model/fine_tuner/fine_tuning_data/fine_tuning_data_detic.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:77089f24dd5414b0d1dcb5b8f3b34aac3daea86e68c1c70e2da6490482ac9d4b
|
| 3 |
+
size 54670629
|
my_model/fine_tuner/fine_tuning_data/fine_tuning_data_yolov5.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a44d22827c212a9d7a30bb3fd94cb7d7ad82a968a55eaa09e0ff5a61f85fde05
|
| 3 |
+
size 14547559
|
my_model/fine_tuner/fine_tuning_data/read_me.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
The data files 'fine_tuning_data_detic.csv' and 'fine_tuning_data_yolov5.csv' are the result of the preparation and
|
| 2 |
+
filtration after performing below steps:
|
| 3 |
+
|
| 4 |
+
- Generate the captions for all the images.
|
| 5 |
+
- Delete all samples with corrupted or rubbish data. (Please refer to the report for details)
|
| 6 |
+
- Run object detection models ('yolov5' and 'detic') and generate the corresponding objects for the images corresponding to the remaining samples.
|
| 7 |
+
- Convert all the question, answer, caption, objects together with the system prompt into the desired template for all
|
| 8 |
+
the samples (Please refer to the report for the detailed template design).
|
my_model/fine_tuner/fine_tuning_data_handler.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from my_model.utilities import is_pycharm
|
| 2 |
+
import seaborn as sns
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from datasets import Dataset, load_dataset
|
| 5 |
+
import fine_tuning_config as config
|
| 6 |
+
from my_model.LLAMA2.LLAMA2_model import Llama2ModelManager
|
| 7 |
+
from typing import Tuple
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class FinetuningDataHandler:
|
| 12 |
+
"""
|
| 13 |
+
A class dedicated to handling data for fine-tuning language models. It manages loading,
|
| 14 |
+
inspecting, preparing, and splitting the dataset, specifically designed to filter out
|
| 15 |
+
data samples exceeding a specified token count limit. This is crucial for models with
|
| 16 |
+
token count constraints and it helps control the level of GPU RAM tolernace based on the number of tokens,
|
| 17 |
+
ensuring efficient and effective model fine-tuning.
|
| 18 |
+
|
| 19 |
+
Attributes:
|
| 20 |
+
tokenizer (AutoTokenizer): Tokenizer used for tokenizing the dataset.
|
| 21 |
+
dataset_file (str): File path to the dataset.
|
| 22 |
+
max_token_count (int): Maximum allowable token count per data sample.
|
| 23 |
+
|
| 24 |
+
Methods:
|
| 25 |
+
load_llm_tokenizer(): Loads the LLM tokenizer and adds special tokens, if not already loaded.
|
| 26 |
+
load_dataset(): Loads the dataset from a specified file path.
|
| 27 |
+
plot_tokens_count_distribution(token_counts, title): Plots the distribution of token counts in the dataset.
|
| 28 |
+
filter_dataset_by_indices(dataset, valid_indices): Filters the dataset based on valid indices, removing samples exceeding token limits.
|
| 29 |
+
get_token_counts(dataset): Calculates token counts for each sample in the dataset.
|
| 30 |
+
prepare_dataset(): Tokenizes and filters the dataset, preparing it for training. Also visualizes token count distribution before and after filtering.
|
| 31 |
+
split_dataset_for_train_eval(dataset): Divides the dataset into training and evaluation sets.
|
| 32 |
+
inspect_prepare_split_data(): Coordinates the data preparation and splitting process for fine-tuning.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, tokenizer: AutoTokenizer = None, dataset_file: str = config.DATASET_FILE) -> None:
|
| 36 |
+
"""
|
| 37 |
+
Initializes the FinetuningDataHandler class.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
tokenizer (AutoTokenizer): Tokenizer to use for tokenizing the dataset.
|
| 41 |
+
dataset_file (str): Path to the dataset file.
|
| 42 |
+
"""
|
| 43 |
+
self.tokenizer = tokenizer # The tokenizer used for processing the dataset.
|
| 44 |
+
self.dataset_file = dataset_file # Path to the fine-tuning dataset file.
|
| 45 |
+
self.max_token_count = config.MAX_TOKEN_COUNT # Max token count for filtering.
|
| 46 |
+
|
| 47 |
+
def load_llm_tokenizer(self):
|
| 48 |
+
"""
|
| 49 |
+
Loads the LLM tokenizer and adds special tokens, if not already loaded.
|
| 50 |
+
If the tokenizer is already loaded, this method does nothing.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
if self.tokenizer is None:
|
| 54 |
+
llm_manager = Llama2ModelManager() # Initialize Llama2 model manager.
|
| 55 |
+
# we only need the tokenizer for the data inspection not the model itself.
|
| 56 |
+
self.tokenizer = llm_manager.load_tokenizer()
|
| 57 |
+
llm_manager.add_special_tokens() # Add special tokens specific to LLAMA2 vocab for efficient tokenization.
|
| 58 |
+
|
| 59 |
+
def load_dataset(self) -> Dataset:
|
| 60 |
+
"""
|
| 61 |
+
Loads the dataset from the specified file path. The dataset is expected to be in CSV format.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Dataset: The loaded dataset, ready for processing.
|
| 65 |
+
"""
|
| 66 |
+
return load_dataset('csv', data_files=self.dataset_file)
|
| 67 |
+
|
| 68 |
+
def plot_tokens_count_distribution(self, token_counts: list, title: str = "Token Count Distribution") -> None:
|
| 69 |
+
"""
|
| 70 |
+
Plots the distribution of token counts in the dataset for visualization purposes.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
token_counts (list): List of token counts, each count representing the number of tokens in a dataset sample.
|
| 74 |
+
title (str): Title for the plot, highlighting the nature of the distribution.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
if is_pycharm(): # Ensuring compatibility with PyCharm's environment for interactive plot.
|
| 78 |
+
import matplotlib
|
| 79 |
+
matplotlib.use('TkAgg') # Set the backend to 'TkAgg'
|
| 80 |
+
import matplotlib.pyplot as plt
|
| 81 |
+
sns.set_style("whitegrid")
|
| 82 |
+
plt.figure(figsize=(15, 6))
|
| 83 |
+
plt.hist(token_counts, bins=50, color='#3498db', edgecolor='black')
|
| 84 |
+
plt.title(title, fontsize=16)
|
| 85 |
+
plt.xlabel("Number of Tokens", fontsize=14)
|
| 86 |
+
plt.ylabel("Number of Samples", fontsize=14)
|
| 87 |
+
plt.xticks(fontsize=12)
|
| 88 |
+
plt.yticks(fontsize=12)
|
| 89 |
+
plt.tight_layout()
|
| 90 |
+
plt.show()
|
| 91 |
+
|
| 92 |
+
def filter_dataset_by_indices(self, dataset: Dataset, valid_indices: list) -> Dataset:
|
| 93 |
+
"""
|
| 94 |
+
Filters the dataset based on a list of valid indices. This method is used to exclude
|
| 95 |
+
data samples that have a token count exceeding the specified maximum token count.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
dataset (Dataset): The dataset to be filtered.
|
| 99 |
+
valid_indices (list): Indices of samples with token counts within the limit.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Dataset: Filtered dataset containing only samples with valid indices.
|
| 103 |
+
"""
|
| 104 |
+
return dataset['train'].select(valid_indices) # Select only samples with valid indices based on token count.
|
| 105 |
+
|
| 106 |
+
def get_token_counts(self, dataset):
|
| 107 |
+
"""
|
| 108 |
+
Calculates and returns the token counts for each sample in the dataset.
|
| 109 |
+
This function assumes the dataset has a 'train' split and a 'text' field.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
dataset (Dataset): The dataset for which to count tokens.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
List[int]: List of token counts per sample in the dataset.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
if 'train' in dataset:
|
| 119 |
+
return [len(self.tokenizer.tokenize(s)) for s in dataset["train"]["text"]]
|
| 120 |
+
else:
|
| 121 |
+
# After filtering the samples with unacceptable token count, the dataset is already
|
| 122 |
+
# dataset = dataset['train']
|
| 123 |
+
return [len(self.tokenizer.tokenize(s)) for s in dataset["text"]]
|
| 124 |
+
|
| 125 |
+
def prepare_dataset(self) -> Tuple[Dataset, Dataset]:
|
| 126 |
+
"""
|
| 127 |
+
Prepares the dataset for fine-tuning by tokenizing the data and filtering out samples
|
| 128 |
+
that exceed the maximum used context window (configurable through max_token_count).
|
| 129 |
+
It also visualizes the token count distribution before and after filtering.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Tuple[Dataset, Dataset]: The train and evaluate datasets, post-filtering.
|
| 133 |
+
"""
|
| 134 |
+
dataset = self.load_dataset()
|
| 135 |
+
self.load_llm_tokenizer()
|
| 136 |
+
|
| 137 |
+
# Count tokens in each dataset sample before filtering
|
| 138 |
+
token_counts_before_filtering = self.get_token_counts(dataset)
|
| 139 |
+
# Plot token count distribution before filtering for visualization.
|
| 140 |
+
self.plot_tokens_count_distribution(token_counts_before_filtering, "Token Count Distribution Before Filtration")
|
| 141 |
+
# Identify valid indices based on max token count.
|
| 142 |
+
valid_indices = [i for i, count in enumerate(token_counts_before_filtering) if count <= self.max_token_count]
|
| 143 |
+
# Filter the dataset to exclude samples with excessive token counts.
|
| 144 |
+
filtered_dataset = self.filter_dataset_by_indices(dataset, valid_indices)
|
| 145 |
+
|
| 146 |
+
token_counts_after_filtering = self.get_token_counts(filtered_dataset)
|
| 147 |
+
self.plot_tokens_count_distribution(token_counts_after_filtering, "Token Count Distribution After Filtration")
|
| 148 |
+
|
| 149 |
+
return self.split_dataset_for_train_eval(filtered_dataset) # split the dataset into training and evaluation.
|
| 150 |
+
|
| 151 |
+
def split_dataset_for_train_eval(self, dataset) -> Tuple[Dataset, Dataset]:
|
| 152 |
+
"""
|
| 153 |
+
Splits the dataset into training and evaluation datasets.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
dataset (Dataset): The dataset to split.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
tuple[Dataset, Dataset]: The split training and evaluation datasets.
|
| 160 |
+
"""
|
| 161 |
+
split_data = dataset.train_test_split(test_size=config.TEST_SIZE, shuffle=True, seed=config.SEED)
|
| 162 |
+
train_data, eval_data = split_data['train'], split_data['test']
|
| 163 |
+
return train_data, eval_data
|
| 164 |
+
|
| 165 |
+
def inspect_prepare_split_data(self) -> tuple[Dataset, Dataset]:
|
| 166 |
+
"""
|
| 167 |
+
Orchestrates the process of inspecting, preparing, and splitting the dataset for fine-tuning.
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
tuple[Dataset, Dataset]: The prepared training and evaluation datasets.
|
| 171 |
+
"""
|
| 172 |
+
return self.prepare_dataset()
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# Example usage
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
|
| 178 |
+
# Please uncomment the below lines to test the data prep.
|
| 179 |
+
#data_handler = FinetuningDataHandler()
|
| 180 |
+
#fine_tuning_data_train, fine_tuning_data_eval = data_handler.inspect_prepare_split_data()
|
| 181 |
+
#print(fine_tuning_data_train, fine_tuning_data_eval)
|
| 182 |
+
pass
|
my_model/object_detection.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from transformers import AutoImageProcessor, AutoModelForObjectDetection
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
import os
|
| 8 |
+
from utilities import get_path, show_image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ObjectDetector:
|
| 12 |
+
"""
|
| 13 |
+
A class for detecting objects in images using models like Detic and YOLOv5.
|
| 14 |
+
|
| 15 |
+
This class supports loading and using different object detection models to identify objects
|
| 16 |
+
in images and draw bounding boxes around them.
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
model (torch.nn.Module): The loaded object detection model.
|
| 20 |
+
processor (transformers.AutoImageProcessor): Processor for the Detic model.
|
| 21 |
+
model_name (str): Name of the model used for detection.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
"""
|
| 26 |
+
Initializes the ObjectDetector class with default values.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
self.model = None
|
| 30 |
+
self.processor = None
|
| 31 |
+
self.model_name = None
|
| 32 |
+
|
| 33 |
+
def load_model(self, model_name='detic', pretrained=True, model_version='yolov5s'):
|
| 34 |
+
"""
|
| 35 |
+
Load the specified object detection model.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
model_name (str): Name of the model to load. Options are 'detic' and 'yolov5'.
|
| 39 |
+
pretrained (bool): Boolean indicating if a pretrained model should be used.
|
| 40 |
+
model_version (str): Version of the YOLOv5 model, applicable only when using YOLOv5.
|
| 41 |
+
|
| 42 |
+
Raises:
|
| 43 |
+
ValueError: If an unsupported model name is provided.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
self.model_name = model_name
|
| 47 |
+
if model_name == 'detic':
|
| 48 |
+
self._load_detic_model(pretrained)
|
| 49 |
+
elif model_name == 'yolov5':
|
| 50 |
+
self._load_yolov5_model(pretrained, model_version)
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError(f"Unsupported model name: {model_name}")
|
| 53 |
+
|
| 54 |
+
def _load_detic_model(self, pretrained):
|
| 55 |
+
"""
|
| 56 |
+
Load the Detic model.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
pretrained (bool): If True, load a pretrained model.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
model_path = get_path('deformable-detr-detic', 'models')
|
| 64 |
+
self.processor = AutoImageProcessor.from_pretrained(model_path)
|
| 65 |
+
self.model = AutoModelForObjectDetection.from_pretrained(model_path)
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f"Error loading Detic model: {e}")
|
| 68 |
+
raise
|
| 69 |
+
|
| 70 |
+
def _load_yolov5_model(self, pretrained, model_version):
|
| 71 |
+
"""
|
| 72 |
+
Load the YOLOv5 model.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
pretrained (bool): If True, load a pretrained model.
|
| 76 |
+
model_version (str): Version of the YOLOv5 model.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
model_path = get_path('yolov5', 'models')
|
| 81 |
+
if model_path and os.path.exists(model_path):
|
| 82 |
+
self.model = torch.hub.load(model_path, model_version, pretrained=pretrained, source='local')
|
| 83 |
+
else:
|
| 84 |
+
self.model = torch.hub.load('ultralytics/yolov5', model_version, pretrained=pretrained)
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Error loading YOLOv5 model: {e}")
|
| 87 |
+
raise
|
| 88 |
+
|
| 89 |
+
def process_image(self, image_path):
|
| 90 |
+
"""
|
| 91 |
+
Process the image from the given path.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
image_path (str): Path to the image file.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Image.Image: Processed image in RGB format.
|
| 98 |
+
|
| 99 |
+
Raises:
|
| 100 |
+
Exception: If an error occurs during image processing.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
with Image.open(image_path) as image:
|
| 105 |
+
return image.convert("RGB")
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"Error processing image: {e}")
|
| 108 |
+
raise
|
| 109 |
+
|
| 110 |
+
def detect_objects(self, image, threshold=0.4):
|
| 111 |
+
"""
|
| 112 |
+
Detect objects in the given image using the loaded model.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
image (Image.Image): Image in which to detect objects.
|
| 116 |
+
threshold (float): Model detection confidence.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
tuple: A tuple containing a string representation and a list of detected objects.
|
| 120 |
+
|
| 121 |
+
Raises:
|
| 122 |
+
ValueError: If the model is not loaded or the model name is unsupported.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
if self.model_name == 'detic':
|
| 126 |
+
return self._detect_with_detic(image, threshold)
|
| 127 |
+
elif self.model_name == 'yolov5':
|
| 128 |
+
return self._detect_with_yolov5(image, threshold)
|
| 129 |
+
else:
|
| 130 |
+
raise ValueError("Model not loaded or unsupported model name")
|
| 131 |
+
|
| 132 |
+
def _detect_with_detic(self, image, threshold):
|
| 133 |
+
"""
|
| 134 |
+
Detect objects using the Detic model.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
image (Image.Image): The image in which to detect objects.
|
| 138 |
+
threshold (float): The confidence threshold for detections.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
tuple: A tuple containing a string representation and a list of detected objects.
|
| 142 |
+
Each object in the list is represented as a tuple (label_name, box_rounded, certainty).
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
inputs = self.processor(images=image, return_tensors="pt")
|
| 146 |
+
outputs = self.model(**inputs)
|
| 147 |
+
target_sizes = torch.tensor([image.size[::-1]])
|
| 148 |
+
results = self.processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[
|
| 149 |
+
0]
|
| 150 |
+
|
| 151 |
+
detected_objects_str = ""
|
| 152 |
+
detected_objects_list = []
|
| 153 |
+
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
|
| 154 |
+
if score >= threshold:
|
| 155 |
+
label_name = self.model.config.id2label[label.item()]
|
| 156 |
+
box_rounded = [round(coord, 2) for coord in box.tolist()]
|
| 157 |
+
certainty = round(score.item() * 100, 2)
|
| 158 |
+
detected_objects_str += f"{{object: {label_name}, bounding box: {box_rounded}, certainty: {certainty}%}}\n"
|
| 159 |
+
detected_objects_list.append((label_name, box_rounded, certainty))
|
| 160 |
+
return detected_objects_str, detected_objects_list
|
| 161 |
+
|
| 162 |
+
def _detect_with_yolov5(self, image, threshold):
|
| 163 |
+
"""
|
| 164 |
+
Detect objects using the YOLOv5 model.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
image (Image.Image): The image in which to detect objects.
|
| 168 |
+
threshold (float): The confidence threshold for detections.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
tuple: A tuple containing a string representation and a list of detected objects.
|
| 172 |
+
Each object in the list is represented as a tuple (label_name, box_rounded, certainty).
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 176 |
+
results = self.model(cv2_img)
|
| 177 |
+
|
| 178 |
+
detected_objects_str = ""
|
| 179 |
+
detected_objects_list = []
|
| 180 |
+
for *bbox, conf, cls in results.xyxy[0]:
|
| 181 |
+
if conf >= threshold:
|
| 182 |
+
label_name = results.names[int(cls)]
|
| 183 |
+
box_rounded = [round(coord.item(), 2) for coord in bbox]
|
| 184 |
+
certainty = round(conf.item() * 100, 2)
|
| 185 |
+
detected_objects_str += f"{{object: {label_name}, bounding box: {box_rounded}, certainty: {certainty}%}}\n"
|
| 186 |
+
detected_objects_list.append((label_name, box_rounded, certainty))
|
| 187 |
+
return detected_objects_str, detected_objects_list
|
| 188 |
+
|
| 189 |
+
def draw_boxes(self, image, detected_objects, show_confidence=True):
|
| 190 |
+
"""
|
| 191 |
+
Draw bounding boxes around detected objects in the image.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
image (Image.Image): Image on which to draw.
|
| 195 |
+
detected_objects (list): List of detected objects.
|
| 196 |
+
show_confidence (bool): Whether to show confidence scores.
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Image.Image: Image with drawn boxes.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
draw = ImageDraw.Draw(image)
|
| 203 |
+
try:
|
| 204 |
+
font = ImageFont.truetype("arial.ttf", 15)
|
| 205 |
+
except IOError:
|
| 206 |
+
font = ImageFont.load_default()
|
| 207 |
+
|
| 208 |
+
colors = ["red", "green", "blue", "yellow", "purple", "orange"]
|
| 209 |
+
label_color_map = {}
|
| 210 |
+
|
| 211 |
+
for label_name, box, score in detected_objects:
|
| 212 |
+
if label_name not in label_color_map:
|
| 213 |
+
label_color_map[label_name] = colors[len(label_color_map) % len(colors)]
|
| 214 |
+
|
| 215 |
+
color = label_color_map[label_name]
|
| 216 |
+
draw.rectangle(box, outline=color, width=3)
|
| 217 |
+
|
| 218 |
+
label_text = f"{label_name}"
|
| 219 |
+
if show_confidence:
|
| 220 |
+
label_text += f" ({round(score, 2)}%)"
|
| 221 |
+
draw.text((box[0], box[1]), label_text, fill=color, font=font)
|
| 222 |
+
|
| 223 |
+
return image
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def detect_and_draw_objects(image_path, model_type='yolov5', threshold=0.2, show_confidence=True):
|
| 227 |
+
"""
|
| 228 |
+
Detects objects in an image, draws bounding boxes around them, and returns the processed image and a string description.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
image_path (str): Path to the image file.
|
| 232 |
+
model_type (str): Type of model to use for detection ('yolov5' or 'detic').
|
| 233 |
+
threshold (float): Detection threshold.
|
| 234 |
+
show_confidence (bool): Whether to show confidence scores on the output image.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
tuple: A tuple containing the processed Image.Image and a string of detected objects.
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
detector = ObjectDetector()
|
| 241 |
+
detector.load_model(model_type)
|
| 242 |
+
image = detector.process_image(image_path)
|
| 243 |
+
detected_objects_string, detected_objects_list = detector.detect_objects(image, threshold=threshold)
|
| 244 |
+
image_with_boxes = detector.draw_boxes(image, detected_objects_list, show_confidence=show_confidence)
|
| 245 |
+
return image_with_boxes, detected_objects_string
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# Example usage
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
pass
|
| 251 |
+
|
| 252 |
+
# 'Sample_Images' is the folder conatining sample images for demo.
|
| 253 |
+
image_path = get_path('horse.jpg', 'Sample_Images')
|
| 254 |
+
processed_image, objects_string = detect_and_draw_objects(image_path,
|
| 255 |
+
model_type='detic',
|
| 256 |
+
threshold=0.2,
|
| 257 |
+
show_confidence=False)
|
| 258 |
+
show_image(processed_image)
|
| 259 |
+
print("Detected Objects:", objects_string)
|
my_model/utilities.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from collections import Counter
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
from IPython import get_ipython
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class VQADataProcessor:
|
| 14 |
+
"""
|
| 15 |
+
A class to process OKVQA dataset.
|
| 16 |
+
|
| 17 |
+
Attributes:
|
| 18 |
+
questions_file_path (str): The file path for the questions JSON file.
|
| 19 |
+
annotations_file_path (str): The file path for the annotations JSON file.
|
| 20 |
+
questions (list): List of questions extracted from the JSON file.
|
| 21 |
+
annotations (list): List of annotations extracted from the JSON file.
|
| 22 |
+
df_questions (DataFrame): DataFrame created from the questions list.
|
| 23 |
+
df_answers (DataFrame): DataFrame created from the annotations list.
|
| 24 |
+
merged_df (DataFrame): DataFrame resulting from merging questions and answers.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, questions_file_path, annotations_file_path):
|
| 28 |
+
"""
|
| 29 |
+
Initializes the VQADataProcessor with file paths for questions and annotations.
|
| 30 |
+
|
| 31 |
+
Parameters:
|
| 32 |
+
questions_file_path (str): The file path for the questions JSON file.
|
| 33 |
+
annotations_file_path (str): The file path for the annotations JSON file.
|
| 34 |
+
"""
|
| 35 |
+
self.questions_file_path = questions_file_path
|
| 36 |
+
self.annotations_file_path = annotations_file_path
|
| 37 |
+
self.questions, self.annotations = self.read_json_files()
|
| 38 |
+
self.df_questions = pd.DataFrame(self.questions)
|
| 39 |
+
self.df_answers = pd.DataFrame(self.annotations)
|
| 40 |
+
self.merged_df = None
|
| 41 |
+
|
| 42 |
+
def read_json_files(self):
|
| 43 |
+
"""
|
| 44 |
+
Reads the JSON files for questions and annotations.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
tuple: A tuple containing two lists: questions and annotations.
|
| 48 |
+
"""
|
| 49 |
+
with open(self.questions_file_path, 'r') as file:
|
| 50 |
+
data = json.load(file)
|
| 51 |
+
questions = data['questions']
|
| 52 |
+
|
| 53 |
+
with open(self.annotations_file_path, 'r') as file:
|
| 54 |
+
data = json.load(file)
|
| 55 |
+
annotations = data['annotations']
|
| 56 |
+
|
| 57 |
+
return questions, annotations
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def find_most_frequent(my_list):
|
| 61 |
+
"""
|
| 62 |
+
Finds the most frequent item in a list.
|
| 63 |
+
|
| 64 |
+
Parameters:
|
| 65 |
+
my_list (list): A list of items.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
The most frequent item in the list. Returns None if the list is empty.
|
| 69 |
+
"""
|
| 70 |
+
if not my_list:
|
| 71 |
+
return None
|
| 72 |
+
counter = Counter(my_list)
|
| 73 |
+
most_common = counter.most_common(1)
|
| 74 |
+
return most_common[0][0]
|
| 75 |
+
|
| 76 |
+
def merge_dataframes(self):
|
| 77 |
+
"""
|
| 78 |
+
Merges the questions and answers DataFrames on 'question_id' and 'image_id'.
|
| 79 |
+
"""
|
| 80 |
+
self.merged_df = pd.merge(self.df_questions, self.df_answers, on=['question_id', 'image_id'])
|
| 81 |
+
|
| 82 |
+
def join_words_with_hyphen(self, sentence):
|
| 83 |
+
|
| 84 |
+
return '-'.join(sentence.split())
|
| 85 |
+
|
| 86 |
+
def process_answers(self):
|
| 87 |
+
"""
|
| 88 |
+
Processes the answers by extracting raw and processed answers and finding the most frequent ones.
|
| 89 |
+
"""
|
| 90 |
+
if self.merged_df is not None:
|
| 91 |
+
self.merged_df['raw_answers'] = self.merged_df['answers'].apply(lambda x: [ans['raw_answer'] for ans in x])
|
| 92 |
+
self.merged_df['processed_answers'] = self.merged_df['answers'].apply(
|
| 93 |
+
lambda x: [ans['answer'] for ans in x])
|
| 94 |
+
self.merged_df['most_frequent_raw_answer'] = self.merged_df['raw_answers'].apply(self.find_most_frequent)
|
| 95 |
+
self.merged_df['most_frequent_processed_answer'] = self.merged_df['processed_answers'].apply(
|
| 96 |
+
self.find_most_frequent)
|
| 97 |
+
self.merged_df.drop(columns=['answers'], inplace=True)
|
| 98 |
+
else:
|
| 99 |
+
print("DataFrames have not been merged yet.")
|
| 100 |
+
|
| 101 |
+
# Apply the function to the 'most_frequent_processed_answer' column
|
| 102 |
+
self.merged_df['single_word_answers'] = self.merged_df['most_frequent_processed_answer'].apply(
|
| 103 |
+
self.join_words_with_hyphen)
|
| 104 |
+
|
| 105 |
+
def get_processed_data(self):
|
| 106 |
+
"""
|
| 107 |
+
Retrieves the processed DataFrame.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
DataFrame: The processed DataFrame. Returns None if the DataFrame is empty or not processed.
|
| 111 |
+
"""
|
| 112 |
+
if self.merged_df is not None:
|
| 113 |
+
return self.merged_df
|
| 114 |
+
else:
|
| 115 |
+
print("DataFrame is empty or not processed yet.")
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
def save_to_csv(self, df, saved_file_name):
|
| 119 |
+
|
| 120 |
+
if saved_file_name is not None:
|
| 121 |
+
if ".csv" not in saved_file_name:
|
| 122 |
+
df.to_csv(os.path.join(saved_file_name, ".csv"), index=None)
|
| 123 |
+
|
| 124 |
+
else:
|
| 125 |
+
df.to_csv(saved_file_name, index=None)
|
| 126 |
+
|
| 127 |
+
else:
|
| 128 |
+
df.to_csv("data.csv", index=None)
|
| 129 |
+
|
| 130 |
+
def display_dataframe(self):
|
| 131 |
+
"""
|
| 132 |
+
Displays the processed DataFrame.
|
| 133 |
+
"""
|
| 134 |
+
if self.merged_df is not None:
|
| 135 |
+
print(self.merged_df)
|
| 136 |
+
else:
|
| 137 |
+
print("DataFrame is empty.")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def process_okvqa_dataset(questions_file_path, annotations_file_path, save_to_csv=False, saved_file_name=None):
|
| 141 |
+
"""
|
| 142 |
+
Processes the OK-VQA dataset given the file paths for questions and annotations.
|
| 143 |
+
|
| 144 |
+
Parameters:
|
| 145 |
+
questions_file_path (str): The file path for the questions JSON file.
|
| 146 |
+
annotations_file_path (str): The file path for the annotations JSON file.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
DataFrame: The processed DataFrame containing merged and processed VQA data.
|
| 150 |
+
"""
|
| 151 |
+
# Create an instance of the class
|
| 152 |
+
processor = VQADataProcessor(questions_file_path, annotations_file_path)
|
| 153 |
+
|
| 154 |
+
# Process the data
|
| 155 |
+
processor.merge_dataframes()
|
| 156 |
+
processor.process_answers()
|
| 157 |
+
|
| 158 |
+
# Retrieve the processed DataFrame
|
| 159 |
+
processed_data = processor.get_processed_data()
|
| 160 |
+
|
| 161 |
+
if save_to_csv:
|
| 162 |
+
processor.save_to_csv(processed_data, saved_file_name)
|
| 163 |
+
|
| 164 |
+
return processed_data
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def show_image(image):
|
| 168 |
+
"""
|
| 169 |
+
Display an image in various environments (Jupyter, PyCharm, Hugging Face Spaces).
|
| 170 |
+
Handles different types of image inputs (file path, PIL Image, numpy array, OpenCV, PyTorch tensor).
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
image (str or PIL.Image or numpy.ndarray or torch.Tensor): The image to display.
|
| 174 |
+
"""
|
| 175 |
+
in_jupyter = is_jupyter_notebook()
|
| 176 |
+
in_colab = is_google_colab()
|
| 177 |
+
|
| 178 |
+
# Convert image to PIL Image if it's a file path, numpy array, or PyTorch tensor
|
| 179 |
+
if isinstance(image, str):
|
| 180 |
+
|
| 181 |
+
if os.path.isfile(image):
|
| 182 |
+
image = Image.open(image)
|
| 183 |
+
else:
|
| 184 |
+
raise ValueError("File path provided does not exist.")
|
| 185 |
+
elif isinstance(image, np.ndarray):
|
| 186 |
+
|
| 187 |
+
if image.ndim == 3 and image.shape[2] in [3, 4]:
|
| 188 |
+
|
| 189 |
+
image = Image.fromarray(image[..., ::-1] if image.shape[2] == 3 else image)
|
| 190 |
+
else:
|
| 191 |
+
|
| 192 |
+
image = Image.fromarray(image)
|
| 193 |
+
elif torch.is_tensor(image):
|
| 194 |
+
|
| 195 |
+
image = Image.fromarray(image.permute(1, 2, 0).numpy().astype(np.uint8))
|
| 196 |
+
|
| 197 |
+
# Display the image
|
| 198 |
+
if in_jupyter or in_colab:
|
| 199 |
+
|
| 200 |
+
from IPython.display import display
|
| 201 |
+
display(image)
|
| 202 |
+
else:
|
| 203 |
+
|
| 204 |
+
image.show()
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def show_image_with_matplotlib(image):
|
| 209 |
+
if isinstance(image, str):
|
| 210 |
+
image = Image.open(image)
|
| 211 |
+
elif isinstance(image, np.ndarray):
|
| 212 |
+
image = Image.fromarray(image)
|
| 213 |
+
elif torch.is_tensor(image):
|
| 214 |
+
image = Image.fromarray(image.permute(1, 2, 0).numpy().astype(np.uint8))
|
| 215 |
+
|
| 216 |
+
plt.imshow(image)
|
| 217 |
+
plt.axis('off') # Turn off axis numbers
|
| 218 |
+
plt.show()
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def is_jupyter_notebook():
|
| 222 |
+
"""
|
| 223 |
+
Check if the code is running in a Jupyter notebook.
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
bool: True if running in a Jupyter notebook, False otherwise.
|
| 227 |
+
"""
|
| 228 |
+
try:
|
| 229 |
+
from IPython import get_ipython
|
| 230 |
+
if 'IPKernelApp' not in get_ipython().config:
|
| 231 |
+
return False
|
| 232 |
+
if 'ipykernel' in str(type(get_ipython())):
|
| 233 |
+
return True # Running in Jupyter Notebook
|
| 234 |
+
except (NameError, AttributeError):
|
| 235 |
+
return False # Not running in Jupyter Notebook
|
| 236 |
+
|
| 237 |
+
return False # Default to False if none of the above conditions are met
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def is_pycharm():
|
| 241 |
+
return 'PYCHARM_HOSTED' in os.environ
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def is_google_colab():
|
| 245 |
+
return 'COLAB_GPU' in os.environ or 'google.colab' in sys.modules
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def get_path(name, path_type):
|
| 249 |
+
"""
|
| 250 |
+
Generates a path for models, images, or data based on the specified type.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
name (str): The name of the model, image, or data folder/file.
|
| 254 |
+
path_type (str): The type of path needed ('models', 'images', or 'data').
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
str: The full path to the specified resource.
|
| 258 |
+
"""
|
| 259 |
+
# Get the current working directory (assumed to be inside 'code' folder)
|
| 260 |
+
current_dir = os.getcwd()
|
| 261 |
+
|
| 262 |
+
# Get the directory one level up (the parent directory)
|
| 263 |
+
parent_dir = os.path.dirname(current_dir)
|
| 264 |
+
|
| 265 |
+
# Construct the path to the specified folder
|
| 266 |
+
folder_path = os.path.join(parent_dir, path_type)
|
| 267 |
+
|
| 268 |
+
# Construct the full path to the specific resource
|
| 269 |
+
full_path = os.path.join(folder_path, name)
|
| 270 |
+
|
| 271 |
+
return full_path
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
if __name__ == "__main__":
|
| 276 |
+
pass
|
| 277 |
+
#val_data = process_okvqa_dataset('OpenEnded_mscoco_val2014_questions.json', 'mscoco_val2014_annotations.json', save_to_csv=True, saved_file_name="okvqa_val.csv")
|
| 278 |
+
#train_data = process_okvqa_dataset('OpenEnded_mscoco_train2014_questions.json', 'mscoco_train2014_annotations.json', save_to_csv=True, saved_file_name="okvqa_train.csv")
|