Created
July 16, 2020 02:45
-
-
Save xeniaqian94/02a33ae450b3d53047b205525d4d09e6 to your computer and use it in GitHub Desktop.
A script that reranks from an initial search result from an initial set of search result.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ''' | |
| -*- coding: utf-8 -*- | |
| Copyright (C) 2020/5/18 | |
| This script reranks from an initial search result. | |
| Usage: | |
| python rerank_custom_collection.py --search_output_file [OUTPUT_PATH] \ | |
| --qid2query_file [QUERY_FILE_PATH] \ | |
| --passage_text_file [PASSAGE_ID2TEXT_PATH] \ | |
| --model_name_or_path [BERT_BASE_PSG_RETRIEVAL_MODEL_PATH] \ | |
| --device [your_device_setting] --output_path [RERANKER_OUTPUT_PATH] | |
| ''' | |
| import argparse | |
| from typing import Optional, List | |
| from pathlib import Path | |
| import logging | |
| import tqdm | |
| from pydantic import BaseModel, validator | |
| from transformers import (AutoModel, | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| BertForSequenceClassification) | |
| import torch | |
| class RerankInstance(BaseModel): | |
| qid: str | |
| query_text: str | |
| docid: str | |
| passage_text: str | |
| METHOD_CHOICES = ('transformer', 'bm25', 't5', 'seq_class_transformer', | |
| 'random') | |
| def construct_seq_class_transformer(args): | |
| try: | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| pretrained_model_name_or_path=args.model_name_or_path) | |
| except OSError: | |
| try: | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| pretrained_model_name_or_path=args.model_name_or_path, from_tf=True) | |
| except AttributeError: | |
| BertForSequenceClassification.bias = torch.nn.Parameter( | |
| torch.zeros(2)) | |
| BertForSequenceClassification.weight = torch.nn.Parameter( | |
| torch.zeros(2, 768)) | |
| model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path=args.model_name_or_path, | |
| from_tf=True) | |
| model.classifier.weight = BertForSequenceClassification.weight | |
| model.classifier.bias = BertForSequenceClassification.bias | |
| device = torch.device(args.device) | |
| model = model.to(device).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=args.model_name_or_path) | |
| return model, tokenizer, device | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Process some args.') | |
| parser.add_argument('--search_output_file', type=str, default="") | |
| parser.add_argument('--qid2query_file', type=str, | |
| default="") | |
| parser.add_argument('--passage_text_file', type=str, default="") | |
| parser.add_argument('--method', type=str, default="seq_class_transformer") | |
| parser.add_argument('--max_length', type=int, default=512) | |
| parser.add_argument('--model_name_or_path', type=str, default="") | |
| parser.add_argument('--device', type=str, default="") | |
| parser.add_argument('--output_path', type=str, default="") | |
| args = parser.parse_args() | |
| print(args) | |
| # Load reranking examples into a list | |
| qid2query = dict() | |
| passage_id2text = dict() | |
| for line_number, line in enumerate(open(args.qid2query_file, 'r', encoding='utf8')): | |
| qid, query = line.strip().split('\t') | |
| qid2query[qid] = query | |
| for line_number, line in enumerate(open(args.passage_text_file, 'r', encoding='utf8')): | |
| passage_id, passage_text = line.strip().split('\t') | |
| passage_id2text[passage_id] = passage_text | |
| reranking_examples = [] | |
| for line_number, line in enumerate(open(args.search_output_file, 'r', encoding='utf8')): | |
| qid, docid, _ = line.strip().split('\t') | |
| query = qid2query[qid] | |
| passage_text = passage_id2text[docid] | |
| reranking_examples.append( | |
| RerankInstance(qid=qid, query_text=query, docid=docid, passage_text=passage_text)) | |
| # You could customize to different reranker architecture here, e.g. T5 | |
| construct_map = dict( | |
| seq_class_transformer=construct_seq_class_transformer, | |
| ) | |
| reranker_model, reranker_tokenizer, reranker_device = construct_map[args.method]( | |
| args) | |
| # Score each example | |
| result_tuples = [] | |
| for example in tqdm.tqdm(reranking_examples): | |
| encoded_input = reranker_tokenizer.encode_plus(example.query_text, example.passage_text, | |
| max_length=args.max_length, return_token_type_ids=True, | |
| return_tensors='pt') | |
| input_ids = encoded_input['input_ids'].to(reranker_device) | |
| tt_ids = encoded_input['token_type_ids'].to(reranker_device) | |
| output, = reranker_model(input_ids, token_type_ids=tt_ids) | |
| if output.size(1) > 1: | |
| score = torch.nn.functional.log_softmax( | |
| output, 1)[0, -1].item() | |
| else: | |
| score = output.item() | |
| result_tuples += [{ | |
| "qid": example.qid, | |
| "query_text": example.query_text, | |
| "docid": example.docid, | |
| "passage_text": example.passage_text, | |
| "score": score, | |
| }] | |
| result_tuples = sorted(result_tuples, key=lambda tuple: (tuple["qid"], tuple["score"]), reverse=True) | |
| with open(args.output_path, "w") as fout: | |
| for result_tuple in result_tuples: | |
| fout.write( | |
| '{}\t{}\t{}\t{}\t{}\n'.format(result_tuple["qid"], result_tuple["query_text"], result_tuple["docid"], | |
| result_tuple["passage_text"], | |
| result_tuple["score"])) | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment