Commit bb4616ff authored by Bernhard Liebl's avatar Bernhard Liebl
Browse files

refactor DocEmbedder

parent 88754b2b
...@@ -253,16 +253,154 @@ class TSNECallback(openTSNE.callbacks.Callback): ...@@ -253,16 +253,154 @@ class TSNECallback(openTSNE.callbacks.Callback):
def __call__(self, iteration, error, embedding): def __call__(self, iteration, error, embedding):
pass pass
class DocEmbedderFactory:
def __init__(self, session, nlp, doc_encoders={}):
self._session = session
self._nlp = nlp
self._doc_encoders = doc_encoders
def create(self, encoder):
return DocEmbedder(self._session, self._nlp, self._doc_encoders, encoder)
class EmbeddingPlotter:
def __init__(self, session, nlp, gold, aggregator): class DocEmbedder:
def __init__(self, session, nlp, doc_encoders={}, encoder=None):
self._session = session self._session = session
self._nlp = nlp self._nlp = nlp
self._doc_encoders = doc_encoders
self._callbacks = []
self._partition = session.partition("document")
Option = collections.namedtuple("Option", ["name", "token_embedding", "doc_encoder"])
options = []
for k, v in self._session.embeddings.items():
options.append(Option(format_embedding_name(k) + " [token]", v, None))
for k, v in doc_encoders.items():
options.append(Option(format_embedding_name(k) + " [doc]", None, v))
self._options = options
default_option = options[0].name
if encoder is not None:
default_option = options[find_index_by_filter(
[x.name for x in options], encoder)].name
agg_options = ["mean", "median", "max", "min"]
if _display_mode.bokeh:
self._embedding_select = bokeh.models.Select(
title="",
value=default_option,
options=[option.name for option in options])
self._aggregator = bokeh.models.Select(
title="",
value=agg_options[0],
options=agg_options)
def embedding_changed_shim(attr, old, new):
self.embedding_changed()
def aggregator_changed_shim(attr, old, new):
self.aggregator_changed()
self._embedding_select.on_change("value", embedding_changed_shim)
self.embedding_changed()
self._aggregator.on_change("value", aggregator_changed_shim)
else:
self._embedding_select = widgets.Dropdown(
title="",
style={'description_width': 'initial', 'width': 'max'},
value=default_option,
options=[option.name for option in options])
self._aggregator = bokeh.models.Dropdown(
title="",
style={'description_width': 'initial', 'width': 'max'},
value=agg_options[0],
options=agg_options)
self._embedding_select.observe(lambda changed: self.embedding_changed(), names="value")
self._aggregator.observe(lambda changed: self.aggregator_changed(), names="value")
@property
def disabled(self):
return self._embedding_select.disabled
@disabled.setter
def disabled(self, value):
self._embedding_select.disabled = value
self._aggregator.disabled = value
def _change_occured(self):
for cb in self._callbacks:
cb()
def embedding_changed(self):
self._aggregator.visible = self.option.token_embedding is not None
self._change_occured()
def aggregator_changed(self):
self._change_occured()
@property
def session(self):
return self._session
@property
def option(self):
return self._options[[x.name for x in self._options].index(self._embedding_select.value)]
def on_change(self, callback):
self._callbacks.append(callback)
@property
def widget(self):
if _display_mode.bokeh:
return bokeh.layouts.row(self._embedding_select, self._aggregator)
else:
return widgets.HBox([self._embedding_select, self._aggregator])
def display(self):
if _display_mode.bokeh:
bokeh.io.show(lambda doc: doc.add_root(self.widget))
else:
display(self.widget)
@property
def encoder(self):
option = self.option
if option.doc_encoder is not None:
return option.doc_encoder
else:
agg = getattr(np, self._aggregator.value)
return CachedPartitionEncoder(
TokenEmbeddingAggregator(option.token_embedding.factory, agg))
@property
def partition(self):
return self._partition
def mk_query(self, text):
return DummyIndex(self.partition).make_query(text)
def encode(self, docs):
return self.encoder.encode(
prepare_docs(docs, self._nlp), self.partition).unmodified
class EmbeddingPlotter:
def __init__(self, embedder, gold):
self._embedder = embedder
self._gold = gold self._gold = gold
self._current_selection = None self._current_selection = None
self._id_to_doc = dict((doc.unique_id, doc) for doc in self._session.documents) self._id_to_doc = dict((doc.unique_id, doc) for doc in self.session.documents)
self._doc_formatter = DocFormatter(gold) self._doc_formatter = DocFormatter(gold)
DocData = collections.namedtuple("DocData", ["doc", "query", "work"]) DocData = collections.namedtuple("DocData", ["doc", "query", "work"])
...@@ -275,14 +413,7 @@ class EmbeddingPlotter: ...@@ -275,14 +413,7 @@ class EmbeddingPlotter:
query=pattern.phrase, query=pattern.phrase,
work=occ.source.work)) work=occ.source.work))
self._docs = docs self._docs = docs
self._partition = session.partition("document")
self.encoders = dict()
for k, embedding in session.embeddings.items():
self.encoders[format_embedding_name(k) + f" ({aggregator.__name__})"] = CachedPartitionEncoder(
TokenEmbeddingAggregator(embedding.factory, aggregator))
self._doc_emb_tooltips = """ self._doc_emb_tooltips = """
<span style="font-variant:small-caps">@work</span> <span style="font-variant:small-caps">@work</span>
<br> <br>
...@@ -330,14 +461,17 @@ class EmbeddingPlotter: ...@@ -330,14 +461,17 @@ class EmbeddingPlotter:
self._figures = [] self._figures = []
self._figures_html = [] self._figures_html = []
@property
def session(self):
return self._embedder.session
@property @property
def partition(self): def partition(self):
return self._partition return self._embedder.partition
def _compute_source_data(self, embedding, intruder): def _compute_source_data(self, intruder):
encoder = self.encoders[embedding] intruder_doc = self._embedder.mk_query(intruder)
intruder_doc = DummyIndex(self.partition).make_query(intruder)
id_to_doc = self._id_to_doc id_to_doc = self._id_to_doc
query_docs = [] query_docs = []
...@@ -362,8 +496,7 @@ class EmbeddingPlotter: ...@@ -362,8 +496,7 @@ class EmbeddingPlotter:
'work': works, 'work': works,
'query': phrases, 'query': phrases,
'context': contexts, 'context': contexts,
'vector': encoder.encode( 'vector': self._embedder.encode(query_docs)
prepare_docs(query_docs, self._nlp), self.partition).unmodified
} }
include_intruder = not np.any(np.isnan(data['vector'])) include_intruder = not np.any(np.isnan(data['vector']))
...@@ -421,19 +554,10 @@ class EmbeddingPlotter: ...@@ -421,19 +554,10 @@ class EmbeddingPlotter:
script_code.append(x.string) script_code.append(x.string)
self._pw.js_init("\n".join(script_code)) self._pw.js_init("\n".join(script_code))
def mk_plot(self, bokeh_doc, encoder=0, selection=[], locator=None, has_tok_emb=True, plot_width=1200): def mk_plot(self, bokeh_doc, selection=[], locator=None, plot_width=1200):
encoder_names = sorted(self.encoders.keys()) has_tok_emb = self._embedder.option.token_embedding is not None
if isinstance(encoder, str):
encoder = find_index_by_filter(encoder_names, encoder)
if _display_mode.bokeh: if _display_mode.bokeh:
embedding_select = bokeh.models.Select(
title="",
value=encoder_names[encoder],
options=encoder_names,
margin=(0, 20, 0, 0))
intruder_select = bokeh.models.Select( intruder_select = bokeh.models.Select(
title="", title="",
value=self._gold.patterns[0].phrase, value=self._gold.patterns[0].phrase,
...@@ -448,10 +572,6 @@ class EmbeddingPlotter: ...@@ -448,10 +572,6 @@ class EmbeddingPlotter:
options_cb = bokeh.models.CheckboxButtonGroup( options_cb = bokeh.models.CheckboxButtonGroup(
labels=["legend"], active=[0]) labels=["legend"], active=[0])
else: else:
embedding_select = widgets.Dropdown(
value=encoder_names[encoder],
options=encoder_names)
intruder_select = widgets.Dropdown( intruder_select = widgets.Dropdown(
value=[p.phrase for p in self._gold.patterns][0], value=[p.phrase for p in self._gold.patterns][0],
options=[p.phrase for p in self._gold.patterns]) options=[p.phrase for p in self._gold.patterns])
...@@ -465,8 +585,7 @@ class EmbeddingPlotter: ...@@ -465,8 +585,7 @@ class EmbeddingPlotter:
query_tabs.set_title(1, "fixed locator") query_tabs.set_title(1, "fixed locator")
query_tabs.set_title(2, "free locator") query_tabs.set_title(2, "free locator")
source = dict((k, bokeh.models.ColumnDataSource(v)) for k, v in self._compute_source_data( source = dict((k, bokeh.models.ColumnDataSource(v)) for k, v in self._compute_source_data("").items())
embedding_select.value, "").items())
cmap = bokeh.transform.factor_cmap( cmap = bokeh.transform.factor_cmap(
'query', 'query',
...@@ -561,26 +680,28 @@ class EmbeddingPlotter: ...@@ -561,26 +680,28 @@ class EmbeddingPlotter:
set_tok_emb_status("") set_tok_emb_status("")
def update_token_plot(max_token_count=750): def update_token_plot(max_token_count=750):
selected = source['docs'].selected.indices
self._current_selection = [self._docs[i].doc.unique_id for i in selected]
if tok_emb_p is None: if tok_emb_p is None:
return return
embedding = self.encoders[embedding_select.value].embedding embedding = self._embedder.encoder.embedding
if embedding is None: if embedding is None:
clear_token_plot() clear_token_plot()
set_tok_emb_status("No token embedding.")
return return
selected = source['docs'].selected.indices
if not selected: if not selected:
clear_token_plot() clear_token_plot()
set_tok_emb_status("No selection.")
return return
self._current_selection = [self._docs[i].doc.unique_id for i in selected]
token_embedding_data = [] token_embedding_data = []
for i in selected: for i in selected:
doc_data = self._docs[i] doc_data = self._docs[i]
for span in doc_data.doc.spans(self._partition): for span in doc_data.doc.spans(self._embedder.partition):
texts = [token.text for token in span] texts = [token.text for token in span]
for i, token in enumerate(span): for i, token in enumerate(span):
token_embedding_data.append({ token_embedding_data.append({
...@@ -599,7 +720,7 @@ class EmbeddingPlotter: ...@@ -599,7 +720,7 @@ class EmbeddingPlotter:
set_tok_emb_status("Selection is too large.<br>Please select fewer documents.") set_tok_emb_status("Selection is too large.<br>Please select fewer documents.")
return return
token_embedding_vecs = np.array(self._session.word_vec( token_embedding_vecs = np.array(self.session.word_vec(
embedding, [x['token'] for x in token_embedding_data])) embedding, [x['token'] for x in token_embedding_data]))
mag = np.linalg.norm(token_embedding_vecs, axis=1) mag = np.linalg.norm(token_embedding_vecs, axis=1)
...@@ -655,13 +776,18 @@ class EmbeddingPlotter: ...@@ -655,13 +776,18 @@ class EmbeddingPlotter:
intruder = "" intruder = ""
else: else:
intruder = [intruder_select, intruder_free][active - 1].value intruder = [intruder_select, intruder_free][active - 1].value
for k, v in self._compute_source_data( for k, v in self._compute_source_data(intruder).items():
embedding_select.value, intruder).items():
source[k].data = v source[k].data = v
update_token_plot() #update_token_plot()
if not _display_mode.fully_interactive: if not _display_mode.fully_interactive:
self._update_figures_html() self._update_figures_html()
def encoder_changed():
update_document_embedding_plot()
if selection:
id_to_index = dict((doc_data.doc.unique_id, i) for i, doc_data in enumerate(self._docs))
source['docs'].selected.indices = [id_to_index[x] for x in selection]
def clear_token_plot(): def clear_token_plot():
if tok_emb_p is None: if tok_emb_p is None:
...@@ -682,7 +808,8 @@ class EmbeddingPlotter: ...@@ -682,7 +808,8 @@ class EmbeddingPlotter:
update_document_embedding_plot() update_document_embedding_plot()
if _display_mode.static: if _display_mode.static:
embedding_select.disabled = True # self._embedder.disabled = True
#embedding_select.disabled = True
intruder_select.disabled = True intruder_select.disabled = True
intruder_free.disabled = True intruder_free.disabled = True
query_tabs.disabled = True query_tabs.disabled = True
...@@ -690,7 +817,8 @@ class EmbeddingPlotter: ...@@ -690,7 +817,8 @@ class EmbeddingPlotter:
# x.disabled = True # x.disabled = True
options_cb.disabled = True options_cb.disabled = True
else: else:
embedding_select.on_change("value", update_document_embedding_plot_shim) self._embedder.on_change(encoder_changed)
#embedding_select.on_change("value", update_document_embedding_plot_shim)
intruder_select.on_change("value", update_document_embedding_plot_shim) intruder_select.on_change("value", update_document_embedding_plot_shim)
intruder_free.on_change("value", update_document_embedding_plot_shim) intruder_free.on_change("value", update_document_embedding_plot_shim)
query_tabs.on_change("active", update_document_embedding_plot_shim) query_tabs.on_change("active", update_document_embedding_plot_shim)
...@@ -700,7 +828,8 @@ class EmbeddingPlotter: ...@@ -700,7 +828,8 @@ class EmbeddingPlotter:
def update_document_embedding_plot_shim(changed): def update_document_embedding_plot_shim(changed):
update_document_embedding_plot() update_document_embedding_plot()
embedding_select.observe(update_document_embedding_plot_shim, names="value") self._embedder.on_change(encoder_changed)
#embedding_select.observe(update_document_embedding_plot_shim, names="value")
intruder_select.observe(update_document_embedding_plot_shim, names="value") intruder_select.observe(update_document_embedding_plot_shim, names="value")
intruder_free.observe(update_document_embedding_plot_shim, names="value") intruder_free.observe(update_document_embedding_plot_shim, names="value")
query_tabs.observe(update_document_embedding_plot_shim, names="selected_index") query_tabs.observe(update_document_embedding_plot_shim, names="selected_index")
...@@ -787,7 +916,8 @@ class EmbeddingPlotter: ...@@ -787,7 +916,8 @@ class EmbeddingPlotter:
figure_widget = doc_emb_p figure_widget = doc_emb_p
return bokeh.layouts.column( return bokeh.layouts.column(
bokeh.layouts.column(embedding_select, query_tabs, background="#F0F0F0"), self._embedder.widget,
bokeh.layouts.column(query_tabs, background="#F0F0F0"),
figure_widget, figure_widget,
options_cb, options_cb,
sizing_mode="stretch_width") sizing_mode="stretch_width")
...@@ -814,7 +944,8 @@ class EmbeddingPlotter: ...@@ -814,7 +944,8 @@ class EmbeddingPlotter:
display(self._pw) display(self._pw)
root_widgets = [ root_widgets = [
widgets.VBox([embedding_select, query_tabs]), self._embedder.widget,
query_tabs,
figure_widget figure_widget
] ]
...@@ -826,19 +957,22 @@ class EmbeddingPlotter: ...@@ -826,19 +957,22 @@ class EmbeddingPlotter:
def plot_doc_embeddings(session, nlp, gold, plot_args, aggregator=np.mean, extra_encoders={}): def plot_doc_embeddings(embedder_factory, gold, plot_args):
plotters = [] plotters = []
for args in plot_args: for args in plot_args:
plotter = EmbeddingPlotter(session, nlp, gold, aggregator) plotter = EmbeddingPlotter(
for k, v in extra_encoders.items(): embedder_factory.create(args.get("encoder")),
plotter.encoders[k] = v gold)
plotters.append(plotter) plotters.append(plotter)
if _display_mode.bokeh: if _display_mode.bokeh:
def mk_root(bokeh_doc): def mk_root(bokeh_doc):
widgets = [] widgets = []
for plotter, kwargs in zip(plotters, plot_args): for plotter, kwargs in zip(plotters, plot_args):
kwargs = kwargs.copy()
if "encoder" in kwargs:
del kwargs["encoder"]
widgets.append(plotter.mk_plot(bokeh_doc, **kwargs)) widgets.append(plotter.mk_plot(bokeh_doc, **kwargs))
return bokeh.layouts.row(widgets) return bokeh.layouts.row(widgets)
...@@ -861,19 +995,15 @@ def plot_doc_embeddings(session, nlp, gold, plot_args, aggregator=np.mean, extra ...@@ -861,19 +995,15 @@ def plot_doc_embeddings(session, nlp, gold, plot_args, aggregator=np.mean, extra
class DocEmbeddingExplorer: class DocEmbeddingExplorer:
def __init__(self, **base_args): def __init__(self, *args, gold, **kwargs):
self._aggregator = widgets.Dropdown( self._embedder_factory = DocEmbedderFactory(*args, **kwargs)
description="token embedding aggregator:", self._gold = gold
style = {'description_width': 'initial', 'width': 'max'},
options=[("mean", np.mean), ("median", np.median), ("max", np.max), ("min", np.min)])
display(self._aggregator)
self._base_args = base_args
def plot(self, args): def plot(self, args):
return plot_doc_embeddings( return plot_doc_embeddings(
plot_args=args, self._embedder_factory,
aggregator=self._aggregator.value, self._gold,
**self._base_args) args)
......
...@@ -74,11 +74,11 @@ ...@@ -74,11 +74,11 @@
"id": "77676262-9b46-430d-b575-636e8a50e20d", "id": "77676262-9b46-430d-b575-636e8a50e20d",
"metadata": {}, "metadata": {},
"source": [ "source": [
"There are now various established ways to compute embeddings for similarity tasks. A first important distinction is between *token* embeddings and *document* embeddings (see diagram below) - note that we use the terms \"token embeddings\" and \"word embeddings\" interchangeably. While the former imply one embedding (i.e. numeric vector) per token, the latter operate by mapping a whole document (a set of tokens) into one single embedding.\n", "There are now various established ways to compute embeddings for word similarity tasks. A first important distinction is between *token* embeddings and *document* embeddings (see diagram below) - note that we use the terms \"token embeddings\" and \"word embeddings\" interchangeably. While the former imply one embedding (i.e. numeric vector) per token, the latter operate by mapping a whole document (a set of tokens) into one single embedding.\n",
"\n", "\n",
"There are two common ways to compute document embeddings. One way is to derive them from token embeddings - for example by averaging over them. More complex approaches train dedicated models that are optimized to produce good document embeddings.\n", "There are two common ways to compute document embeddings. One way is to derive them from token embeddings - for example by averaging over them. More complex approaches train dedicated models that are optimized to produce good document embeddings.\n",
"\n", "\n",
"So on this level, we can differentiate between three kinds of embeddings: pure token embeddings, document embeddings derived from token embeddings, and - finally - document embeddings from dedicated document embedding models (e.g. SBERT).\n", "So on this level, we can differentiate between three kinds of embeddings: pure token embeddings, document embeddings derived from token embeddings, and - finally - document embeddings from dedicated document embedding models - e.g. models like Sentence-BERT (Reimers and Gurevych, 2019).\n",
"\n", "\n",
"![Different kinds of embeddings](miscellaneous/diagram_embeddings_1.svg)\n" "![Different kinds of embeddings](miscellaneous/diagram_embeddings_1.svg)\n"
] ]
...@@ -336,7 +336,7 @@ ...@@ -336,7 +336,7 @@
"source": [ "source": [
"from vectorian.embeddings import SentenceBertEmbedding\n", "from vectorian.embeddings import SentenceBertEmbedding\n",
"\n", "\n",
"the_embeddings['sbert'] = SentenceBertEmbedding(nlp)" "the_embeddings['sbert'] = SentenceBertEmbedding(nlp, 768)"
] ]
}, },
{ {
...@@ -736,7 +736,7 @@ ...@@ -736,7 +736,7 @@
"source": [ "source": [
"Before we turn to alignment strategies to match sentences token by token, we first look at representing each document with one single embedding in order to gather an understanding how different embedding strategies relate to the nearness of documents. We will later return to individual token embeddings.\n", "Before we turn to alignment strategies to match sentences token by token, we first look at representing each document with one single embedding in order to gather an understanding how different embedding strategies relate to the nearness of documents. We will later return to individual token embeddings.\n",
"\n", "\n",
"We will use two strategies for computing document embeddings:\n", "We will use the two strategies for computing document embeddings we mentioned earlier:\n",
"\n", "\n",
"* averaging over token embeddings\n", "* averaging over token embeddings\n",
"* computing document embeddings through a dedicated model" "* computing document embeddings through a dedicated model"
...@@ -774,37 +774,34 @@ ...@@ -774,37 +774,34 @@
"id": "12d93851-2895-422d-92d0-586aa540d480", "id": "12d93851-2895-422d-92d0-586aa540d480",
"metadata": {}, "metadata": {},
"source": [ "source": [
"In order to achieve the former, we configure a helper class instance to use averaging to build documents embeddings from token embeddings.\n", "In order to achieve the former, we configure a helper class instance to use averaging to build documents embeddings from token embeddings. Interactive readers may want to try changing the \"mean\" (i.e. averaging) method to other methods for computing document tokens as well."
"\n",
"Interactive readers may want to change the \"mean\" (averaging) method to other methods for computing document tokens as well."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "2f6fb8cc-5328-483c-becd-9a86fb091192", "id": "5ddd998c-2a3a-4110-b5ef-39303bf3c62d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"doc_embedding_explorer = nbutils.DocEmbeddingExplorer(\n", "import os\n",
" session=session, nlp=nlp, gold=gold_data, extra_encoders={sbert_encoder_name: sbert_encoder})" "\n",
"import importlib\n",
"importlib.reload(nbutils)\n",
"importlib.reload(gold)\n",
"\n",
"nbutils.initialize(\"auto\")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "d271e325-7408-4ff1-a274-89f34251ff3b", "id": "2f6fb8cc-5328-483c-becd-9a86fb091192",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"@interact(embedding=widgets.Dropdown(\n", "doc_embedding_explorer = nbutils.DocEmbeddingExplorer(\n",
" options=[(k, v) for k, v in the_embeddings.items() if not v.is_contextual],\n", " session=session, nlp=nlp, gold=gold_data, doc_encoders={sbert_encoder_name: sbert_encoder})"
" value=the_embeddings[\"numberbatch\"]), normalize=False)\n",
"def plot(embedding, normalize):\n",
" nbutils.plot_embedding_vectors_val(\n",
" [\"sail\", \"boat\", \"coffee\", \"tea\", \"guitar\", \"piano\"],\n",
" get_vec=lambda w: session.word_vec(embedding, w),\n",
" normalize=normalize)"
] ]
}, },
{ {
...@@ -815,8 +812,8 @@ ...@@ -815,8 +812,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"doc_embedding_explorer.plot([\n", "doc_embedding_explorer.plot([\n",
" {\"encoder\": \"paraphrase_distilroberta\", \"locator\": (\"fixed\", \"carry coals\"), 'has_tok_emb': False},\n", " {\"encoder\": \"paraphrase_distilroberta\", \"locator\": (\"fixed\", \"carry coals\")},\n",
" {\"encoder\": \"paraphrase_distilroberta\", \"locator\": (\"fixed\", \"an old man is twice\"), 'has_tok_emb': False}\n", " {\"encoder\": \"paraphrase_distilroberta\", \"locator\": (\"fixed\", \"an old man is twice\")}\n",
"]);" "]);"
] ]
}, },
...@@ -903,7 +900,7 @@ ...@@ -903,7 +900,7 @@
"id": "60da85ca-b472-4e59-846d-0c120454807a", "id": "60da85ca-b472-4e59-846d-0c120454807a",
"metadata": {}, "metadata": {},
"source": [ "source": [
"A different approach to compute a measure of similarity between bag of words is the so-called Word Mover's Distance introduced by Kusner et al. (Kusner et al., 2015)." "A different approach to compute a measure of similarity between bag of words is the so-called Word Mover's Distance introduced by Kusner et al. (Kusner et al., 2015). The main idea is computing similarity through finding the optimal solution of a transportation problem between words."
] ]
}, },
{ {
...@@ -982,6 +979,14 @@ ...@@ -982,6 +979,14 @@
"We first define a strategy for searching the corpus. In the summary below you will find the strategy used for the non-interactive version of this text. In the interactive version, you can click on \"Edit\" and change these settings and rerun the following sections of the notebook accordingly." "We first define a strategy for searching the corpus. In the summary below you will find the strategy used for the non-interactive version of this text. In the interactive version, you can click on \"Edit\" and change these settings and rerun the following sections of the notebook accordingly."
] ]
}, },
{
"cell_type": "markdown",
"id": "40a082f5-99cf-4b21-8d65-d824f2199e0f",
"metadata": {},
"source": [
"We investigate two variants of WMD. First the classic variant as described by Kusner et al., where a transportation problem is solved over the normalized bag of words (nbow) vector. We also introduce a new variant of WMD, where we keep the bag of words (bow) unnormalized - i.e. we pose the transportation problem on absolute word occurence counts."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
...@@ -1171,7 +1176,9 @@ ...@@ -1171,7 +1176,9 @@
"id": "c637b6d2-6271-48ed-9375-41b62c2874d9", "id": "c637b6d2-6271-48ed-9375-41b62c2874d9",
"metadata": {}, "metadata": {},
"source": [ "source": [
"The distributions of score contributions we just observed are the motivation for our approach to tag-weighted alignments, that are described in (Liebl and Burghardt, 2020). We demonstrate it now, by using a tag-weighted alignment that will weight nouns like \"madness\" and \"method\" 3 times more than other word types. Let's set it up (\"NN\" is a Penn Treebank tag and identifies singular nouns):" "The distributions of score contributions we just observed are the motivation for our approach to tag-weighted alignments, that are described in (Liebl and Burghardt, 2020). Nagoudi and Schwab used similar ideas of POS weighting for computing sentence similarity, but did not combine it with alignments (Nagoudi and Schwab, 2017).\n",
"\n",
"We now demonstrate tag-weighted alignments, by using a tag-weighted alignment that will weight nouns like \"madness\" and \"method\" 3 times more than other word types. \"NN\" is a Penn Treebank tag and identifies singular nouns."
] ]
}, },
{ {
...@@ -1204,11 +1211,7 @@ ...@@ -1204,11 +1211,7 @@
"id": "ca966318-3d37-464c-a73d-fd3d153dcabb", "id": "ca966318-3d37-464c-a73d-fd3d153dcabb",
"metadata": {}, "metadata": {},
"source": [ "source": [
"This tag-weighting allows to fix move the correct results far to the top, namely to ranks 1, 2, 4 and 6.\n", "Tag-weighting moves the correct results far to the top, namely to ranks 1, 2, 4 and 6. By increasing the NN weight to 5, it is possible to bring rank 73 to rank 15. This is sort of an extreme measure though and we will not investigate it further here. Instead we investigate how the weighting affects the other queries. Therefore, we re-run the NDCG computation and compare it against unweighted WSB."
"\n",
"Note that we can bring rank 73 to rank 15 by increasing the NN weight to 5. But this is sort of an extreme measure and we will not follow it here.\n",
"\n",
"Instead we wonder: how will the weighting affect the other queries? Let's re-run the NDCG computation and compare it against unweighted WSB."
] ]
}, },
{ {
...@@ -1248,7 +1251,7 @@ ...@@ -1248,7 +1251,7 @@
"id": "be0ea111-9c4a-4ec1-8dee-6a5b1a5ef619", "id": "be0ea111-9c4a-4ec1-8dee-6a5b1a5ef619",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# The Influence of Embeddings" "# The Influence of Different Embeddings"
] ]
}, },
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment