1、代码下载、依赖包安装

  1. git clone https://github.com/princeton-nlp/SimCSE.git
  2. cd SimCSE/
  3. # 需先安装好 torch 1.7.1
  4. pip install -r requirements.txt

2、Evaluation

(1)执行命令

  1. cd /path/to/SimCSE
  2. # jy: 下载数据集
  3. cd SentEval/data/downstream/
  4. bash download_dataset.sh
  5. cd ../../../
  6. # can evaluate any transformers-based pre-trained models using this
  7. evaluation code.
  8. python evaluation.py \
  9. --model_name_or_path princeton-nlp/sup-simcse-bert-base-uncased \
  10. --pooler cls \
  11. --task_set sts \
  12. --mode test
  • 测试结果如下:

    1. ![image.png](https://cdn.nlark.com/yuque/0/2022/png/25833371/1651015125526-decc9cfc-7455-4e7e-85b9-502102f8d871.png#clientId=u09d61385-2f53-4&crop=0&crop=0&crop=1&crop=1&from=paste&height=96&id=u0dade05b&margin=%5Bobject%20Object%5D&name=image.png&originHeight=325&originWidth=2394&originalType=binary&ratio=1&rotation=0&showTitle=false&size=240342&status=done&style=none&taskId=uf3118859-6e2b-40bf-aa1c-afcb4d28ec5&title=&width=708.9970703125)

    (2)相关参数说明

  • --model_name_or_path: The name or path of a transformers-based pre-trained checkpoint.

  • --pooler: Pooling method. Now we support
    • cls (default): Use the representation of [CLS] token. A linear+activation layer is applied after the representation (it’s in the standard BERT implementation). If you use supervised SimCSE, you should use this option.
    • cls_before_pooler: Use the representation of [CLS] token without the extra linear+activation. If you use unsupervised SimCSE, you should take this option.
    • avg: Average embeddings of the last layer. If you use checkpoints of SBERT/SRoBERTa, you should use this option.
      • SBERT/SRoBERTa 论文:【2019-11-03】Sentence-BERT:Sentence Embeddings using Siamese BERT-Networks
    • avg_top2: Average embeddings of the last two layers.
    • avg_first_last: Average embeddings of the first and last layers. If you use vanilla BERT or RoBERTa, this works the best.
  • --mode: Evaluation mode
    • test (default): To faithfully reproduce our results, you should use this option.
    • dev: Report the development set results.
      • Note that in STS tasks, only STS-B and SICK-R have development sets, so we only report their numbers.
      • It also takes a fast mode for transfer tasks, so the running time is much shorter than the test mode (though numbers are slightly lower).
    • fasttest: It is the same as test, but with a fast mode so the running time is much shorter, but the reported numbers may be lower (only for transfer tasks).
  • --task_set: What set of tasks to evaluate on (if set, it will override --tasks)
    • sts (default): Evaluate on STS tasks, including STS 12~16, STS-B and SICK-R. This is the most commonly-used set of tasks to evaluate the quality of sentence embeddings.
    • transfer: Evaluate on transfer tasks.
    • full: Evaluate on both STS and transfer tasks.
    • na: Manually set tasks by --tasks.
  • --tasks: Specify which dataset(s) to evaluate on. Will be overridden if --task_set is not na. See the code for a full list of tasks.

    3、Train

    (1)执行命令

    ```python cd /path/to/SimCSE

    下载训练数据

    cd data

    For supervised SimCSE: SNLI and MNLI datasets

    bash download_nli.sh

    For unsupervised SimCSE: 1 million sentences from English Wikipedia

    bash download_wiki.sh cd ../

jy: 训练过程即传入参数执行 train.py 脚本

unsupervised: provide a single-GPU (or CPU) example for the unsupervised version

注意: 如果没有 GPU 环境, 则传入 train.py 中的 —fp16 参数需要被注释掉;

bash run_unsup_example.sh

supervised SimCSE: give a multiple-GPU example for the supervised version

bash run_sup_example.sh

  1. <a name="CrxrJ"></a>
  2. ## (2)相关参数说明(即传入 train.py 的参数)
  3. - `--train_file`: Training file path. You can use our provided Wikipedia or NLI data, or you can use your own data with the same format. Support format:
  4. - "txt" files:one line for one sentence
  5. - "csv" files
  6. - 2-column: pair data with no hard negative
  7. - 3-column: pair data with one corresponding hard negative instance.
  8. - `--model_name_or_path`: Pre-trained checkpoints to start with. Support:
  9. - BERT-based models:`bert-base-uncased`, `bert-large-uncased`, etc.
  10. - RoBERTa-based models :`RoBERTa-base`, `RoBERTa-large`, etc.
  11. - `--temp`: Temperature for the contrastive loss(默认为 0.05).
  12. - `--pooler_type`: Pooling method. It's the same as the `--pooler` in the evaluation part.
  13. - `--mlp_only_train`: We have found that for unsupervised SimCSE, it works better to train the model with MLP layer but test the model without it. You should use this argument when training unsupervised SimCSE models.
  14. - `--hard_negative_weight`: If using hard negatives (i.e., there are 3 columns in the training file), this is the logarithm of the weight. For example, if the weight is 1, then this argument should be set as 0 (default value).
  15. - `--do_mlm`: Whether to use the MLM auxiliary objective. If True:
  16. - `--mlm_weight`: Weight for the MLM objective.
  17. - `--mlm_probability`: Masking rate for the MLM objective.
  18. - All the other arguments are standard Huggingface's transformers training arguments. Some of the often-used arguments are: `--output_dir`, `--learning_rate`, `--per_device_train_batch_size`.
  19. - In our example scripts, we also set to evaluate the model on the STS-B development set (need to download the dataset following the evaluation section) and save the best checkpoint.
  20. <a name="eQ8pf"></a>
  21. ## (3)训练参数汇总
  22. ```shell
  23. usage: train.py --xxx xxx_val [...]
  24. optional arguments:
  25. -h, --help show this help message and exit
  26. --model_name_or_path MODEL_NAME_OR_PATH
  27. The model checkpoint for weights initialization.Don\'t
  28. set if you want to train a model from scratch.
  29. --model_type MODEL_TYPE
  30. If training from scratch, pass a model type from the
  31. list: layoutlm, distilbert, albert, bart, camembert,
  32. xlm-roberta, longformer, roberta, squeezebert, bert,
  33. mobilebert, flaubert, xlm, electra, reformer, funnel,
  34. mpnet, tapas
  35. --config_name CONFIG_NAME
  36. Pretrained config name or path if not the same as
  37. model_name
  38. --tokenizer_name TOKENIZER_NAME
  39. Pretrained tokenizer name or path if not the same as
  40. model_name
  41. --cache_dir CACHE_DIR
  42. Where do you want to store the pretrained models
  43. downloaded from huggingface.co
  44. --no_use_fast_tokenizer
  45. Whether to use one of the fast tokenizer (backed by
  46. the tokenizers library) or not.
  47. --model_revision MODEL_REVISION
  48. The specific model version to use (can be a branch
  49. name, tag name or commit id).
  50. --use_auth_token Will use the token generated when running
  51. `transformers-cli login` (necessary to use this script
  52. with private models).
  53. --temp TEMP Temperature for softmax.
  54. --pooler_type POOLER_TYPE
  55. What kind of pooler to use (cls, cls_before_pooler,
  56. avg, avg_top2, avg_first_last).
  57. --hard_negative_weight HARD_NEGATIVE_WEIGHT
  58. The **logit** of weight for hard negatives (only
  59. effective if hard negatives are used).
  60. --do_mlm Whether to use MLM auxiliary objective.
  61. --mlm_weight MLM_WEIGHT
  62. Weight for MLM auxiliary objective (only effective if
  63. --do_mlm).
  64. --mlp_only_train Use MLP only during training
  65. --dataset_name DATASET_NAME
  66. The name of the dataset to use (via the datasets
  67. library).
  68. --dataset_config_name DATASET_CONFIG_NAME
  69. The configuration name of the dataset to use (via the
  70. datasets library).
  71. --overwrite_cache Overwrite the cached training and evaluation sets
  72. --validation_split_percentage VALIDATION_SPLIT_PERCENTAGE
  73. The percentage of the train set used as validation set
  74. in case there\'s no validation split
  75. --preprocessing_num_workers PREPROCESSING_NUM_WORKERS
  76. The number of processes to use for the preprocessing.
  77. --train_file TRAIN_FILE
  78. The training data file (.txt or .csv).
  79. --max_seq_length MAX_SEQ_LENGTH
  80. The maximum total input sequence length after
  81. tokenization. Sequences longer than this will be
  82. truncated.
  83. --pad_to_max_length Whether to pad all samples to `max_seq_length`. If
  84. False, will pad the samples dynamically when batching
  85. to the maximum length in the batch.
  86. --mlm_probability MLM_PROBABILITY
  87. Ratio of tokens to mask for MLM (only effective if
  88. --do_mlm)
  89. --output_dir OUTPUT_DIR
  90. The output directory where the model predictions and
  91. checkpoints will be written.
  92. --overwrite_output_dir
  93. Overwrite the content of the output directory. Use this
  94. to continue training if output_dir points to a
  95. checkpoint directory.
  96. --do_train Whether to run training.
  97. --do_eval Whether to run eval on the dev set.
  98. --do_predict Whether to run predictions on the test set.
  99. --evaluation_strategy {EvaluationStrategy.NO, EvaluationStrategy.STEPS,
  100. EvaluationStrategy.EPOCH}
  101. The evaluation strategy to use.
  102. --prediction_loss_only
  103. When performing evaluation and predictions, only
  104. returns the loss.
  105. --per_device_train_batch_size PER_DEVICE_TRAIN_BATCH_SIZE
  106. Batch size per GPU/TPU core/CPU for training.
  107. --per_device_eval_batch_size PER_DEVICE_EVAL_BATCH_SIZE
  108. Batch size per GPU/TPU core/CPU for evaluation.
  109. --per_gpu_train_batch_size PER_GPU_TRAIN_BATCH_SIZE
  110. Deprecated, the use of `--per_device_train_batch_size`
  111. is preferred. Batch size per GPU/TPU core/CPU for
  112. training.
  113. --per_gpu_eval_batch_size PER_GPU_EVAL_BATCH_SIZE
  114. Deprecated, the use of `--per_device_eval_batch_size`
  115. is preferred. Batch size per GPU/TPU core/CPU for
  116. evaluation.
  117. --gradient_accumulation_steps GRADIENT_ACCUMULATION_STEPS
  118. Number of updates steps to accumulate before
  119. performing a backward/update pass.
  120. --eval_accumulation_steps EVAL_ACCUMULATION_STEPS
  121. Number of predictions steps to accumulate before
  122. moving the tensors to the CPU.
  123. --learning_rate LEARNING_RATE
  124. The initial learning rate for Adam.
  125. --weight_decay WEIGHT_DECAY
  126. Weight decay if we apply some.
  127. --adam_beta1 ADAM_BETA1
  128. Beta1 for Adam optimizer
  129. --adam_beta2 ADAM_BETA2
  130. Beta2 for Adam optimizer
  131. --adam_epsilon ADAM_EPSILON
  132. Epsilon for Adam optimizer.
  133. --max_grad_norm MAX_GRAD_NORM
  134. Max gradient norm.
  135. --num_train_epochs NUM_TRAIN_EPOCHS
  136. Total number of training epochs to perform.
  137. --max_steps MAX_STEPS
  138. If > 0: set total number of training steps to perform.
  139. Override num_train_epochs.
  140. --lr_scheduler_type {SchedulerType.LINEAR,
  141. SchedulerType.COSINE,
  142. SchedulerType.COSINE_WITH_RESTARTS,
  143. SchedulerType.POLYNOMIAL,
  144. SchedulerType.CONSTANT,
  145. SchedulerType.CONSTANT_WITH_WARMUP}
  146. The scheduler type to use.
  147. --warmup_steps WARMUP_STEPS
  148. Linear warmup over warmup_steps.
  149. --logging_dir LOGGING_DIR
  150. Tensorboard log dir.
  151. --logging_first_step Log the first global_step
  152. --logging_steps LOGGING_STEPS
  153. Log every X updates steps.
  154. --save_steps SAVE_STEPS
  155. Save checkpoint every X updates steps.
  156. --save_total_limit SAVE_TOTAL_LIMIT
  157. Limit the total amount of checkpoints.Deletes the
  158. older checkpoints in the output_dir. Default is
  159. unlimited checkpoints
  160. --no_cuda Do not use CUDA even when it is available
  161. --seed SEED random seed for initialization
  162. --fp16 Whether to use 16-bit (mixed) precision (through
  163. NVIDIA Apex) instead of 32-bit
  164. --fp16_opt_level FP16_OPT_LEVEL
  165. For fp16: Apex AMP optimization level selected in
  166. ['O0', 'O1', 'O2', and 'O3']. See details at
  167. https://nvidia.github.io/apex/amp.html
  168. --fp16_backend {auto,amp,apex}
  169. The backend to be used for mixed precision.
  170. --local_rank LOCAL_RANK
  171. For distributed training: local_rank
  172. --tpu_num_cores TPU_NUM_CORES
  173. TPU: Number of TPU cores (automatically passed by
  174. launcher script)
  175. --tpu_metrics_debug Deprecated, the use of `--debug` is preferred. TPU:
  176. Whether to print debug metrics
  177. --debug Whether to print debug metrics on TPU
  178. --dataloader_drop_last
  179. Drop the last incomplete batch if it is not divisible
  180. by the batch size.
  181. --eval_steps EVAL_STEPS
  182. Run an evaluation every X steps.
  183. --dataloader_num_workers DATALOADER_NUM_WORKERS
  184. Number of subprocesses to use for data loading
  185. (PyTorch only). 0 means that the data will be loaded
  186. in the main process.
  187. --past_index PAST_INDEX
  188. If >=0, uses the corresponding part of the output as
  189. the past state for next step.
  190. --run_name RUN_NAME An optional descriptor for the run. Notably used for
  191. wandb logging.
  192. --disable_tqdm DISABLE_TQDM
  193. Whether or not to disable the tqdm progress bars.
  194. --no_remove_unused_columns
  195. Remove columns not required by the model when using an
  196. nlp.Dataset.
  197. --label_names LABEL_NAMES [LABEL_NAMES ...]
  198. The list of keys in your dictionary of inputs that
  199. correspond to the labels.
  200. --load_best_model_at_end
  201. Whether or not to load the best model found during
  202. training at the end of training.
  203. --metric_for_best_model METRIC_FOR_BEST_MODEL
  204. The metric to use to compare two different models.
  205. --greater_is_better GREATER_IS_BETTER
  206. Whether the `metric_for_best_model` should be
  207. maximized or not.
  208. --ignore_data_skip When resuming training, whether or not to skip the
  209. first epochs and batches to get to the same training
  210. data.
  211. --sharded_ddp Whether or not to use sharded DDP training (in
  212. distributed training only).
  213. --deepspeed DEEPSPEED
  214. Enable deepspeed and pass the path to deepspeed json
  215. config file (e.g. ds_config.json)
  216. --label_smoothing_factor LABEL_SMOOTHING_FACTOR
  217. The label smoothing epsilon to apply (zero means no
  218. label smoothing).
  219. --adafactor Whether or not to replace Adam by Adafactor.
  220. --eval_transfer Evaluate transfer task dev sets (in validation).

4、训练过程的注意事项

(1)训练结束后报错

return comm.gather(inputs, ctx.dim, ctx.target_device)
File "/home/huangjiayue/anaconda3/envs/jy-pharm_paper_search/lib/python3.6/site-packages/torch/nn/parallel/comm.py", line 230, in gather
return torch._C._gather(tensors, dim, destination)
RuntimeError: Input tensor at index 1 has invalid shape [62, 62], but expected [62, 63]
  • 原因分析:https://github.com/princeton-nlp/SimCSE/issues/147
  • It seems to be a GPU communication-related error. Maybe try limiting the number of GPUs to 1 and try again.

    (2)训练过程加载数据集时出现报错

    Traceback (most recent call last):
    File "train.py", line 585, in <module>
      main()
    File "train.py", line 310, in main
      datasets = load_dataset(extension, data_files=data_files, cache_dir="./data/")
    File "/home/huangjiayue/anaconda3/envs/jy-pharm_paper_search/lib/python3.6/site-packages/datasets/load.py", line 591, in load_dataset
      path, script_version=script_version, download_config=download_config, download_mode=download_mode, dataset=True
    File "/home/huangjiayue/anaconda3/envs/jy-pharm_paper_search/lib/python3.6/site-packages/datasets/load.py", line 267, in prepare_module
      local_path = cached_path(file_path, download_config=download_config)
    File "/home/huangjiayue/anaconda3/envs/jy-pharm_paper_search/lib/python3.6/site-packages/datasets/utils/file_utils.py", line 343, in cached_path
      max_retries=download_config.max_retries,
    File "/home/huangjiayue/anaconda3/envs/jy-pharm_paper_search/lib/python3.6/site-packages/datasets/utils/file_utils.py", line 617, in get_from_cache
      raise ConnectionError("Couldn't reach {}".format(url))
    ConnectionError: Couldn't reach https://raw.githubusercontent.com/huggingface/datasets/1.2.1/datasets/text/text.py
    Traceback (most recent call last):
    File "/home/huangjiayue/anaconda3/envs/jy-pharm_paper_search/lib/python3.6/runpy.py", line 193, in _run_module_as_main
      "__main__", mod_spec)
    File "/home/huangjiayue/anaconda3/envs/jy-pharm_paper_search/lib/python3.6/runpy.py", line 85, in _run_code
      exec(code, run_globals)
    File "/home/huangjiayue/anaconda3/envs/jy-pharm_paper_search/lib/python3.6/site-packages/torch/distributed/launch.py", line 260, in <module>
      main()
    File "/home/huangjiayue/anaconda3/envs/jy-pharm_paper_search/lib/python3.6/site-packages/torch/distributed/launch.py", line 256, in main
      cmd=cmd)
    
  • 解决方法:在Simcse项目下创建text/text.py ```python “”” jy: 该文件下载自: https://raw.githubusercontent.com/huggingface/datasets/1.2.1/datasets/text/text.py “”” import logging from dataclasses import dataclass import pyarrow as pa import datasets

logger = logging.getLogger(name) FEATURES = datasets.Features( { “text”: datasets.Value(“string”), } )

@dataclass class TextConfig(datasets.BuilderConfig): “””BuilderConfig for text files.”””

encoding: str = "utf-8"
chunksize: int = 10 << 20  # 10MB

class Text(datasets.ArrowBasedBuilder): BUILDER_CONFIG_CLASS = TextConfig

def _info(self):
    return datasets.DatasetInfo(features=FEATURES)

def _split_generators(self, dl_manager):
    """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]].

    If str or List[str], then the dataset returns only the 'train' split.
    If dict, then keys should be from the `datasets.Split` enum.
    """
    if not self.config.data_files:
        raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
    data_files = dl_manager.download_and_extract(self.config.data_files)
    if isinstance(data_files, (str, list, tuple)):
        files = data_files
        if isinstance(files, str):
            files = [files]
        return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})]
    splits = []
    for split_name, files in data_files.items():
        if isinstance(files, str):
            files = [files]
        splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
    return splits

def _generate_tables(self, files):
    for file_idx, file in enumerate(files):
        batch_idx = 0
        with open(file, "r", encoding=self.config.encoding) as f:
            while True:
                batch = f.read(self.config.chunksize)
                if not batch:
                    break
                batch += f.readline()  # finish current line
                batch = batch.splitlines()
                pa_table = pa.Table.from_arrays([pa.array(batch)], schema=pa.schema({"text": pa.string()}))
                # Uncomment for debugging (will print the Arrow table size and elements)
                # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
                # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
                yield (file_idx, batch_idx), pa_table
                batch_idx += 1


<a name="PpFvC"></a>
# 5、训练过程解读
<a name="qrkUa"></a>
## (1)unsupervised
```shell
:<<!
python train.py \
    --model_name_or_path bert-base-uncased \
    --train_file data/wiki1m_for_simcse.txt \
    --output_dir result/my-unsup-simcse-bert-base-uncased \
    --num_train_epochs 1 \
    --per_device_train_batch_size 64 \
    --learning_rate 3e-5 \
    --max_seq_length 32 \
    --evaluation_strategy steps \
    --metric_for_best_model stsb_spearman \
    --load_best_model_at_end \
    --eval_steps 125 \
    --pooler_type cls \
    --mlp_only_train \
    --overwrite_output_dir \
    --temp 0.05 \
    --do_train \
    --do_eval \
    --fp16 \
    "$@"
!

# jy: 如果没有 GPU 环境, 则传入 train.py 中的 --fp16 参数需要被注释掉;
python train.py \
    --model_name_or_path bert-base-uncased \
    --train_file data/wiki1m_for_simcse-20w.txt \
    --output_dir result/my-unsup-simcse-bert-base-uncased \
    --num_train_epochs 1 \
    --per_device_train_batch_size 64 \
    --learning_rate 3e-5 \
    --max_seq_length 32 \
    --evaluation_strategy steps \
    --metric_for_best_model stsb_spearman \
    --load_best_model_at_end \
    --eval_steps 125 \
    --pooler_type cls \
    --mlp_only_train \
    --overwrite_output_dir \
    --temp 0.05 \
    --do_train \
    --do_eval \
    "$@"