by sonoisa
Open source · 45k downloads · 51 likes
This model is an enhanced version of the Japanese Sentence-BERT, specifically designed to encode Japanese sentences into high-quality semantic vectors. It employs a more effective loss function than its predecessor, enabling it to achieve a 1.5 to 2-point improvement in accuracy on internal datasets. Trained on a pre-trained Japanese BERT base model, it excels in contextual understanding and semantic similarity between sentences. Its primary use cases include information retrieval, text classification, and plagiarism detection in Japanese. What sets it apart is its ability to generate more precise and consistent vector representations for applications requiring fine-grained analysis of the Japanese language.
This is a Japanese sentence-BERT model.
日本語用Sentence-BERTモデル(バージョン2)です。
バージョン1よりも良いロス関数であるMultipleNegativesRankingLossを用いて学習した改良版です。
手元の非公開データセットでは、バージョン1よりも1.5〜2ポイントほど精度が高い結果が得られました。
事前学習済みモデルとしてcl-tohoku/bert-base-japanese-whole-word-maskingを利用しました。
従って、推論の実行にはfugashiとipadicが必要です(pip install fugashi ipadic)。
https://qiita.com/sonoisa/items/1df94d0a98cd4f209051
モデル名を"sonoisa/sentence-bert-base-ja-mean-tokens-v2"に書き換えれば、本モデルを利用した挙動になります。
from transformers import BertJapaneseTokenizer, BertModel
import torch
class SentenceBertJapanese:
def __init__(self, model_name_or_path, device=None):
self.tokenizer = BertJapaneseTokenizer.from_pretrained(model_name_or_path)
self.model = BertModel.from_pretrained(model_name_or_path)
self.model.eval()
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device)
self.model.to(device)
def _mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
@torch.no_grad()
def encode(self, sentences, batch_size=8):
all_embeddings = []
iterator = range(0, len(sentences), batch_size)
for batch_idx in iterator:
batch = sentences[batch_idx:batch_idx + batch_size]
encoded_input = self.tokenizer.batch_encode_plus(batch, padding="longest",
truncation=True, return_tensors="pt").to(self.device)
model_output = self.model(**encoded_input)
sentence_embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"]).to('cpu')
all_embeddings.extend(sentence_embeddings)
# return torch.stack(all_embeddings).numpy()
return torch.stack(all_embeddings)
MODEL_NAME = "sonoisa/sentence-bert-base-ja-mean-tokens-v2" # <- v2です。
model = SentenceBertJapanese(MODEL_NAME)
sentences = ["暴走したAI", "暴走した人工知能"]
sentence_embeddings = model.encode(sentences, batch_size=8)
print("Sentence embeddings:", sentence_embeddings)