I'm trying to use a Huggingface pretrained model "GPT2dialog" as a encoder for sentences,But the textindexer confused me.
In detail ,I can run a unittest for dataset_reader with a pretrained indexer normally,when use the train command to train the model caused a Bug:
File "/home/lee/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/common/lazy.py", line 54, in constructor_to_use
return constructor.from_params(Params({}), **kwargs) # type: ignore[union-attr]
File "/home/lee/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/common/from_params.py", line 604, in from_params
**extras,
File "/home/lee/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/common/from_params.py", line 634, in from_params
return constructor_to_call(**kwargs) # type: ignore
File "/home/lee/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/data/vocabulary.py", line 310, in from_instances
instance.count_vocab_items(namespace_token_counts)
File "/home/lee/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/data/instance.py", line 60, in count_vocab_items
field.count_vocab_items(counter)
File "/home/lee/anaconda3/envs/allennlp/lib/python3.6/site-packages/allennlp/data/fields/text_field.py", line 78, in count_vocab_items
for indexer in self.token_indexers.values():
AttributeError: 'PretrainedTransformerIndexer' object has no attribute 'values'
Here is my dataset_reader code.
class MultiWozDatasetReader(DatasetReader):
def __init__(self,
lazy:bool = False,
tokenizer: Tokenizer = None,
tokenindexer:Dict[str, TokenIndexer] = None
) -> None:
super().__init__(lazy)
self._tokenizer = tokenizer or WhitespaceTokenizer()
self._tokenindexer = PretrainedTransformerIndexer("microsoft/DialoGPT-small")
#overrides
def read(self, file_path: str):
logger.warn("call read")
with open(file_path, 'r') as data_file:
dialogs = json.load(data_file)
for dialog in dialogs:
dialogue = dialog["dialogue"]
for turn_num in range(len(dialogue)):
dia_single_turn = dialogue[turn_num]
sys_utt = dia_single_turn["system_transcript"]
user_utt = dia_single_turn["transcript"]
state_category = dia_single_turn["state_category"]
span_info = dia_single_turn["span"]
yield self.text_to_instance(sys_utt, user_utt, state_category, span_info)
#overrides
def text_to_instance(self, sys_utt, user_utt, state_catgory, span_info):
tokenized_sys_utt = self._tokenizer.tokenize(sys_utt)
tokenized_user_utt = self._tokenizer.tokenize(user_utt)
tokenized_span_info = self._tokenizer.tokenize(span_info)
tokenized_classifier_input = self._tokenizer.tokenize("[CLS] "+ sys_utt + " [SEP] "+ user_utt)
sys_utt_field = TextField(tokenized_sys_utt, self._tokenindexer)
user_utt_field = TextField(tokenized_user_utt, self._tokenindexer)
classifier_filed = TextField(tokenized_classifier_input, self._tokenindexer)
span_field = TextField(tokenized_span_info, self._tokenindexer)
fields = {"sys_utt": sys_utt_field,"user_utt":user_utt_field,"classifier_input":classifier_filed,"span":span_field}
fields['label']=LabelField(state_catgory)
return Instance(fields)
I am searching for a long time on net. But no use. Please help or try to give some ideas how to achieve this.
The token_indexer needs to be a dictionary. It can be set as follows:
self._token_indexers = {"tokens": PretrainedTransformerIndexer("microsoft/DialoGPT-small")}
Related
I have a project that searches PDFs for URLs and in the process extracts the PDF Metadata. It works perfectly around 99.6% of the time without any errors. But every once in a while, a file throws the old "invalid token error. Traceback Below:
Traceback (most recent call last):
File "c:\python38\lib\runpy.py", line 193, in _run_module_as_main
return run_code(code, main_globals, None,
File "c:\python38\lib\runpy.py", line 86, in run_code
exec(code, run_globals)
File "C:\Python38\Scripts\linkrot.exe_main.py", line 7, in
File "c:\python38\lib\site-packages\linkrot\cli.py", line 182, in main
pdf = linkrot.linkrot(args.pdf)
File "c:\python38\lib\site-packages\linkrot_init.py", line 131, in init
self.reader = PDFMinerBackend(self.stream)
File "c:\python38\lib\site-packages\linkrot\backends.py", line 213, in init
self.metadata.update(xmp_to_dict(metadata))
File "c:\python38\lib\site-packages\linkrot\libs\xmp.py", line 92, in xmp_to_dict
return XmpParser(xmp).meta
File "c:\python38\lib\site-packages\linkrot\libs\xmp.py", line 41, in init
self.tree = ET.XML(xmp)
File "c:\python38\lib\xml\etree\ElementTree.py", line 1320, in XML
parser.feed(text)
xml.etree.ElementTree.ParseError: not well-formed (invalid token): line 55, column 10
My assumption is that there is some sort of issue with the XML extracted from the PDF, but I can't be sure. Is there a workaround? Some way the rest of the program could run when this error throws? The metadata is valuable to the process so I'd like to keep it if possible. I don't know etree that well, so I'd appreciate some help. The Code itself is below:
class XmpParser(object):
"""
Parses an XMP string into a dictionary.
Usage:
parser = XmpParser(xmpstring)
meta = parser.meta
"""
def __init__(self, xmp):
self.tree = ET.XML(xmp)
self.rdftree = self.tree.find(RDF_NS + "RDF")
#property
def meta(self):
""" A dictionary of all the parsed metadata. """
meta = defaultdict(dict)
if self.rdftree:
for desc in self.rdftree.findall(RDF_NS + "Description"):
for (
el
) in (
desc.iter()
):
ns, tag = self._parse_tag(el)
value = self._parse_value(el)
meta[ns][tag] = value
return dict(meta)
def _parse_tag(self, el):
""" Extract the namespace and tag from an element. """
ns = None
tag = el.tag
if tag[0] == "{":
ns, tag = tag[1:].split("}", 1)
if ns in NS_MAP:
ns = NS_MAP[ns]
return ns, tag
def _parse_value(self, el): # noqa: C901
""" Extract the metadata value from an element. """
if el.find(RDF_NS + "Bag") is not None:
value = []
for li in el.findall(RDF_NS + "Bag/" + RDF_NS + "li"):
value.append(li.text)
elif el.find(RDF_NS + "Seq") is not None:
value = []
for li in el.findall(RDF_NS + "Seq/" + RDF_NS + "li"):
value.append(li.text)
elif el.find(RDF_NS + "Alt") is not None:
value = {}
for li in el.findall(RDF_NS + "Alt/" + RDF_NS + "li"):
value[li.get(XML_NS + "lang")] = li.text
else:
value = el.text
return value
Any help or advice would be appreciated.
I am trying to pull a huge amount of data (in millions) and I am getting the following error when running my code. If I run the same code with a small range (to be exact a range of 2) it runs successfully. Please assist in helping me know if this is my issue or is coming from the API side
Thanks
The Error I am getting
DEBUG:google.api_core.bidi:Started helper thread Thread-ConsumeBidirectionalStream
DEBUG:google.api_core.bidi:Thread-ConsumeBidirectionalStream caught error 400 Request contains an invalid argument. and will exit. Generally this is due to the RPC itself being cancelled and the error will be surfaced to the calling code.
Traceback (most recent call last):
File "/home/coyugi/teltel_env/lib/python3.8/site-packages/google/api_core/grpc_helpers.py", line 147, in error_remapped_callable
return _StreamingResponseIterator(
File "/home/coyugi/teltel_env/lib/python3.8/site-packages/google/api_core/grpc_helpers.py", line 73, in __init__
self._stored_first_result = next(self._wrapped)
File "/home/coyugi/teltel_env/lib/python3.8/site-packages/grpc/_channel.py", line 426, in __next__
return self._next()
File "/home/coyugi/teltel_env/lib/python3.8/site-packages/grpc/_channel.py", line 826, in _next
raise self
grpc._channel._MultiThreadedRendezvous: <_MultiThreadedRendezvous of RPC that terminated with:
status = StatusCode.INVALID_ARGUMENT
details = "Request contains an invalid argument."
debug_error_string = "{"created":"#1652904360.179503883","description":"Error received from peer ipv4:173.194.76.95:443","file":"src/core/lib/surface/call.cc","file_line":952,"grpc_message":"Request contains an invalid argument.","grpc_status":3}"
>
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/coyugi/teltel_env/lib/python3.8/site-packages/google/api_core/bidi.py", line 636, in _thread_main
self._bidi_rpc.open()
File "/home/coyugi/teltel_env/lib/python3.8/site-packages/google/api_core/bidi.py", line 279, in open
call = self._start_rpc(iter(request_generator), metadata=self._rpc_metadata)
File "/home/coyugi/teltel_env/lib/python3.8/site-packages/google/cloud/bigquery_storage_v1/services/big_query_write/client.py", line 678, in append_rows
response = rpc(
File "/home/coyugi/teltel_env/lib/python3.8/site-packages/google/api_core/gapic_v1/method.py", line 154, in __call__
return wrapped_func(*args, **kwargs)
File "/home/coyugi/teltel_env/lib/python3.8/site-packages/google/api_core/retry.py", line 283, in retry_wrapped_func
return retry_target(
File "/home/coyugi/teltel_env/lib/python3.8/site-packages/google/api_core/retry.py", line 190, in retry_target
return target()
File "/home/coyugi/teltel_env/lib/python3.8/site-packages/google/api_core/grpc_helpers.py", line 151, in error_remapped_callable
raise exceptions.from_grpc_error(exc) from exc
google.api_core.exceptions.InvalidArgument: 400 Request contains an invalid argument.
INFO:google.api_core.bidi:Thread-ConsumeBidirectionalStream exiting
DEBUG:google.cloud.bigquery_storage_v1.writer:Finished stopping manager.
Traceback (most recent call last):
File "write_data_to_db2.py", line 207, in <module>
p.append_rows_pending(project_id='dwingestion', dataset_id='ke',
File "write_data_to_db2.py", line 188, in append_rows_pending
response_future_1 = append_rows_stream.send(request)
File "/home/coyugi/teltel_env/lib/python3.8/site-packages/google/cloud/bigquery_storage_v1/writer.py", line 234, in send
return self._open(request)
File "/home/coyugi/teltel_env/lib/python3.8/site-packages/google/cloud/bigquery_storage_v1/writer.py", line 207, in _open
raise request_exception
google.api_core.exceptions.Unknown: None There was a problem opening the stream. Try turning on DEBUG level logs to see the error.
Summary Of My Code
# PULLING DATA FROM THE API
def whole_teltel_raw_data():
# Creating a session to introduce network consistency
session = requests.Session()
retry = Retry(connect=3, backoff_factor=1.0)
adapter = HTTPAdapter(max_retries=retry)
session.mount('http://', adapter)
session.mount('https://', adapter)
url = "https://my_api_url"
the_headers = {"X-API-KEY": 'my key'}
offset_limit = 1249500
teltel_data = []
# Loop through the results and if present extend the teltel_data list
#======================================================================================================================
# WRITE THE DATA TO THE DATA WAREHOUSE
# ======================================================================================================================
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'dwingestion-b033d9535e9d.json'
def create_row_data(tuple_data):
call_id, starttime, stoptime, direction, type, status, duration_sec, rate, cost, transfer, extra_prefix, audio_url, \
hangup_element, caller_number, caller_type, caller_cid, caller_dnid, caller_user_id, caller_user_short, \
callee_number, calle_type, callee, hangup_element_name, hangup_element_element, callee_user_id, callee_user_short, \
caller = tuple_data
row = teltel_call_data_pb2.TeltelCall()
row.call_id = call_id
row.starttime = starttime
row.stoptime = stoptime
row.direction = direction
row.type = type
row.status = status
row.duration_sec = duration_sec
row.rate = rate
row.cost = cost
row.transfer = transfer
row.extra_prefix = extra_prefix
row.audio_url = audio_url
row.hangup_element = hangup_element
row.caller_number = caller_number
row.caller_type = caller_type
row.caller_cid = caller_cid
row.caller_dnid = caller_dnid
row.caller_user_id = caller_user_id
row.caller_user_short = caller_user_short
row.callee_number = callee_number
row.calle_type = calle_type
row.callee = callee
row.hangup_element_name = hangup_element_name
row.hangup_element_title = hangup_element_element
row.callee_user_id = callee_user_id
row.callee_user_short = callee_user_short
row.caller = caller
return row.SerializeToString()
# Creating connection to the data warehouse
def create_bigquery_storage_client(google_credentials):
return bigquery_storage_v1.client.BigQueryWriteClient(
credentials=google_credentials
)
class GcpBigqueryStorageService(object):
def __init__(self, google_credentials=None, gcp_config=None):
self.client = create_bigquery_storage_client(google_credentials)
self.config = gcp_config
def append_rows_pending(self, project_id: str, dataset_id: str, table_id: str):
"""Create a write stream, write some sample data, and commit the stream."""
# write_client = self.client
parent = self.client.table_path(project_id, dataset_id, table_id)
write_stream = types.WriteStream()
# When creating the stream, choose the type. Use the PENDING type to wait
write_stream.type_ = types.WriteStream.Type.PENDING
write_stream = self.client.create_write_stream(
parent=parent, write_stream=write_stream
)
stream_name = write_stream.name
# Create a template with fields needed for the first request.
request_template = types.AppendRowsRequest()
# The initial request must contain the stream name.
request_template.write_stream = stream_name
# So that BigQuery knows how to parse the serialized_rows, generate a
# protocol buffer representation of your message descriptor.
proto_schema = types.ProtoSchema()
proto_descriptor = descriptor_pb2.DescriptorProto()
teltel_call_data_pb2.TeltelCall.DESCRIPTOR.CopyToProto(proto_descriptor)
proto_schema.proto_descriptor = proto_descriptor
proto_data = types.AppendRowsRequest.ProtoData()
proto_data.writer_schema = proto_schema
request_template.proto_rows = proto_data
# Some stream types support an unbounded number of requests. Construct an
# AppendRowsStream to send an arbitrary number of requests to a stream.
append_rows_stream = writer.AppendRowsStream(self.client, request_template)
# Create a batch of row data by appending proto2 serialized bytes to the
# serialized_rows repeated field.
proto_rows = types.ProtoRows()
row_number = 0
for row in whole_teltel_raw_data():
proto_rows.serialized_rows.append(create_row_data(row))
# checking the writing progress
row_number = row_number + 1
print("Writing to the database row number", row_number)
# The first request must always have an offset of 0.
request = types.AppendRowsRequest()
proto_data = types.AppendRowsRequest.ProtoData()
proto_data.rows = proto_rows
request.proto_rows = proto_data
append_rows_stream.close()
# A PENDING type stream must be "finalized" before being committed. No new
# records can be written to the stream after this method has been called.
self.client.finalize_write_stream(name=write_stream.name)
# Commit the stream you created earlier.
batch_commit_write_streams_request = types.BatchCommitWriteStreamsRequest()
batch_commit_write_streams_request.parent = parent
batch_commit_write_streams_request.write_streams = [write_stream.name]
self.client.batch_commit_write_streams(batch_commit_write_streams_request)
print(f"Writes to stream: '{write_stream.name}' have been committed.")
p = GcpBigqueryStorageService()
p.append_rows_pending(project_id='my_project', dataset_id='my_id', table_id='teltel_call_2')
I'm following a YT tutorial, and I feel I've copied the code exactly, but keep getting this error:
AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "C:\Users\Connor\Anaconda3\lib\site-packages\torch\utils\data\_utils\worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "C:\Users\Connor\Anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "C:\Users\Connor\Anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "C:\Users\Connor\OneDrive\Python\Kaggle\Facial Keypoint Detecetion\dataset.py", line 21, in __getitem__
image = np.array(self.data.iloc[index, 30].split()).astype(np.float32)
AttributeError: 'numpy.float64' object has no attribute 'split'
Below is the full code:
class FacialKeypointDataset(Dataset):
def __init__(self, csv_file, train=True, transform=None):
super().__init__()
self.data = pd.read_csv(csv_file)
self.category_names = ['left_eye_center_x', 'left_eye_center_y', 'right_eye_center_x', 'right_eye_center_y', 'left_eye_inner_corner_x', 'left_eye_inner_corner_y', 'left_eye_outer_corner_x', 'left_eye_outer_corner_y', 'right_eye_inner_corner_x', 'right_eye_inner_corner_y', 'right_eye_outer_corner_x', 'right_eye_outer_corner_y', 'left_eyebrow_inner_end_x', 'left_eyebrow_inner_end_y', 'left_eyebrow_outer_end_x', 'left_eyebrow_outer_end_y', 'right_eyebrow_inner_end_x', 'right_eyebrow_inner_end_y', 'right_eyebrow_outer_end_x', 'right_eyebrow_outer_end_y', 'nose_tip_x', 'nose_tip_y', 'mouth_left_corner_x', 'mouth_left_corner_y', 'mouth_right_corner_x', 'mouth_right_corner_y', 'mouth_center_top_lip_x', 'mouth_center_top_lip_y', 'mouth_center_bottom_lip_x', 'mouth_center_bottom_lip_y']
self.transform = transform
self.train = train
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
if self.train:
image = np.array(self.data.iloc[index, 30].split()).astype(np.float32)
labels = np.array(self.data.iloc[index, :30].tolist())
labels[np.isnan(labels)] = -1
else:
image = np.array(self.data.iloc[index, 1].split()).astype(np.float32)
labels = np.zeros(30)
ignore_indices = labels == -1
labels = labels.reshape(15, 2)
if self.transform:
image = np.repeat(image.reshape(96, 96, 1), 3, 2).astype(np.uint8)
augmentations = self.transform(image=image, keypoints=labels)
image = augmentations["image"]
labels = augmentations["keypoints"]
labels = np.array(labels).reshape(-1)
labels[ignore_indices] = -1
return image, labels.astype(np.float32)
if __name__ == "__main__":
ds = FacialKeypointDataset(csv_file="data/train_4.csv", train=True, transform=config.train_transforms)
loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0)
for idx, (x, y) in enumerate(loader):
plt.imshow(x[0][0].detach().cpu().numpy(), cmap='gray')
plt.plot(y[0][0::2].detach().cpu().numpy(), y[0][1::2].detach().cpu().numpy(), "go")
plt.show()
In the tutorial it has the same lines of code, but no error. Here is the link to the Gituhub:
https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Kaggles/Facial%20Keypoint%20Detection%20Competition
Any ideas what might be causing this?
The code seems to expect that the value returned by self.data.iloc[index, 30] will always be a string.
That might be ok for the project you are basing your code on, but if you pass a csv file that has floats instead of strings it will result on the error that you got.
Convert the data to float with .astype(np.float32) first. For example,
self.data = pd.read_csv(csv_file)
self.data = self.data.astype(np.float32)
Or
self.data = pd.read_csv(csv_file, dtype=np.float64)
If you got error, it means your csv_file has data type like strings, and this program cannot be used on your input data.
I only wrote a simple model in model.py, and when I ran it, it gave the following error.
2021-02-08 22:20:11.872409: E tensorflow/core/common_runtime/executor.cc:641] Executor failed to create kernel. Unimplemented: Cast string to int32 is not supported
[[{{node embedding/Cast}}]]
Traceback (most recent call last):
File "C:\Users\xiaoc\Anaconda3\lib\runpy.py", line 193, in _run_module_as_main
"main", mod_spec)
File "C:\Users\xiaoc\Anaconda3\lib\runpy.py", line 85, in _run_code
exec(code, run_globals)
File "C:\Users\xiaoc\AppData\Local\Google\Cloud SDK\trainer\task.py", line 55, in
train_model(args)
File "C:\Users\xiaoc\AppData\Local\Google\Cloud SDK\trainer\task.py", line 43, in train_model
validation_data=(eval_data, eval_labels))
File "C:\Users\xiaoc\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py", line 780, in fit
steps_name='steps_per_epoch')
File "C:\Users\xiaoc\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\training_arrays.py", line 363, in model_iteration
batch_outs = f(ins_batch)
File "C:\Users\xiaoc\Anaconda3\lib\site-packages\tensorflow\python\keras\backend.py", line 3289, in call
self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
File "C:\Users\xiaoc\Anaconda3\lib\site-packages\tensorflow\python\keras\backend.py", line 3222, in _make_callable
callable_fn = session._make_callable_from_options(callable_opts)
File "C:\Users\xiaoc\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1489, in _make_callable_from_options
return BaseSession._Callable(self, callable_options)
File "C:\Users\xiaoc\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1446, in init
session._session, options_ptr)
tensorflow.python.framework.errors_impl.UnimplementedError: Cast string to int32 is not supported
[[{{node embedding/Cast}}]]
What is the problem? The requirement is to only make changes in model.py, not others. Thanks in advance!
Following are the three python files.
model.py
import tensorflow as tf
from tensorflow.keras.layers import Dense,Embedding,LSTM, Activation,Dropout
from tensorflow.keras import Model
def get_batch_size(): #size of training 8056 number of batches = 8056/128
return 128
def get_epochs():
return 50
def solution(input_layer):
max_len = 150
max_words = 200
# inputs = Input(name='inputs',shape=[max_len])
layer = Embedding(max_words,output_dim = 64, input_length=max_len)(input_layer)
# layer = LSTM(64,return_sequences=True)(input_layer)
layer = tf.expand_dims(layer, axis=-1)
layer = LSTM(64,return_sequences=True)(layer)
layer = Dense(256,)(layer)
layer = Activation('relu')(layer)
layer = Dropout(0.5)(layer)
layer = Dense(5)(layer)
# layer = Activation('softmax')(layer)
model = Model(inputs=input_layer,outputs=layer)
model.compile(loss='sparse_categorical_crossentropy',optimizer=tf.keras.optimizers.Adam(),metrics=['accuracy'])
return model
data.py
import csv
import numpy as np
label_map = {
0: 'A',
1: 'B',
2: 'C',
3: 'D',
4: 'E',
}
label_map_inv = dict(map(reversed, label_map.items()))
def load_dataset(dataset_file):
data = []
labels = []
with open(dataset_file, "r", encoding="utf-8") as f:
data_reader = csv.reader(f, delimiter=",", quotechar='"')
next(data_reader)
for lbl, desc in data_reader:
data.append(desc)
labels.append(label_map_inv[lbl])
return np.array(data), np.array(labels)
task.py
import os
import argparse
import logging
import numpy as np
import tensorflow as tf
import tensorflow.keras
import trainer.data as data
import trainer.model as model
def train_model(params):
(train_data, train_labels) = data.load_dataset("data/train.csv")
(eval_data, eval_labels) = data.load_dataset("data/eval.csv")
input_layer = tf.keras.Input(shape=(), name='input_text', dtype=tf.string)
ml_model = model.solution(input_layer)
if ml_model is None:
print("No model found. You need to implement one in model.py")
else:
ml_model.fit(train_data, train_labels,
batch_size=model.get_batch_size(),
epochs=model.get_epochs(),
validation_data=(eval_data, eval_labels))
_ = ml_model.evaluate(eval_data, eval_labels, verbose=1)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()
tf_logger = logging.getLogger("tensorflow")
tf_logger.setLevel(logging.INFO)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(tf_logger.level // 10)
train_model(args)
Environment:
Ubuntu 16.04;
TensorFlow v1.0.0 (CPU)
When attempting to import a saved graph using "tf.train.import_meta_graph('model.meta')," I get the following error:
Traceback (most recent call last):
File "test_load.py", line 19, in new_saver =
tf.train.import_meta_graph('model.meta')
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py",
line 1577, in import_meta_graph **kwargs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/meta_graph.py",
line 498, in import_scoped_meta_graph producer_op_list=producer_op_list)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/importer.py",
line 259, in import_graph_def
raise ValueError('No op named %s in defined operations.' % node.op)
ValueError: No op named attn_add_fun_f32f32f32 in defined operations.
This error isn't thrown when I retrain my model without attention and import the graph with the same line of code.
Is loading a model trained with attention not currently supported? Here is what my attention implementation looks like:
attention_states = tf.transpose(self.encoder_outputs, [1, 0, 2])
(attention_keys,
attention_values,
attention_score_fn,
attention_construct_fn) = seq2seq.prepare_attention(
attention_states = attention_states,
attention_option = "bahdanau",
num_units = self.decoder_cell.output_size)
decoder_fn_train = seq2seq.attention_decoder_fn_train(
encoder_state = self.encoder_state,
attention_keys = attention_keys,
attention_values = attention_values,
attention_score_fn = attention_score_fn,
attention_construct_fn = attention_construct_fn,
name = 'attention_decoder')
decoder_fn_inference = seq2seq.attention_decoder_fn_inference(
output_fn = output_fn,
encoder_state = self.encoder_state,
attention_keys = attention_keys,
attention_values = attention_values,
attention_score_fn = attention_score_fn,
attention_construct_fn = attention_construct_fn,
embeddings = self.embedding_matrix,
start_of_sequence_id = self.EOS,
end_of_sequence_id = self.EOS,
maximum_length = tf.reduce_max(self.encoder_inputs_length) + 3,
num_decoder_symbols = self.vocab_size,)
Thanks!