root/branches/debug/extra/tsp/src/tsp.c

Revision 128, 11.9 KB (checked in by anton, 3 years ago)

Debugging branch added

Line 
1/*
2 * Traveling Salesman Problem solution algorithm for PostgreSQL
3 *
4 * Copyright (c) 2006 Anton A. Patrushev, Orkney, Inc.
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
19 *
20 */
21
22#include "tsp.h"
23
24#include "postgres.h"
25#include "executor/spi.h"
26#include "funcapi.h"
27
28#include "string.h"
29#include "math.h"
30
31#include "fmgr.h"
32
33#ifdef PG_MODULE_MAGIC
34PG_MODULE_MAGIC;
35#endif
36
37
38// ------------------------------------------------------------------------
39
40/*
41 * Define this to have profiling enabled
42 */
43//#define PROFILE
44
45#ifdef PROFILE
46#include <sys/time.h>
47
48struct timeval prof_tsp, prof_store, prof_extract, prof_total;
49long proftime[5];
50long profipts1, profipts2, profopts;
51#define profstart(x) do { gettimeofday(&x, NULL); } while (0);
52#define profstop(n, x) do { struct timeval _profstop;   \
53        long _proftime;                         \
54        gettimeofday(&_profstop, NULL);                         \
55        _proftime = ( _profstop.tv_sec*1000000+_profstop.tv_usec) -     \
56                ( x.tv_sec*1000000+x.tv_usec); \
57        elog(NOTICE, \
58                "PRF(%s) %lu (%f ms)", \
59                (n), \
60             _proftime, _proftime / 1000.0);    \
61        } while (0);
62
63#else
64
65#define profstart(x) do { } while (0);
66#define profstop(n, x) do { } while (0);
67
68#endif // PROFILE
69
70// ------------------------------------------------------------------------
71
72Datum tsp(PG_FUNCTION_ARGS);
73
74#undef DEBUG
75//#define DEBUG 1
76
77#ifdef DEBUG
78#define DBG(format, arg...)                     \
79    elog(NOTICE, format , ## arg)
80#else
81#define DBG(format, arg...) do { ; } while (0)
82#endif
83
84// The number of tuples to fetch from the SPI cursor at each iteration
85#define TUPLIMIT 1000
86
87// Apologies for using fixed-length arrays.  But this is an example, not
88// production code ;)
89//#define MAX_TOWNS 40
90
91float DISTANCE[MAX_TOWNS][MAX_TOWNS];
92float x[MAX_TOWNS],y[MAX_TOWNS];
93int total_tuples;
94
95
96static char *
97text2char(text *in)
98{
99  char *out = (char*)palloc(VARSIZE(in));
100
101  memcpy(out, VARDATA(in), VARSIZE(in) - VARHDRSZ);
102  out[VARSIZE(in) - VARHDRSZ] = '\0';
103  return out;
104}
105
106static int
107finish(int code, int ret)
108{
109  code = SPI_finish();
110  if (code  != SPI_OK_FINISH )
111  {
112    elog(ERROR,"couldn't disconnect from SPI");
113    return -1 ;
114  }
115  return ret;
116}
117                 
118
119typedef struct point_columns
120{
121  int id;
122  float8 x;
123  float8 y;
124} point_columns_t;
125
126
127static int
128fetch_point_columns(SPITupleTable *tuptable, point_columns_t *point_columns)
129{
130  DBG("Fetching point");
131
132  point_columns->id = SPI_fnumber(SPI_tuptable->tupdesc, "source_id");
133  point_columns->x = SPI_fnumber(SPI_tuptable->tupdesc, "x");
134  point_columns->y = SPI_fnumber(SPI_tuptable->tupdesc, "y");
135  if (point_columns->id == SPI_ERROR_NOATTRIBUTE ||
136      point_columns->x == SPI_ERROR_NOATTRIBUTE ||
137      point_columns->y == SPI_ERROR_NOATTRIBUTE)
138    {
139      elog(ERROR, "Error, query must return columns "
140           "'source_id', 'x' and 'y'");
141      return -1;
142    }
143   
144  DBG("* Point %i [%f, %f]", point_columns->id, point_columns->x,
145      point_columns->y);
146
147  return 0;
148}
149
150static void
151fetch_point(HeapTuple *tuple, TupleDesc *tupdesc,
152            point_columns_t *point_columns, point_t *point)
153{
154  Datum binval;
155  bool isnull;
156
157  DBG("inside fetch_point\n");
158
159  binval = SPI_getbinval(*tuple, *tupdesc, point_columns->id, &isnull);
160  DBG("Got id\n");
161
162  if (isnull)
163    elog(ERROR, "id contains a null value");
164
165  point->id = DatumGetInt32(binval);
166
167  DBG("id = %i\n", point->id);
168
169  binval = SPI_getbinval(*tuple, *tupdesc, point_columns->x, &isnull);
170  if (isnull)
171    elog(ERROR, "x contains a null value");
172
173  point->x = DatumGetFloat8(binval);
174
175  DBG("x = %f\n", point->x);
176
177  binval = SPI_getbinval(*tuple, *tupdesc, point_columns->y, &isnull);
178
179  if (isnull)
180    elog(ERROR, "y contains a null value");
181
182  point->y = DatumGetFloat8(binval);
183
184  DBG("y = %f\n", point->y);
185}
186
187
188static int solve_tsp(char* sql, char* p_ids,
189                     int source, path_element_t* path)
190{
191  int SPIcode;
192  void *SPIplan;
193  Portal SPIportal;
194  bool moredata = TRUE;
195  int ntuples;
196
197  //todo replace path (vector of path_element_t) with this array
198  int ids[MAX_TOWNS];
199
200  point_t *points=NULL;
201  point_columns_t point_columns = {id: -1, x: -1, y:-1};
202
203  char *err_msg = NULL;
204  int ret = -1;
205   
206  char *p;
207  int   z = 0;
208 
209  int    tt, cc;
210  double dx, dy;
211  float  fit=0.0;
212
213  DBG("inside tsp\n");
214
215  //int total_tuples = 0;
216  total_tuples = 0;
217
218  p = strtok(p_ids, ",");
219  while(p != NULL)
220    {
221      //      ((path_element_t*)path)[z].vertex_id = atoi(p);
222      ids[z]=atoi(p);
223      p = strtok(NULL, ",");
224      z++;
225      if(z >= MAX_TOWNS)
226      {
227        elog(ERROR, "Number of points exeeds max number.");
228        break;
229      }
230    }
231   
232  DBG("ZZZ %i\n",z);
233  DBG("start tsp\n");
234       
235  SPIcode = SPI_connect();
236
237  if (SPIcode  != SPI_OK_CONNECT)
238    {
239      elog(ERROR, "tsp: couldn't open a connection to SPI");
240      return -1;
241    }
242
243  SPIplan = SPI_prepare(sql, 0, NULL);
244
245  if (SPIplan  == NULL)
246    {
247      elog(ERROR, "tsp: couldn't create query plan via SPI");
248      return -1;
249    }
250
251  if ((SPIportal = SPI_cursor_open(NULL, SPIplan, NULL, NULL, true)) == NULL)
252    {
253      elog(ERROR, "tsp: SPI_cursor_open('%s') returns NULL", sql);
254      return -1;
255    }
256   
257  DBG("Query: %s\n",sql);
258  DBG("Query executed\n");
259
260  while (moredata == TRUE)
261    {
262      SPI_cursor_fetch(SPIportal, TRUE, TUPLIMIT);
263
264      if (point_columns.id == -1)
265        {
266          if (fetch_point_columns(SPI_tuptable, &point_columns) == -1)
267            return finish(SPIcode, ret);
268        }
269
270      ntuples = SPI_processed;
271
272      total_tuples += ntuples;
273
274      DBG("Tuples: %i\n", total_tuples);
275
276      if (!points)
277        points = palloc(total_tuples * sizeof(point_t));
278      else
279        points = repalloc(points, total_tuples * sizeof(point_t));
280                                       
281      if (points == NULL)
282        {
283          elog(ERROR, "Out of memory");
284          return finish(SPIcode, ret);
285        }
286
287      if (ntuples > 0)
288        {
289          int t;
290          SPITupleTable *tuptable = SPI_tuptable;
291          TupleDesc tupdesc = SPI_tuptable->tupdesc;
292
293          DBG("Got tuple desc\n");
294               
295          for (t = 0; t < ntuples; t++)
296            {
297              HeapTuple tuple = tuptable->vals[t];
298              DBG("Before point fetched\n");
299              fetch_point(&tuple, &tupdesc, &point_columns,
300                          &points[total_tuples - ntuples + t]);
301              DBG("Point fetched\n");
302            }
303
304          SPI_freetuptable(tuptable);
305        }
306      else
307        {
308          moredata = FALSE;
309        }                                       
310    }
311
312 
313  DBG("Calling TSP\n");
314       
315  profstop("extract", prof_extract);
316  profstart(prof_tsp);
317
318  DBG("Total tuples: %i\n", total_tuples);
319
320  for(tt=0;tt<total_tuples;++tt)
321    {
322      //((path_element_t*)path)[tt].vertex_id = points[tt].id;
323      ids[tt] = points[tt].id;
324      x[tt] = points[tt].x;
325      y[tt] = points[tt].y;
326 
327      DBG("Point at %i: %i [%f, %f]\n",  tt, ids[tt], x[tt], y[tt]);
328           
329      // ((path_element_t*)path)[tt].vertex_id, x[tt], y[tt]);
330
331      for(cc=0;cc<total_tuples;++cc)
332        {
333          dx=x[tt]-x[cc]; dy=y[tt]-y[cc];
334          DISTANCE[tt][cc] = DISTANCE[cc][tt] = sqrt(dx*dx+dy*dy);
335        }
336    }
337
338  DBG("DISTANCE counted\n");
339  pfree(points);
340   
341  //ret = find_tsp_solution(total_tuples, DISTANCE,
342  //   path, source, &fit, &err_msg);
343
344  ret = find_tsp_solution(total_tuples, DISTANCE, ids,
345                          source, &fit, err_msg);
346
347  for(tt=0;tt<total_tuples;++tt)
348    {
349      ((path_element_t*)path)[tt].vertex_id = ids[tt];
350    }
351   
352  DBG("TSP solved!\n");
353  DBG("Score: %f\n", fit);
354
355  profstop("tsp", prof_tsp);
356  profstart(prof_store);
357
358  DBG("Profile changed and ret is %i", ret);
359
360  if (ret < 0)
361    {
362      //elog(ERROR, "Error computing path: %s", err_msg);
363      ereport(ERROR, (errcode(ERRCODE_E_R_E_CONTAINING_SQL_NOT_PERMITTED), errmsg("Error computing path: %s", err_msg)));
364    }
365
366  return finish(SPIcode, ret);   
367}
368
369PG_FUNCTION_INFO_V1(tsp);
370Datum
371tsp(PG_FUNCTION_ARGS)
372{
373  FuncCallContext     *funcctx;
374  int                  call_cntr;
375  int                  max_calls;
376  TupleDesc            tuple_desc;
377  path_element_t        *path;
378   
379  /* stuff done only on the first call of the function */
380  if (SRF_IS_FIRSTCALL())
381    {
382      MemoryContext   oldcontext;
383      //int path_count;
384      int ret=-1;
385
386      // XXX profiling messages are not thread safe
387      profstart(prof_total);
388      profstart(prof_extract);
389
390      /* create a function context for cross-call persistence */
391      funcctx = SRF_FIRSTCALL_INIT();
392
393      /* switch to memory context appropriate for multiple function calls */
394      oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
395       
396
397      path = (path_element_t *)palloc(sizeof(path_element_t)*MAX_TOWNS);
398
399
400      ret = solve_tsp(text2char(PG_GETARG_TEXT_P(0)),
401                      text2char(PG_GETARG_TEXT_P(1)),
402                      PG_GETARG_INT32(2),
403                      path);
404
405      /* total number of tuples to be returned */
406      DBG("Counting tuples number\n");
407
408      funcctx->max_calls = total_tuples;
409
410      funcctx->user_fctx = path;
411
412      funcctx->tuple_desc = BlessTupleDesc(
413                              RelationNameGetTupleDesc("path_result"));
414      MemoryContextSwitchTo(oldcontext);
415    }
416
417  /* stuff done on every call of the function */
418  funcctx = SRF_PERCALL_SETUP();
419
420  call_cntr = funcctx->call_cntr;
421  max_calls = funcctx->max_calls;
422  tuple_desc = funcctx->tuple_desc;
423
424  path = (path_element_t *)funcctx->user_fctx;
425
426  DBG("Trying to allocate some memory\n");
427  DBG("call_cntr = %i, max_calls = %i\n", call_cntr, max_calls);
428
429  if (call_cntr < max_calls)    /* do when there is more left to send */
430    {
431      HeapTuple    tuple;
432      Datum        result;
433      Datum *values;
434      char* nulls;
435
436      /* This will work for some compilers. If it crashes with segfault, try to change the following block with this one   
437
438      values = palloc(4 * sizeof(Datum));
439      nulls = palloc(4 * sizeof(char));
440
441      values[0] = call_cntr;
442      nulls[0] = ' ';
443      values[1] = Int32GetDatum(path[call_cntr].vertex_id);
444      nulls[1] = ' ';
445      values[2] = Int32GetDatum(path[call_cntr].edge_id);
446      nulls[2] = ' ';
447      values[3] = Float8GetDatum(path[call_cntr].cost);
448      nulls[3] = ' ';
449      */
450   
451      values = palloc(3 * sizeof(Datum));
452      nulls = palloc(3 * sizeof(char));
453
454      values[0] = Int32GetDatum(path[call_cntr].vertex_id);
455      nulls[0] = ' ';
456      values[1] = Int32GetDatum(path[call_cntr].edge_id);
457      nulls[1] = ' ';
458      values[2] = Float8GetDatum(path[call_cntr].cost);
459      nulls[2] = ' ';
460     
461      DBG("Heap making\n");
462
463      tuple = heap_formtuple(tuple_desc, values, nulls);
464
465      DBG("Datum making\n");
466
467      /* make the tuple into a datum */
468      result = HeapTupleGetDatum(tuple);
469
470      DBG("VAL: %i\n, %i", values[0], result);
471      DBG("Trying to free some memory\n");
472   
473      /* clean up (this is not really necessary) */
474      pfree(values);
475      pfree(nulls);
476       
477
478      SRF_RETURN_NEXT(funcctx, result);
479    }
480  else    /* do when there is no more left */
481    {
482      DBG("Ending function\n");
483      profstop("store", prof_store);
484      profstop("total", prof_total);
485      DBG("Profiles stopped\n");
486
487      pfree(path);
488
489      DBG("Path cleared\n");
490
491      SRF_RETURN_DONE(funcctx);
492    }
493}
Note: See TracBrowser for help on using the browser.