Skip to content

Instantly share code, notes, and snippets.

@xeniaqian94
Created July 16, 2020 02:45
Show Gist options
  • Select an option

  • Save xeniaqian94/02a33ae450b3d53047b205525d4d09e6 to your computer and use it in GitHub Desktop.

Select an option

Save xeniaqian94/02a33ae450b3d53047b205525d4d09e6 to your computer and use it in GitHub Desktop.
A script that reranks from an initial search result from an initial set of search result.
'''
-*- coding: utf-8 -*-
Copyright (C) 2020/5/18
This script reranks from an initial search result.
Usage:
python rerank_custom_collection.py --search_output_file [OUTPUT_PATH] \
--qid2query_file [QUERY_FILE_PATH] \
--passage_text_file [PASSAGE_ID2TEXT_PATH] \
--model_name_or_path [BERT_BASE_PSG_RETRIEVAL_MODEL_PATH] \
--device [your_device_setting] --output_path [RERANKER_OUTPUT_PATH]
'''
import argparse
from typing import Optional, List
from pathlib import Path
import logging
import tqdm
from pydantic import BaseModel, validator
from transformers import (AutoModel,
AutoTokenizer,
AutoModelForSequenceClassification,
BertForSequenceClassification)
import torch
class RerankInstance(BaseModel):
qid: str
query_text: str
docid: str
passage_text: str
METHOD_CHOICES = ('transformer', 'bm25', 't5', 'seq_class_transformer',
'random')
def construct_seq_class_transformer(args):
try:
model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=args.model_name_or_path)
except OSError:
try:
model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=args.model_name_or_path, from_tf=True)
except AttributeError:
BertForSequenceClassification.bias = torch.nn.Parameter(
torch.zeros(2))
BertForSequenceClassification.weight = torch.nn.Parameter(
torch.zeros(2, 768))
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path=args.model_name_or_path,
from_tf=True)
model.classifier.weight = BertForSequenceClassification.weight
model.classifier.bias = BertForSequenceClassification.bias
device = torch.device(args.device)
model = model.to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=args.model_name_or_path)
return model, tokenizer, device
def main():
parser = argparse.ArgumentParser(description='Process some args.')
parser.add_argument('--search_output_file', type=str, default="")
parser.add_argument('--qid2query_file', type=str,
default="")
parser.add_argument('--passage_text_file', type=str, default="")
parser.add_argument('--method', type=str, default="seq_class_transformer")
parser.add_argument('--max_length', type=int, default=512)
parser.add_argument('--model_name_or_path', type=str, default="")
parser.add_argument('--device', type=str, default="")
parser.add_argument('--output_path', type=str, default="")
args = parser.parse_args()
print(args)
# Load reranking examples into a list
qid2query = dict()
passage_id2text = dict()
for line_number, line in enumerate(open(args.qid2query_file, 'r', encoding='utf8')):
qid, query = line.strip().split('\t')
qid2query[qid] = query
for line_number, line in enumerate(open(args.passage_text_file, 'r', encoding='utf8')):
passage_id, passage_text = line.strip().split('\t')
passage_id2text[passage_id] = passage_text
reranking_examples = []
for line_number, line in enumerate(open(args.search_output_file, 'r', encoding='utf8')):
qid, docid, _ = line.strip().split('\t')
query = qid2query[qid]
passage_text = passage_id2text[docid]
reranking_examples.append(
RerankInstance(qid=qid, query_text=query, docid=docid, passage_text=passage_text))
# You could customize to different reranker architecture here, e.g. T5
construct_map = dict(
seq_class_transformer=construct_seq_class_transformer,
)
reranker_model, reranker_tokenizer, reranker_device = construct_map[args.method](
args)
# Score each example
result_tuples = []
for example in tqdm.tqdm(reranking_examples):
encoded_input = reranker_tokenizer.encode_plus(example.query_text, example.passage_text,
max_length=args.max_length, return_token_type_ids=True,
return_tensors='pt')
input_ids = encoded_input['input_ids'].to(reranker_device)
tt_ids = encoded_input['token_type_ids'].to(reranker_device)
output, = reranker_model(input_ids, token_type_ids=tt_ids)
if output.size(1) > 1:
score = torch.nn.functional.log_softmax(
output, 1)[0, -1].item()
else:
score = output.item()
result_tuples += [{
"qid": example.qid,
"query_text": example.query_text,
"docid": example.docid,
"passage_text": example.passage_text,
"score": score,
}]
result_tuples = sorted(result_tuples, key=lambda tuple: (tuple["qid"], tuple["score"]), reverse=True)
with open(args.output_path, "w") as fout:
for result_tuple in result_tuples:
fout.write(
'{}\t{}\t{}\t{}\t{}\n'.format(result_tuple["qid"], result_tuple["query_text"], result_tuple["docid"],
result_tuple["passage_text"],
result_tuple["score"]))
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment